phase3: HTTP server with converter, one-shot preview store, and middleware

This commit is contained in:
Sergey Filkin
2026-04-18 11:55:42 +03:00
parent d1682813ff
commit 843d8dc710
8 changed files with 1033 additions and 4 deletions
+94
View File
@@ -0,0 +1,94 @@
package server
import (
"fmt"
"os"
"strconv"
"strings"
"time"
)
const (
defaultAddr = ":8080"
defaultMaxMarkdownBytes = int64(1_048_576)
defaultMaxRequestBytes = int64(1_200_000)
defaultPreviewTTL = time.Hour
defaultShutdownTimeout = 10 * time.Second
)
type Config struct {
Addr string
MaxMarkdownBytes int64
MaxRequestBytes int64
PreviewTTL time.Duration
ShutdownTimeout time.Duration
}
func LoadConfig() (Config, error) {
maxMarkdownBytes, err := loadPositiveInt64("MAX_MARKDOWN_BYTES", defaultMaxMarkdownBytes)
if err != nil {
return Config{}, err
}
maxRequestBytes, err := loadPositiveInt64("MAX_REQUEST_BYTES", defaultMaxRequestBytes)
if err != nil {
return Config{}, err
}
previewTTL, err := loadDuration("PREVIEW_TTL", defaultPreviewTTL)
if err != nil {
return Config{}, err
}
shutdownTimeout, err := loadDuration("SHUTDOWN_TIMEOUT", defaultShutdownTimeout)
if err != nil {
return Config{}, err
}
addr := strings.TrimSpace(os.Getenv("ADDR"))
if addr == "" {
addr = defaultAddr
}
return Config{
Addr: addr,
MaxMarkdownBytes: maxMarkdownBytes,
MaxRequestBytes: maxRequestBytes,
PreviewTTL: previewTTL,
ShutdownTimeout: shutdownTimeout,
}, nil
}
func loadPositiveInt64(name string, fallback int64) (int64, error) {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return fallback, nil
}
value, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return 0, fmt.Errorf("%s must be an integer: %w", name, err)
}
if value <= 0 {
return 0, fmt.Errorf("%s must be positive", name)
}
return value, nil
}
func loadDuration(name string, fallback time.Duration) (time.Duration, error) {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return fallback, nil
}
value, err := time.ParseDuration(raw)
if err != nil {
return 0, fmt.Errorf("%s must be a valid duration: %w", name, err)
}
if value <= 0 {
return 0, fmt.Errorf("%s must be positive", name)
}
return value, nil
}
+248
View File
@@ -0,0 +1,248 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"strings"
"github.com/fserg/md-to-html/internal/converter"
"github.com/fserg/md-to-html/internal/version"
"github.com/go-chi/chi/v5"
)
const defaultDocumentTitle = "Document"
type Server struct {
cfg Config
conv *converter.Converter
store *PreviewStore
log *slog.Logger
}
type convertRequest struct {
Markdown string `json:"markdown"`
Title string `json:"title,omitempty"`
}
func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) {
if !hasJSONContentType(r.Header.Get("Content-Type")) {
writeJSON(w, http.StatusUnsupportedMediaType, map[string]string{
"detail": "content-type must be application/json",
})
return
}
var payload convertRequest
if err := decodeJSON(r, &payload); err != nil {
s.writeDecodeError(w, err)
return
}
result, err := s.convertMarkdown(payload.Markdown, payload.Title)
if err != nil {
s.writeConvertError(w, err)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(result.HTML)
}
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
func (s *Server) handleVersion(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"version": version.Version})
}
func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{
"status": "ok",
"template_loaded": s.conv != nil,
})
}
func (s *Server) handleHome(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("UI coming in phase 4"))
}
func (s *Server) handleUIConvert(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
s.writeDecodeError(w, err)
return
}
result, err := s.convertMarkdown(r.Form.Get("markdown"), r.Form.Get("title"))
if err != nil {
s.writeConvertError(w, err)
return
}
filename := htmlFilename(result.Title)
previewID := s.store.Put(result.HTML, "text/html; charset=utf-8", filename)
downloadID := s.store.Put(result.HTML, "text/html; charset=utf-8", filename)
fragment := fmt.Sprintf(
`<div><p>Result ready</p><a href="/preview/%s" target="_blank" rel="noopener">Preview</a> <a href="/download/%s">Download</a></div>`,
previewID,
downloadID,
)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(fragment))
}
func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
item, ok := s.store.Take(id)
if !ok {
http.NotFound(w, r)
return
}
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", contentTypeOrDefault(item.mime))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(item.html)
}
func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
item, ok := s.store.Take(id)
if !ok {
http.NotFound(w, r)
return
}
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Content-Type", contentTypeOrDefault(item.mime))
w.Header().Set("Content-Disposition", mime.FormatMediaType("attachment", map[string]string{
"filename": item.filename,
}))
w.WriteHeader(http.StatusOK)
_, _ = w.Write(item.html)
}
func (s *Server) convertMarkdown(markdown, title string) (converter.Result, error) {
if strings.TrimSpace(markdown) == "" {
return converter.Result{}, errEmptyMarkdown
}
if int64(len([]byte(markdown))) > s.cfg.MaxMarkdownBytes {
return converter.Result{}, errMarkdownTooLarge{limit: s.cfg.MaxMarkdownBytes}
}
fallbackTitle := strings.TrimSpace(title)
if fallbackTitle == "" {
fallbackTitle = defaultDocumentTitle
}
result, err := s.conv.Convert([]byte(markdown), fallbackTitle)
if err != nil {
return converter.Result{}, fmt.Errorf("convert markdown: %w", err)
}
return result, nil
}
func (s *Server) writeDecodeError(w http.ResponseWriter, err error) {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{
"detail": fmt.Sprintf("request exceeds %d bytes", s.cfg.MaxRequestBytes),
})
return
}
writeJSON(w, http.StatusBadRequest, map[string]string{"detail": "invalid request payload"})
}
func (s *Server) writeConvertError(w http.ResponseWriter, err error) {
var markdownTooLarge errMarkdownTooLarge
switch {
case errors.Is(err, errEmptyMarkdown):
writeJSON(w, http.StatusBadRequest, map[string]string{"detail": err.Error()})
case errors.As(err, &markdownTooLarge):
writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{
"detail": markdownTooLarge.Error(),
})
default:
s.log.Error("convert_failed", "error", err)
writeJSON(w, http.StatusBadGateway, map[string]string{"detail": err.Error()})
}
}
func hasJSONContentType(value string) bool {
mediaType, _, err := mime.ParseMediaType(value)
return err == nil && mediaType == "application/json"
}
func decodeJSON(r *http.Request, dst any) error {
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(dst); err != nil {
return err
}
var extra json.RawMessage
if err := dec.Decode(&extra); err != nil && !errors.Is(err, io.EOF) {
return err
}
if len(extra) > 0 {
return errors.New("unexpected trailing JSON data")
}
return nil
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
_ = enc.Encode(payload)
}
func htmlFilename(title string) string {
name := strings.TrimSpace(title)
if name == "" {
name = "document"
}
replacer := strings.NewReplacer("/", "-", "\\", "-", "\"", "", "\n", " ", "\r", " ")
name = strings.TrimSpace(replacer.Replace(name))
if name == "" {
name = "document"
}
return name + ".html"
}
func contentTypeOrDefault(value string) string {
if strings.TrimSpace(value) == "" {
return "text/html; charset=utf-8"
}
return value
}
var errEmptyMarkdown = errors.New("markdown must not be empty")
type errMarkdownTooLarge struct {
limit int64
}
func (e errMarkdownTooLarge) Error() string {
return fmt.Sprintf("markdown exceeds %d bytes", e.limit)
}
+59
View File
@@ -0,0 +1,59 @@
package server
import (
"log/slog"
"net/http"
"time"
chimiddleware "github.com/go-chi/chi/v5/middleware"
)
func MaxBytesMiddleware(limit int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
r.Body = http.MaxBytesReader(w, r.Body, limit)
}
next.ServeHTTP(w, r)
})
}
}
func CORSMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
headers.Set("Access-Control-Allow-Origin", "*")
headers.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
headers.Set("Access-Control-Allow-Headers", "content-type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
}
func RequestLogger(log *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor)
start := time.Now()
next.ServeHTTP(ww, r)
log.Info(
"http_request",
"request_id", chimiddleware.GetReqID(r.Context()),
"method", r.Method,
"path", r.URL.Path,
"status", ww.Status(),
"bytes", ww.BytesWritten(),
"duration", time.Since(start),
)
})
}
}
+91
View File
@@ -0,0 +1,91 @@
package server
import (
"context"
"sync"
"time"
"github.com/google/uuid"
)
const janitorInterval = 5 * time.Minute
type PreviewStore struct {
mu sync.Mutex
items map[string]previewItem
ttl time.Duration
now func() time.Time
}
type previewItem struct {
html []byte
mime string
filename string
expires time.Time
}
func NewPreviewStore(ttl time.Duration) *PreviewStore {
return &PreviewStore{
items: make(map[string]previewItem),
ttl: ttl,
now: time.Now,
}
}
func (s *PreviewStore) Put(html []byte, mime, filename string) string {
s.mu.Lock()
defer s.mu.Unlock()
id := uuid.NewString()
s.items[id] = previewItem{
html: append([]byte(nil), html...),
mime: mime,
filename: filename,
expires: s.now().Add(s.ttl),
}
return id
}
func (s *PreviewStore) Take(id string) (previewItem, bool) {
s.mu.Lock()
defer s.mu.Unlock()
item, ok := s.items[id]
if !ok {
return previewItem{}, false
}
delete(s.items, id)
if s.now().After(item.expires) {
return previewItem{}, false
}
item.html = append([]byte(nil), item.html...)
return item, true
}
func (s *PreviewStore) janitor(ctx context.Context) {
ticker := time.NewTicker(janitorInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case now := <-ticker.C:
s.cleanupExpired(now)
}
}
}
func (s *PreviewStore) cleanupExpired(now time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
for id, item := range s.items {
if now.After(item.expires) {
delete(s.items, id)
}
}
}
+80
View File
@@ -0,0 +1,80 @@
package server
import (
"context"
"sync"
"testing"
"time"
)
func TestPreviewStore_OneShot(t *testing.T) {
t.Parallel()
store := NewPreviewStore(time.Hour)
id := store.Put([]byte("<h1>Hello</h1>"), "text/html; charset=utf-8", "hello.html")
item, ok := store.Take(id)
if !ok {
t.Fatalf("expected first take to succeed")
}
if got := string(item.html); got != "<h1>Hello</h1>" {
t.Fatalf("unexpected html: %q", got)
}
if _, ok := store.Take(id); ok {
t.Fatalf("expected second take to miss")
}
}
func TestPreviewStore_TTL(t *testing.T) {
t.Parallel()
store := NewPreviewStore(10 * time.Millisecond)
id := store.Put([]byte("expired"), "text/html; charset=utf-8", "expired.html")
time.Sleep(30 * time.Millisecond)
store.cleanupExpired(time.Now())
if _, ok := store.Take(id); ok {
t.Fatalf("expected expired item to be removed")
}
}
func TestPreviewStore_Concurrent(t *testing.T) {
t.Parallel()
store := NewPreviewStore(time.Hour)
var wg sync.WaitGroup
for i := 0; i < 32; i++ {
wg.Add(1)
go func() {
defer wg.Done()
id := store.Put([]byte("payload"), "text/html; charset=utf-8", "payload.html")
store.Take(id)
}()
}
wg.Wait()
}
func TestPreviewStore_JanitorStopsWithContext(t *testing.T) {
t.Parallel()
store := NewPreviewStore(time.Hour)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
store.janitor(ctx)
close(done)
}()
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("janitor did not stop after context cancellation")
}
}
+90
View File
@@ -0,0 +1,90 @@
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"time"
"github.com/fserg/md-to-html/internal/converter"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
)
func New(cfg Config, conv *converter.Converter) (*Server, error) {
if conv == nil {
return nil, errors.New("converter is required")
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
return &Server{
cfg: cfg,
conv: conv,
store: NewPreviewStore(cfg.PreviewTTL),
log: logger,
}, nil
}
func (s *Server) Router() http.Handler {
r := chi.NewRouter()
r.Use(chimiddleware.RequestID)
r.Use(CORSMiddleware())
r.Use(MaxBytesMiddleware(s.cfg.MaxRequestBytes))
r.Use(RequestLogger(s.log))
r.Use(chimiddleware.Recoverer)
r.Use(chimiddleware.Timeout(30 * time.Second))
r.Get("/", s.handleHome)
r.Post("/convert", s.handleConvert)
r.Get("/health", s.handleHealth)
r.Get("/version", s.handleVersion)
r.Get("/ready", s.handleReady)
r.Post("/ui/convert", s.handleUIConvert)
r.Get("/preview/{id}", s.handlePreview)
r.Get("/download/{id}", s.handleDownload)
return r
}
func (s *Server) Run(ctx context.Context) error {
httpServer := &http.Server{
Addr: s.cfg.Addr,
Handler: s.Router(),
}
errCh := make(chan error, 1)
go s.store.janitor(ctx)
go func() {
s.log.Info("server starting", "addr", s.cfg.Addr)
errCh <- httpServer.ListenAndServe()
}()
select {
case <-ctx.Done():
s.log.Info("shutting down", "timeout", s.cfg.ShutdownTimeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), s.cfg.ShutdownTimeout)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("shutdown server: %w", err)
}
if err := <-errCh; err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server exited after shutdown: %w", err)
}
return nil
case err := <-errCh:
if err == nil || errors.Is(err, http.ErrServerClosed) {
return nil
}
return fmt.Errorf("serve: %w", err)
}
}
+329
View File
@@ -0,0 +1,329 @@
package server
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/fserg/md-to-html/internal/converter"
"github.com/fserg/md-to-html/internal/version"
webtemplate "github.com/fserg/md-to-html/web/template"
)
func TestConvertEndpoint(t *testing.T) {
srv := newTestServer(t, Config{
Addr: ":0",
MaxMarkdownBytes: 128,
MaxRequestBytes: 256,
PreviewTTL: time.Hour,
ShutdownTimeout: time.Second,
})
ts := httptest.NewServer(srv.Router())
defer ts.Close()
tests := []struct {
name string
body string
contentType string
wantStatus int
wantType string
wantBody string
}{
{
name: "valid markdown",
body: `{"markdown":"# Hello"}`,
contentType: "application/json",
wantStatus: http.StatusOK,
wantType: "text/html; charset=utf-8",
wantBody: "<!DOCTYPE html>",
},
{
name: "empty markdown",
body: `{"markdown":" "}`,
contentType: "application/json",
wantStatus: http.StatusBadRequest,
wantType: "application/json; charset=utf-8",
wantBody: `{"detail":"markdown must not be empty"}`,
},
{
name: "markdown too large",
body: `{"markdown":"` + strings.Repeat("a", 129) + `"}`,
contentType: "application/json",
wantStatus: http.StatusRequestEntityTooLarge,
wantType: "application/json; charset=utf-8",
wantBody: `{"detail":"markdown exceeds 128 bytes"}`,
},
{
name: "missing content type",
body: `{"markdown":"# Hello"}`,
contentType: "",
wantStatus: http.StatusUnsupportedMediaType,
wantType: "application/json; charset=utf-8",
wantBody: `{"detail":"content-type must be application/json"}`,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, ts.URL+"/convert", strings.NewReader(tc.body))
if err != nil {
t.Fatalf("new request: %v", err)
}
if tc.contentType != "" {
req.Header.Set("Content-Type", tc.contentType)
}
resp, err := ts.Client().Do(req)
if err != nil {
t.Fatalf("do request: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
if resp.StatusCode != tc.wantStatus {
t.Fatalf("status = %d, want %d; body=%s", resp.StatusCode, tc.wantStatus, body)
}
if got := resp.Header.Get("Content-Type"); got != tc.wantType {
t.Fatalf("content-type = %q, want %q", got, tc.wantType)
}
if !bytes.Contains(body, []byte(tc.wantBody)) {
t.Fatalf("body %q does not contain %q", body, tc.wantBody)
}
})
}
}
func TestConvertEndpoint_RequestLimit(t *testing.T) {
t.Parallel()
srv := newTestServer(t, Config{
Addr: ":0",
MaxMarkdownBytes: 1_048_576,
MaxRequestBytes: 64,
PreviewTTL: time.Hour,
ShutdownTimeout: time.Second,
})
ts := httptest.NewServer(srv.Router())
defer ts.Close()
req, err := http.NewRequest(http.MethodPost, ts.URL+"/convert", strings.NewReader(`{"markdown":"`+strings.Repeat("a", 100)+`"}`))
if err != nil {
t.Fatalf("new request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := ts.Client().Do(req)
if err != nil {
t.Fatalf("do request: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
if resp.StatusCode != http.StatusRequestEntityTooLarge {
t.Fatalf("status = %d, want %d; body=%s", resp.StatusCode, http.StatusRequestEntityTooLarge, body)
}
if !bytes.Contains(body, []byte(`{"detail":"request exceeds 64 bytes"}`)) {
t.Fatalf("unexpected body: %s", body)
}
}
func TestStatusEndpoints(t *testing.T) {
originalVersion := version.Version
version.Version = "dev"
t.Cleanup(func() {
version.Version = originalVersion
})
srv := newTestServer(t, defaultTestConfig())
ts := httptest.NewServer(srv.Router())
defer ts.Close()
tests := []struct {
path string
want map[string]any
}{
{path: "/health", want: map[string]any{"status": "ok"}},
{path: "/version", want: map[string]any{"version": "dev"}},
{path: "/ready", want: map[string]any{"status": "ok", "template_loaded": true}},
}
for _, tc := range tests {
tc := tc
t.Run(tc.path, func(t *testing.T) {
resp, err := ts.Client().Get(ts.URL + tc.path)
if err != nil {
t.Fatalf("get %s: %v", tc.path, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
var got map[string]any
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
t.Fatalf("decode body: %v", err)
}
for key, wantValue := range tc.want {
if got[key] != wantValue {
t.Fatalf("%s[%q] = %v, want %v", tc.path, key, got[key], wantValue)
}
}
})
}
}
func TestPreviewAndDownloadOneShot(t *testing.T) {
t.Parallel()
srv := newTestServer(t, defaultTestConfig())
previewID := srv.store.Put([]byte("<h1>Preview</h1>"), "text/html; charset=utf-8", "preview.html")
downloadID := srv.store.Put([]byte("<h1>Download</h1>"), "text/html; charset=utf-8", "download.html")
ts := httptest.NewServer(srv.Router())
defer ts.Close()
resp, err := ts.Client().Get(ts.URL + "/preview/" + previewID)
if err != nil {
t.Fatalf("get preview: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read preview body: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("preview status = %d, want %d", resp.StatusCode, http.StatusOK)
}
if got := resp.Header.Get("Cache-Control"); got != "no-store" {
t.Fatalf("preview cache-control = %q, want %q", got, "no-store")
}
if string(body) != "<h1>Preview</h1>" {
t.Fatalf("preview body = %q", body)
}
resp, err = ts.Client().Get(ts.URL + "/preview/" + previewID)
if err != nil {
t.Fatalf("get preview second time: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("second preview status = %d, want %d", resp.StatusCode, http.StatusNotFound)
}
resp, err = ts.Client().Get(ts.URL + "/download/" + downloadID)
if err != nil {
t.Fatalf("get download: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("download status = %d, want %d", resp.StatusCode, http.StatusOK)
}
if got := resp.Header.Get("Content-Disposition"); !strings.Contains(got, `attachment; filename=preview.html`) && !strings.Contains(got, `attachment; filename=download.html`) {
t.Fatalf("unexpected content-disposition: %q", got)
}
resp, err = ts.Client().Get(ts.URL + "/download/" + downloadID)
if err != nil {
t.Fatalf("get download second time: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("second download status = %d, want %d", resp.StatusCode, http.StatusNotFound)
}
}
func TestPreviewMissing(t *testing.T) {
t.Parallel()
srv := newTestServer(t, defaultTestConfig())
ts := httptest.NewServer(srv.Router())
defer ts.Close()
resp, err := ts.Client().Get(ts.URL + "/preview/nonexistent")
if err != nil {
t.Fatalf("get preview: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)
}
}
func TestCORSPreflight(t *testing.T) {
t.Parallel()
srv := newTestServer(t, defaultTestConfig())
ts := httptest.NewServer(srv.Router())
defer ts.Close()
req, err := http.NewRequest(http.MethodOptions, ts.URL+"/convert", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
req.Header.Set("Origin", "https://evil.com")
req.Header.Set("Access-Control-Request-Method", http.MethodPost)
resp, err := ts.Client().Do(req)
if err != nil {
t.Fatalf("do request: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" {
t.Fatalf("allow-origin = %q, want %q", got, "*")
}
if got := resp.Header.Get("Access-Control-Allow-Methods"); got != "POST, GET, OPTIONS" {
t.Fatalf("allow-methods = %q", got)
}
}
func newTestServer(t *testing.T, cfg Config) *Server {
t.Helper()
conv, err := converter.New(webtemplate.FS)
if err != nil {
t.Fatalf("new converter: %v", err)
}
srv, err := New(cfg, conv)
if err != nil {
t.Fatalf("new server: %v", err)
}
return srv
}
func defaultTestConfig() Config {
return Config{
Addr: ":0",
MaxMarkdownBytes: 1_048_576,
MaxRequestBytes: 1_200_000,
PreviewTTL: time.Hour,
ShutdownTimeout: time.Second,
}
}