From 843d8dc710512f1d9bfc3bf845c449c049ff90c9 Mon Sep 17 00:00:00 2001 From: Sergey Filkin Date: Sat, 18 Apr 2026 11:55:42 +0300 Subject: [PATCH] phase3: HTTP server with converter, one-shot preview store, and middleware --- cmd/md-to-html/main.go | 46 +++- internal/server/config.go | 94 ++++++++ internal/server/handlers.go | 248 +++++++++++++++++++ internal/server/middleware.go | 59 +++++ internal/server/preview_store.go | 91 +++++++ internal/server/preview_store_test.go | 80 +++++++ internal/server/server.go | 90 +++++++ internal/server/server_test.go | 329 ++++++++++++++++++++++++++ 8 files changed, 1033 insertions(+), 4 deletions(-) create mode 100644 internal/server/config.go create mode 100644 internal/server/handlers.go create mode 100644 internal/server/middleware.go create mode 100644 internal/server/preview_store.go create mode 100644 internal/server/preview_store_test.go create mode 100644 internal/server/server.go create mode 100644 internal/server/server_test.go diff --git a/cmd/md-to-html/main.go b/cmd/md-to-html/main.go index badd8bc..ace3c45 100644 --- a/cmd/md-to-html/main.go +++ b/cmd/md-to-html/main.go @@ -1,12 +1,18 @@ package main import ( + "context" "flag" "fmt" "io" "os" + "os/signal" + "syscall" + "github.com/fserg/md-to-html/internal/converter" + "github.com/fserg/md-to-html/internal/server" "github.com/fserg/md-to-html/internal/version" + webtemplate "github.com/fserg/md-to-html/web/template" ) func main() { @@ -24,7 +30,7 @@ func run(args []string, stdout, stderr io.Writer) int { printUsage(stdout) return 0 case "serve": - return runServe(args[1:], stdout) + return runServe(args[1:], stdout, stderr) case "cli": return runCLI(args[1:], stdout, stderr) case "version": @@ -36,13 +42,45 @@ func run(args []string, stdout, stderr io.Writer) int { } } -func runServe(args []string, stdout io.Writer) int { +func runServe(args []string, stdout, stderr io.Writer) int { fs := flag.NewFlagSet("serve", flag.ContinueOnError) fs.SetOutput(io.Discard) if err := fs.Parse(args); err != nil { return 2 } - fmt.Fprintln(stdout, "serve not implemented yet") + + if fs.NArg() != 0 { + fmt.Fprintln(stderr, "usage: md-to-html serve") + return 2 + } + + cfg, err := server.LoadConfig() + if err != nil { + fmt.Fprintf(stderr, "load config: %v\n", err) + return 1 + } + + conv, err := converter.New(webtemplate.FS) + if err != nil { + fmt.Fprintf(stderr, "load converter: %v\n", err) + return 1 + } + + srv, err := server.New(cfg, conv) + if err != nil { + fmt.Fprintf(stderr, "create server: %v\n", err) + return 1 + } + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + if err := srv.Run(ctx); err != nil { + fmt.Fprintf(stderr, "run server: %v\n", err) + return 1 + } + + _ = stdout return 0 } @@ -85,7 +123,7 @@ func printUsage(w io.Writer) { md-to-html version Commands: - serve Start the HTTP server stub + serve Start the HTTP server cli Convert a Markdown file stub version Print the build version `) diff --git a/internal/server/config.go b/internal/server/config.go new file mode 100644 index 0000000..bbc976d --- /dev/null +++ b/internal/server/config.go @@ -0,0 +1,94 @@ +package server + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +const ( + defaultAddr = ":8080" + defaultMaxMarkdownBytes = int64(1_048_576) + defaultMaxRequestBytes = int64(1_200_000) + defaultPreviewTTL = time.Hour + defaultShutdownTimeout = 10 * time.Second +) + +type Config struct { + Addr string + MaxMarkdownBytes int64 + MaxRequestBytes int64 + PreviewTTL time.Duration + ShutdownTimeout time.Duration +} + +func LoadConfig() (Config, error) { + maxMarkdownBytes, err := loadPositiveInt64("MAX_MARKDOWN_BYTES", defaultMaxMarkdownBytes) + if err != nil { + return Config{}, err + } + + maxRequestBytes, err := loadPositiveInt64("MAX_REQUEST_BYTES", defaultMaxRequestBytes) + if err != nil { + return Config{}, err + } + + previewTTL, err := loadDuration("PREVIEW_TTL", defaultPreviewTTL) + if err != nil { + return Config{}, err + } + + shutdownTimeout, err := loadDuration("SHUTDOWN_TIMEOUT", defaultShutdownTimeout) + if err != nil { + return Config{}, err + } + + addr := strings.TrimSpace(os.Getenv("ADDR")) + if addr == "" { + addr = defaultAddr + } + + return Config{ + Addr: addr, + MaxMarkdownBytes: maxMarkdownBytes, + MaxRequestBytes: maxRequestBytes, + PreviewTTL: previewTTL, + ShutdownTimeout: shutdownTimeout, + }, nil +} + +func loadPositiveInt64(name string, fallback int64) (int64, error) { + raw := strings.TrimSpace(os.Getenv(name)) + if raw == "" { + return fallback, nil + } + + value, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return 0, fmt.Errorf("%s must be an integer: %w", name, err) + } + if value <= 0 { + return 0, fmt.Errorf("%s must be positive", name) + } + + return value, nil +} + +func loadDuration(name string, fallback time.Duration) (time.Duration, error) { + raw := strings.TrimSpace(os.Getenv(name)) + if raw == "" { + return fallback, nil + } + + value, err := time.ParseDuration(raw) + if err != nil { + return 0, fmt.Errorf("%s must be a valid duration: %w", name, err) + } + if value <= 0 { + return 0, fmt.Errorf("%s must be positive", name) + } + + return value, nil +} diff --git a/internal/server/handlers.go b/internal/server/handlers.go new file mode 100644 index 0000000..809f4c1 --- /dev/null +++ b/internal/server/handlers.go @@ -0,0 +1,248 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "mime" + "net/http" + "strings" + + "github.com/fserg/md-to-html/internal/converter" + "github.com/fserg/md-to-html/internal/version" + "github.com/go-chi/chi/v5" +) + +const defaultDocumentTitle = "Document" + +type Server struct { + cfg Config + conv *converter.Converter + store *PreviewStore + log *slog.Logger +} + +type convertRequest struct { + Markdown string `json:"markdown"` + Title string `json:"title,omitempty"` +} + +func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) { + if !hasJSONContentType(r.Header.Get("Content-Type")) { + writeJSON(w, http.StatusUnsupportedMediaType, map[string]string{ + "detail": "content-type must be application/json", + }) + return + } + + var payload convertRequest + if err := decodeJSON(r, &payload); err != nil { + s.writeDecodeError(w, err) + return + } + + result, err := s.convertMarkdown(payload.Markdown, payload.Title) + if err != nil { + s.writeConvertError(w, err) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(result.HTML) +} + +func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +func (s *Server) handleVersion(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"version": version.Version}) +} + +func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "template_loaded": s.conv != nil, + }) +} + +func (s *Server) handleHome(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("UI coming in phase 4")) +} + +func (s *Server) handleUIConvert(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + s.writeDecodeError(w, err) + return + } + + result, err := s.convertMarkdown(r.Form.Get("markdown"), r.Form.Get("title")) + if err != nil { + s.writeConvertError(w, err) + return + } + + filename := htmlFilename(result.Title) + previewID := s.store.Put(result.HTML, "text/html; charset=utf-8", filename) + downloadID := s.store.Put(result.HTML, "text/html; charset=utf-8", filename) + + fragment := fmt.Sprintf( + `

Result ready

Preview Download
`, + previewID, + downloadID, + ) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fragment)) +} + +func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + item, ok := s.store.Take(id) + if !ok { + http.NotFound(w, r) + return + } + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Content-Type", contentTypeOrDefault(item.mime)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(item.html) +} + +func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + item, ok := s.store.Take(id) + if !ok { + http.NotFound(w, r) + return + } + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Content-Type", contentTypeOrDefault(item.mime)) + w.Header().Set("Content-Disposition", mime.FormatMediaType("attachment", map[string]string{ + "filename": item.filename, + })) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(item.html) +} + +func (s *Server) convertMarkdown(markdown, title string) (converter.Result, error) { + if strings.TrimSpace(markdown) == "" { + return converter.Result{}, errEmptyMarkdown + } + + if int64(len([]byte(markdown))) > s.cfg.MaxMarkdownBytes { + return converter.Result{}, errMarkdownTooLarge{limit: s.cfg.MaxMarkdownBytes} + } + + fallbackTitle := strings.TrimSpace(title) + if fallbackTitle == "" { + fallbackTitle = defaultDocumentTitle + } + + result, err := s.conv.Convert([]byte(markdown), fallbackTitle) + if err != nil { + return converter.Result{}, fmt.Errorf("convert markdown: %w", err) + } + + return result, nil +} + +func (s *Server) writeDecodeError(w http.ResponseWriter, err error) { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{ + "detail": fmt.Sprintf("request exceeds %d bytes", s.cfg.MaxRequestBytes), + }) + return + } + + writeJSON(w, http.StatusBadRequest, map[string]string{"detail": "invalid request payload"}) +} + +func (s *Server) writeConvertError(w http.ResponseWriter, err error) { + var markdownTooLarge errMarkdownTooLarge + switch { + case errors.Is(err, errEmptyMarkdown): + writeJSON(w, http.StatusBadRequest, map[string]string{"detail": err.Error()}) + case errors.As(err, &markdownTooLarge): + writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{ + "detail": markdownTooLarge.Error(), + }) + default: + s.log.Error("convert_failed", "error", err) + writeJSON(w, http.StatusBadGateway, map[string]string{"detail": err.Error()}) + } +} + +func hasJSONContentType(value string) bool { + mediaType, _, err := mime.ParseMediaType(value) + return err == nil && mediaType == "application/json" +} + +func decodeJSON(r *http.Request, dst any) error { + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(dst); err != nil { + return err + } + + var extra json.RawMessage + if err := dec.Decode(&extra); err != nil && !errors.Is(err, io.EOF) { + return err + } + + if len(extra) > 0 { + return errors.New("unexpected trailing JSON data") + } + + return nil +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + + enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) + _ = enc.Encode(payload) +} + +func htmlFilename(title string) string { + name := strings.TrimSpace(title) + if name == "" { + name = "document" + } + + replacer := strings.NewReplacer("/", "-", "\\", "-", "\"", "", "\n", " ", "\r", " ") + name = strings.TrimSpace(replacer.Replace(name)) + if name == "" { + name = "document" + } + + return name + ".html" +} + +func contentTypeOrDefault(value string) string { + if strings.TrimSpace(value) == "" { + return "text/html; charset=utf-8" + } + return value +} + +var errEmptyMarkdown = errors.New("markdown must not be empty") + +type errMarkdownTooLarge struct { + limit int64 +} + +func (e errMarkdownTooLarge) Error() string { + return fmt.Sprintf("markdown exceeds %d bytes", e.limit) +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..504b3da --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,59 @@ +package server + +import ( + "log/slog" + "net/http" + "time" + + chimiddleware "github.com/go-chi/chi/v5/middleware" +) + +func MaxBytesMiddleware(limit int64) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, limit) + } + next.ServeHTTP(w, r) + }) + } +} + +func CORSMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers := w.Header() + headers.Set("Access-Control-Allow-Origin", "*") + headers.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + headers.Set("Access-Control-Allow-Headers", "content-type") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func RequestLogger(log *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor) + start := time.Now() + + next.ServeHTTP(ww, r) + + log.Info( + "http_request", + "request_id", chimiddleware.GetReqID(r.Context()), + "method", r.Method, + "path", r.URL.Path, + "status", ww.Status(), + "bytes", ww.BytesWritten(), + "duration", time.Since(start), + ) + }) + } +} diff --git a/internal/server/preview_store.go b/internal/server/preview_store.go new file mode 100644 index 0000000..e5cbb8f --- /dev/null +++ b/internal/server/preview_store.go @@ -0,0 +1,91 @@ +package server + +import ( + "context" + "sync" + "time" + + "github.com/google/uuid" +) + +const janitorInterval = 5 * time.Minute + +type PreviewStore struct { + mu sync.Mutex + items map[string]previewItem + ttl time.Duration + now func() time.Time +} + +type previewItem struct { + html []byte + mime string + filename string + expires time.Time +} + +func NewPreviewStore(ttl time.Duration) *PreviewStore { + return &PreviewStore{ + items: make(map[string]previewItem), + ttl: ttl, + now: time.Now, + } +} + +func (s *PreviewStore) Put(html []byte, mime, filename string) string { + s.mu.Lock() + defer s.mu.Unlock() + + id := uuid.NewString() + s.items[id] = previewItem{ + html: append([]byte(nil), html...), + mime: mime, + filename: filename, + expires: s.now().Add(s.ttl), + } + + return id +} + +func (s *PreviewStore) Take(id string) (previewItem, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + item, ok := s.items[id] + if !ok { + return previewItem{}, false + } + + delete(s.items, id) + if s.now().After(item.expires) { + return previewItem{}, false + } + + item.html = append([]byte(nil), item.html...) + return item, true +} + +func (s *PreviewStore) janitor(ctx context.Context) { + ticker := time.NewTicker(janitorInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + s.cleanupExpired(now) + } + } +} + +func (s *PreviewStore) cleanupExpired(now time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + + for id, item := range s.items { + if now.After(item.expires) { + delete(s.items, id) + } + } +} diff --git a/internal/server/preview_store_test.go b/internal/server/preview_store_test.go new file mode 100644 index 0000000..2d80aeb --- /dev/null +++ b/internal/server/preview_store_test.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestPreviewStore_OneShot(t *testing.T) { + t.Parallel() + + store := NewPreviewStore(time.Hour) + id := store.Put([]byte("

Hello

"), "text/html; charset=utf-8", "hello.html") + + item, ok := store.Take(id) + if !ok { + t.Fatalf("expected first take to succeed") + } + if got := string(item.html); got != "

Hello

" { + t.Fatalf("unexpected html: %q", got) + } + + if _, ok := store.Take(id); ok { + t.Fatalf("expected second take to miss") + } +} + +func TestPreviewStore_TTL(t *testing.T) { + t.Parallel() + + store := NewPreviewStore(10 * time.Millisecond) + id := store.Put([]byte("expired"), "text/html; charset=utf-8", "expired.html") + + time.Sleep(30 * time.Millisecond) + store.cleanupExpired(time.Now()) + + if _, ok := store.Take(id); ok { + t.Fatalf("expected expired item to be removed") + } +} + +func TestPreviewStore_Concurrent(t *testing.T) { + t.Parallel() + + store := NewPreviewStore(time.Hour) + var wg sync.WaitGroup + + for i := 0; i < 32; i++ { + wg.Add(1) + go func() { + defer wg.Done() + id := store.Put([]byte("payload"), "text/html; charset=utf-8", "payload.html") + store.Take(id) + }() + } + + wg.Wait() +} + +func TestPreviewStore_JanitorStopsWithContext(t *testing.T) { + t.Parallel() + + store := NewPreviewStore(time.Hour) + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + go func() { + store.janitor(ctx) + close(done) + }() + + cancel() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("janitor did not stop after context cancellation") + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..74820fd --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,90 @@ +package server + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + "time" + + "github.com/fserg/md-to-html/internal/converter" + "github.com/go-chi/chi/v5" + chimiddleware "github.com/go-chi/chi/v5/middleware" +) + +func New(cfg Config, conv *converter.Converter) (*Server, error) { + if conv == nil { + return nil, errors.New("converter is required") + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + + return &Server{ + cfg: cfg, + conv: conv, + store: NewPreviewStore(cfg.PreviewTTL), + log: logger, + }, nil +} + +func (s *Server) Router() http.Handler { + r := chi.NewRouter() + r.Use(chimiddleware.RequestID) + r.Use(CORSMiddleware()) + r.Use(MaxBytesMiddleware(s.cfg.MaxRequestBytes)) + r.Use(RequestLogger(s.log)) + r.Use(chimiddleware.Recoverer) + r.Use(chimiddleware.Timeout(30 * time.Second)) + + r.Get("/", s.handleHome) + r.Post("/convert", s.handleConvert) + r.Get("/health", s.handleHealth) + r.Get("/version", s.handleVersion) + r.Get("/ready", s.handleReady) + r.Post("/ui/convert", s.handleUIConvert) + r.Get("/preview/{id}", s.handlePreview) + r.Get("/download/{id}", s.handleDownload) + + return r +} + +func (s *Server) Run(ctx context.Context) error { + httpServer := &http.Server{ + Addr: s.cfg.Addr, + Handler: s.Router(), + } + + errCh := make(chan error, 1) + + go s.store.janitor(ctx) + go func() { + s.log.Info("server starting", "addr", s.cfg.Addr) + errCh <- httpServer.ListenAndServe() + }() + + select { + case <-ctx.Done(): + s.log.Info("shutting down", "timeout", s.cfg.ShutdownTimeout) + + shutdownCtx, cancel := context.WithTimeout(context.Background(), s.cfg.ShutdownTimeout) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + return fmt.Errorf("shutdown server: %w", err) + } + + if err := <-errCh; err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("server exited after shutdown: %w", err) + } + return nil + case err := <-errCh: + if err == nil || errors.Is(err, http.ErrServerClosed) { + return nil + } + return fmt.Errorf("serve: %w", err) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..628de75 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,329 @@ +package server + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/fserg/md-to-html/internal/converter" + "github.com/fserg/md-to-html/internal/version" + webtemplate "github.com/fserg/md-to-html/web/template" +) + +func TestConvertEndpoint(t *testing.T) { + srv := newTestServer(t, Config{ + Addr: ":0", + MaxMarkdownBytes: 128, + MaxRequestBytes: 256, + PreviewTTL: time.Hour, + ShutdownTimeout: time.Second, + }) + + ts := httptest.NewServer(srv.Router()) + defer ts.Close() + + tests := []struct { + name string + body string + contentType string + wantStatus int + wantType string + wantBody string + }{ + { + name: "valid markdown", + body: `{"markdown":"# Hello"}`, + contentType: "application/json", + wantStatus: http.StatusOK, + wantType: "text/html; charset=utf-8", + wantBody: "", + }, + { + name: "empty markdown", + body: `{"markdown":" "}`, + contentType: "application/json", + wantStatus: http.StatusBadRequest, + wantType: "application/json; charset=utf-8", + wantBody: `{"detail":"markdown must not be empty"}`, + }, + { + name: "markdown too large", + body: `{"markdown":"` + strings.Repeat("a", 129) + `"}`, + contentType: "application/json", + wantStatus: http.StatusRequestEntityTooLarge, + wantType: "application/json; charset=utf-8", + wantBody: `{"detail":"markdown exceeds 128 bytes"}`, + }, + { + name: "missing content type", + body: `{"markdown":"# Hello"}`, + contentType: "", + wantStatus: http.StatusUnsupportedMediaType, + wantType: "application/json; charset=utf-8", + wantBody: `{"detail":"content-type must be application/json"}`, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, ts.URL+"/convert", strings.NewReader(tc.body)) + if err != nil { + t.Fatalf("new request: %v", err) + } + if tc.contentType != "" { + req.Header.Set("Content-Type", tc.contentType) + } + + resp, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != tc.wantStatus { + t.Fatalf("status = %d, want %d; body=%s", resp.StatusCode, tc.wantStatus, body) + } + if got := resp.Header.Get("Content-Type"); got != tc.wantType { + t.Fatalf("content-type = %q, want %q", got, tc.wantType) + } + if !bytes.Contains(body, []byte(tc.wantBody)) { + t.Fatalf("body %q does not contain %q", body, tc.wantBody) + } + }) + } +} + +func TestConvertEndpoint_RequestLimit(t *testing.T) { + t.Parallel() + + srv := newTestServer(t, Config{ + Addr: ":0", + MaxMarkdownBytes: 1_048_576, + MaxRequestBytes: 64, + PreviewTTL: time.Hour, + ShutdownTimeout: time.Second, + }) + + ts := httptest.NewServer(srv.Router()) + defer ts.Close() + + req, err := http.NewRequest(http.MethodPost, ts.URL+"/convert", strings.NewReader(`{"markdown":"`+strings.Repeat("a", 100)+`"}`)) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want %d; body=%s", resp.StatusCode, http.StatusRequestEntityTooLarge, body) + } + if !bytes.Contains(body, []byte(`{"detail":"request exceeds 64 bytes"}`)) { + t.Fatalf("unexpected body: %s", body) + } +} + +func TestStatusEndpoints(t *testing.T) { + originalVersion := version.Version + version.Version = "dev" + t.Cleanup(func() { + version.Version = originalVersion + }) + + srv := newTestServer(t, defaultTestConfig()) + ts := httptest.NewServer(srv.Router()) + defer ts.Close() + + tests := []struct { + path string + want map[string]any + }{ + {path: "/health", want: map[string]any{"status": "ok"}}, + {path: "/version", want: map[string]any{"version": "dev"}}, + {path: "/ready", want: map[string]any{"status": "ok", "template_loaded": true}}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.path, func(t *testing.T) { + resp, err := ts.Client().Get(ts.URL + tc.path) + if err != nil { + t.Fatalf("get %s: %v", tc.path, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + var got map[string]any + if err := json.NewDecoder(resp.Body).Decode(&got); err != nil { + t.Fatalf("decode body: %v", err) + } + + for key, wantValue := range tc.want { + if got[key] != wantValue { + t.Fatalf("%s[%q] = %v, want %v", tc.path, key, got[key], wantValue) + } + } + }) + } +} + +func TestPreviewAndDownloadOneShot(t *testing.T) { + t.Parallel() + + srv := newTestServer(t, defaultTestConfig()) + previewID := srv.store.Put([]byte("

Preview

"), "text/html; charset=utf-8", "preview.html") + downloadID := srv.store.Put([]byte("

Download

"), "text/html; charset=utf-8", "download.html") + + ts := httptest.NewServer(srv.Router()) + defer ts.Close() + + resp, err := ts.Client().Get(ts.URL + "/preview/" + previewID) + if err != nil { + t.Fatalf("get preview: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read preview body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("preview status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if got := resp.Header.Get("Cache-Control"); got != "no-store" { + t.Fatalf("preview cache-control = %q, want %q", got, "no-store") + } + if string(body) != "

Preview

" { + t.Fatalf("preview body = %q", body) + } + + resp, err = ts.Client().Get(ts.URL + "/preview/" + previewID) + if err != nil { + t.Fatalf("get preview second time: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("second preview status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } + + resp, err = ts.Client().Get(ts.URL + "/download/" + downloadID) + if err != nil { + t.Fatalf("get download: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("download status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if got := resp.Header.Get("Content-Disposition"); !strings.Contains(got, `attachment; filename=preview.html`) && !strings.Contains(got, `attachment; filename=download.html`) { + t.Fatalf("unexpected content-disposition: %q", got) + } + + resp, err = ts.Client().Get(ts.URL + "/download/" + downloadID) + if err != nil { + t.Fatalf("get download second time: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("second download status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestPreviewMissing(t *testing.T) { + t.Parallel() + + srv := newTestServer(t, defaultTestConfig()) + ts := httptest.NewServer(srv.Router()) + defer ts.Close() + + resp, err := ts.Client().Get(ts.URL + "/preview/nonexistent") + if err != nil { + t.Fatalf("get preview: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestCORSPreflight(t *testing.T) { + t.Parallel() + + srv := newTestServer(t, defaultTestConfig()) + ts := httptest.NewServer(srv.Router()) + defer ts.Close() + + req, err := http.NewRequest(http.MethodOptions, ts.URL+"/convert", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("Origin", "https://evil.com") + req.Header.Set("Access-Control-Request-Method", http.MethodPost) + + resp, err := ts.Client().Do(req) + if err != nil { + t.Fatalf("do request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" { + t.Fatalf("allow-origin = %q, want %q", got, "*") + } + if got := resp.Header.Get("Access-Control-Allow-Methods"); got != "POST, GET, OPTIONS" { + t.Fatalf("allow-methods = %q", got) + } +} + +func newTestServer(t *testing.T, cfg Config) *Server { + t.Helper() + + conv, err := converter.New(webtemplate.FS) + if err != nil { + t.Fatalf("new converter: %v", err) + } + + srv, err := New(cfg, conv) + if err != nil { + t.Fatalf("new server: %v", err) + } + + return srv +} + +func defaultTestConfig() Config { + return Config{ + Addr: ":0", + MaxMarkdownBytes: 1_048_576, + MaxRequestBytes: 1_200_000, + PreviewTTL: time.Hour, + ShutdownTimeout: time.Second, + } +}