phase3: HTTP server with converter, one-shot preview store, and middleware
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAddr = ":8080"
|
||||
defaultMaxMarkdownBytes = int64(1_048_576)
|
||||
defaultMaxRequestBytes = int64(1_200_000)
|
||||
defaultPreviewTTL = time.Hour
|
||||
defaultShutdownTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Addr string
|
||||
MaxMarkdownBytes int64
|
||||
MaxRequestBytes int64
|
||||
PreviewTTL time.Duration
|
||||
ShutdownTimeout time.Duration
|
||||
}
|
||||
|
||||
func LoadConfig() (Config, error) {
|
||||
maxMarkdownBytes, err := loadPositiveInt64("MAX_MARKDOWN_BYTES", defaultMaxMarkdownBytes)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
maxRequestBytes, err := loadPositiveInt64("MAX_REQUEST_BYTES", defaultMaxRequestBytes)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
previewTTL, err := loadDuration("PREVIEW_TTL", defaultPreviewTTL)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
shutdownTimeout, err := loadDuration("SHUTDOWN_TIMEOUT", defaultShutdownTimeout)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
addr := strings.TrimSpace(os.Getenv("ADDR"))
|
||||
if addr == "" {
|
||||
addr = defaultAddr
|
||||
}
|
||||
|
||||
return Config{
|
||||
Addr: addr,
|
||||
MaxMarkdownBytes: maxMarkdownBytes,
|
||||
MaxRequestBytes: maxRequestBytes,
|
||||
PreviewTTL: previewTTL,
|
||||
ShutdownTimeout: shutdownTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadPositiveInt64(name string, fallback int64) (int64, error) {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
value, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s must be an integer: %w", name, err)
|
||||
}
|
||||
if value <= 0 {
|
||||
return 0, fmt.Errorf("%s must be positive", name)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func loadDuration(name string, fallback time.Duration) (time.Duration, error) {
|
||||
raw := strings.TrimSpace(os.Getenv(name))
|
||||
if raw == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
value, err := time.ParseDuration(raw)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s must be a valid duration: %w", name, err)
|
||||
}
|
||||
if value <= 0 {
|
||||
return 0, fmt.Errorf("%s must be positive", name)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/fserg/md-to-html/internal/converter"
|
||||
"github.com/fserg/md-to-html/internal/version"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const defaultDocumentTitle = "Document"
|
||||
|
||||
type Server struct {
|
||||
cfg Config
|
||||
conv *converter.Converter
|
||||
store *PreviewStore
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type convertRequest struct {
|
||||
Markdown string `json:"markdown"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleConvert(w http.ResponseWriter, r *http.Request) {
|
||||
if !hasJSONContentType(r.Header.Get("Content-Type")) {
|
||||
writeJSON(w, http.StatusUnsupportedMediaType, map[string]string{
|
||||
"detail": "content-type must be application/json",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var payload convertRequest
|
||||
if err := decodeJSON(r, &payload); err != nil {
|
||||
s.writeDecodeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := s.convertMarkdown(payload.Markdown, payload.Title)
|
||||
if err != nil {
|
||||
s.writeConvertError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(result.HTML)
|
||||
}
|
||||
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) handleVersion(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"version": version.Version})
|
||||
}
|
||||
|
||||
func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"status": "ok",
|
||||
"template_loaded": s.conv != nil,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleHome(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("UI coming in phase 4"))
|
||||
}
|
||||
|
||||
func (s *Server) handleUIConvert(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
s.writeDecodeError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := s.convertMarkdown(r.Form.Get("markdown"), r.Form.Get("title"))
|
||||
if err != nil {
|
||||
s.writeConvertError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
filename := htmlFilename(result.Title)
|
||||
previewID := s.store.Put(result.HTML, "text/html; charset=utf-8", filename)
|
||||
downloadID := s.store.Put(result.HTML, "text/html; charset=utf-8", filename)
|
||||
|
||||
fragment := fmt.Sprintf(
|
||||
`<div><p>Result ready</p><a href="/preview/%s" target="_blank" rel="noopener">Preview</a> <a href="/download/%s">Download</a></div>`,
|
||||
previewID,
|
||||
downloadID,
|
||||
)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(fragment))
|
||||
}
|
||||
|
||||
func (s *Server) handlePreview(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
item, ok := s.store.Take(id)
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", contentTypeOrDefault(item.mime))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(item.html)
|
||||
}
|
||||
|
||||
func (s *Server) handleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
item, ok := s.store.Take(id)
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Content-Type", contentTypeOrDefault(item.mime))
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType("attachment", map[string]string{
|
||||
"filename": item.filename,
|
||||
}))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(item.html)
|
||||
}
|
||||
|
||||
func (s *Server) convertMarkdown(markdown, title string) (converter.Result, error) {
|
||||
if strings.TrimSpace(markdown) == "" {
|
||||
return converter.Result{}, errEmptyMarkdown
|
||||
}
|
||||
|
||||
if int64(len([]byte(markdown))) > s.cfg.MaxMarkdownBytes {
|
||||
return converter.Result{}, errMarkdownTooLarge{limit: s.cfg.MaxMarkdownBytes}
|
||||
}
|
||||
|
||||
fallbackTitle := strings.TrimSpace(title)
|
||||
if fallbackTitle == "" {
|
||||
fallbackTitle = defaultDocumentTitle
|
||||
}
|
||||
|
||||
result, err := s.conv.Convert([]byte(markdown), fallbackTitle)
|
||||
if err != nil {
|
||||
return converter.Result{}, fmt.Errorf("convert markdown: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Server) writeDecodeError(w http.ResponseWriter, err error) {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{
|
||||
"detail": fmt.Sprintf("request exceeds %d bytes", s.cfg.MaxRequestBytes),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{"detail": "invalid request payload"})
|
||||
}
|
||||
|
||||
func (s *Server) writeConvertError(w http.ResponseWriter, err error) {
|
||||
var markdownTooLarge errMarkdownTooLarge
|
||||
switch {
|
||||
case errors.Is(err, errEmptyMarkdown):
|
||||
writeJSON(w, http.StatusBadRequest, map[string]string{"detail": err.Error()})
|
||||
case errors.As(err, &markdownTooLarge):
|
||||
writeJSON(w, http.StatusRequestEntityTooLarge, map[string]string{
|
||||
"detail": markdownTooLarge.Error(),
|
||||
})
|
||||
default:
|
||||
s.log.Error("convert_failed", "error", err)
|
||||
writeJSON(w, http.StatusBadGateway, map[string]string{"detail": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
func hasJSONContentType(value string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(value)
|
||||
return err == nil && mediaType == "application/json"
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, dst any) error {
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
if err := dec.Decode(dst); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var extra json.RawMessage
|
||||
if err := dec.Decode(&extra); err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(extra) > 0 {
|
||||
return errors.New("unexpected trailing JSON data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetEscapeHTML(false)
|
||||
_ = enc.Encode(payload)
|
||||
}
|
||||
|
||||
func htmlFilename(title string) string {
|
||||
name := strings.TrimSpace(title)
|
||||
if name == "" {
|
||||
name = "document"
|
||||
}
|
||||
|
||||
replacer := strings.NewReplacer("/", "-", "\\", "-", "\"", "", "\n", " ", "\r", " ")
|
||||
name = strings.TrimSpace(replacer.Replace(name))
|
||||
if name == "" {
|
||||
name = "document"
|
||||
}
|
||||
|
||||
return name + ".html"
|
||||
}
|
||||
|
||||
func contentTypeOrDefault(value string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "text/html; charset=utf-8"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
var errEmptyMarkdown = errors.New("markdown must not be empty")
|
||||
|
||||
type errMarkdownTooLarge struct {
|
||||
limit int64
|
||||
}
|
||||
|
||||
func (e errMarkdownTooLarge) Error() string {
|
||||
return fmt.Sprintf("markdown exceeds %d bytes", e.limit)
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func MaxBytesMiddleware(limit int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Body != nil {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, limit)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func CORSMiddleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
headers := w.Header()
|
||||
headers.Set("Access-Control-Allow-Origin", "*")
|
||||
headers.Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
|
||||
headers.Set("Access-Control-Allow-Headers", "content-type")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RequestLogger(log *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
start := time.Now()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
|
||||
log.Info(
|
||||
"http_request",
|
||||
"request_id", chimiddleware.GetReqID(r.Context()),
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", ww.Status(),
|
||||
"bytes", ww.BytesWritten(),
|
||||
"duration", time.Since(start),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const janitorInterval = 5 * time.Minute
|
||||
|
||||
type PreviewStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]previewItem
|
||||
ttl time.Duration
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type previewItem struct {
|
||||
html []byte
|
||||
mime string
|
||||
filename string
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
func NewPreviewStore(ttl time.Duration) *PreviewStore {
|
||||
return &PreviewStore{
|
||||
items: make(map[string]previewItem),
|
||||
ttl: ttl,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PreviewStore) Put(html []byte, mime, filename string) string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
id := uuid.NewString()
|
||||
s.items[id] = previewItem{
|
||||
html: append([]byte(nil), html...),
|
||||
mime: mime,
|
||||
filename: filename,
|
||||
expires: s.now().Add(s.ttl),
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *PreviewStore) Take(id string) (previewItem, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, ok := s.items[id]
|
||||
if !ok {
|
||||
return previewItem{}, false
|
||||
}
|
||||
|
||||
delete(s.items, id)
|
||||
if s.now().After(item.expires) {
|
||||
return previewItem{}, false
|
||||
}
|
||||
|
||||
item.html = append([]byte(nil), item.html...)
|
||||
return item, true
|
||||
}
|
||||
|
||||
func (s *PreviewStore) janitor(ctx context.Context) {
|
||||
ticker := time.NewTicker(janitorInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
s.cleanupExpired(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PreviewStore) cleanupExpired(now time.Time) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for id, item := range s.items {
|
||||
if now.After(item.expires) {
|
||||
delete(s.items, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPreviewStore_OneShot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewPreviewStore(time.Hour)
|
||||
id := store.Put([]byte("<h1>Hello</h1>"), "text/html; charset=utf-8", "hello.html")
|
||||
|
||||
item, ok := store.Take(id)
|
||||
if !ok {
|
||||
t.Fatalf("expected first take to succeed")
|
||||
}
|
||||
if got := string(item.html); got != "<h1>Hello</h1>" {
|
||||
t.Fatalf("unexpected html: %q", got)
|
||||
}
|
||||
|
||||
if _, ok := store.Take(id); ok {
|
||||
t.Fatalf("expected second take to miss")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewStore_TTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewPreviewStore(10 * time.Millisecond)
|
||||
id := store.Put([]byte("expired"), "text/html; charset=utf-8", "expired.html")
|
||||
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
store.cleanupExpired(time.Now())
|
||||
|
||||
if _, ok := store.Take(id); ok {
|
||||
t.Fatalf("expected expired item to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewStore_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewPreviewStore(time.Hour)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 32; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
id := store.Put([]byte("payload"), "text/html; charset=utf-8", "payload.html")
|
||||
store.Take(id)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPreviewStore_JanitorStopsWithContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewPreviewStore(time.Hour)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
store.janitor(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("janitor did not stop after context cancellation")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fserg/md-to-html/internal/converter"
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func New(cfg Config, conv *converter.Converter) (*Server, error) {
|
||||
if conv == nil {
|
||||
return nil, errors.New("converter is required")
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}))
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
conv: conv,
|
||||
store: NewPreviewStore(cfg.PreviewTTL),
|
||||
log: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Router() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(chimiddleware.RequestID)
|
||||
r.Use(CORSMiddleware())
|
||||
r.Use(MaxBytesMiddleware(s.cfg.MaxRequestBytes))
|
||||
r.Use(RequestLogger(s.log))
|
||||
r.Use(chimiddleware.Recoverer)
|
||||
r.Use(chimiddleware.Timeout(30 * time.Second))
|
||||
|
||||
r.Get("/", s.handleHome)
|
||||
r.Post("/convert", s.handleConvert)
|
||||
r.Get("/health", s.handleHealth)
|
||||
r.Get("/version", s.handleVersion)
|
||||
r.Get("/ready", s.handleReady)
|
||||
r.Post("/ui/convert", s.handleUIConvert)
|
||||
r.Get("/preview/{id}", s.handlePreview)
|
||||
r.Get("/download/{id}", s.handleDownload)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
httpServer := &http.Server{
|
||||
Addr: s.cfg.Addr,
|
||||
Handler: s.Router(),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go s.store.janitor(ctx)
|
||||
go func() {
|
||||
s.log.Info("server starting", "addr", s.cfg.Addr)
|
||||
errCh <- httpServer.ListenAndServe()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.log.Info("shutting down", "timeout", s.cfg.ShutdownTimeout)
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), s.cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("shutdown server: %w", err)
|
||||
}
|
||||
|
||||
if err := <-errCh; err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("server exited after shutdown: %w", err)
|
||||
}
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
if err == nil || errors.Is(err, http.ErrServerClosed) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("serve: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fserg/md-to-html/internal/converter"
|
||||
"github.com/fserg/md-to-html/internal/version"
|
||||
webtemplate "github.com/fserg/md-to-html/web/template"
|
||||
)
|
||||
|
||||
func TestConvertEndpoint(t *testing.T) {
|
||||
srv := newTestServer(t, Config{
|
||||
Addr: ":0",
|
||||
MaxMarkdownBytes: 128,
|
||||
MaxRequestBytes: 256,
|
||||
PreviewTTL: time.Hour,
|
||||
ShutdownTimeout: time.Second,
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(srv.Router())
|
||||
defer ts.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
contentType string
|
||||
wantStatus int
|
||||
wantType string
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "valid markdown",
|
||||
body: `{"markdown":"# Hello"}`,
|
||||
contentType: "application/json",
|
||||
wantStatus: http.StatusOK,
|
||||
wantType: "text/html; charset=utf-8",
|
||||
wantBody: "<!DOCTYPE html>",
|
||||
},
|
||||
{
|
||||
name: "empty markdown",
|
||||
body: `{"markdown":" "}`,
|
||||
contentType: "application/json",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantType: "application/json; charset=utf-8",
|
||||
wantBody: `{"detail":"markdown must not be empty"}`,
|
||||
},
|
||||
{
|
||||
name: "markdown too large",
|
||||
body: `{"markdown":"` + strings.Repeat("a", 129) + `"}`,
|
||||
contentType: "application/json",
|
||||
wantStatus: http.StatusRequestEntityTooLarge,
|
||||
wantType: "application/json; charset=utf-8",
|
||||
wantBody: `{"detail":"markdown exceeds 128 bytes"}`,
|
||||
},
|
||||
{
|
||||
name: "missing content type",
|
||||
body: `{"markdown":"# Hello"}`,
|
||||
contentType: "",
|
||||
wantStatus: http.StatusUnsupportedMediaType,
|
||||
wantType: "application/json; charset=utf-8",
|
||||
wantBody: `{"detail":"content-type must be application/json"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, ts.URL+"/convert", strings.NewReader(tc.body))
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
if tc.contentType != "" {
|
||||
req.Header.Set("Content-Type", tc.contentType)
|
||||
}
|
||||
|
||||
resp, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != tc.wantStatus {
|
||||
t.Fatalf("status = %d, want %d; body=%s", resp.StatusCode, tc.wantStatus, body)
|
||||
}
|
||||
if got := resp.Header.Get("Content-Type"); got != tc.wantType {
|
||||
t.Fatalf("content-type = %q, want %q", got, tc.wantType)
|
||||
}
|
||||
if !bytes.Contains(body, []byte(tc.wantBody)) {
|
||||
t.Fatalf("body %q does not contain %q", body, tc.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertEndpoint_RequestLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newTestServer(t, Config{
|
||||
Addr: ":0",
|
||||
MaxMarkdownBytes: 1_048_576,
|
||||
MaxRequestBytes: 64,
|
||||
PreviewTTL: time.Hour,
|
||||
ShutdownTimeout: time.Second,
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(srv.Router())
|
||||
defer ts.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, ts.URL+"/convert", strings.NewReader(`{"markdown":"`+strings.Repeat("a", 100)+`"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("status = %d, want %d; body=%s", resp.StatusCode, http.StatusRequestEntityTooLarge, body)
|
||||
}
|
||||
if !bytes.Contains(body, []byte(`{"detail":"request exceeds 64 bytes"}`)) {
|
||||
t.Fatalf("unexpected body: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusEndpoints(t *testing.T) {
|
||||
originalVersion := version.Version
|
||||
version.Version = "dev"
|
||||
t.Cleanup(func() {
|
||||
version.Version = originalVersion
|
||||
})
|
||||
|
||||
srv := newTestServer(t, defaultTestConfig())
|
||||
ts := httptest.NewServer(srv.Router())
|
||||
defer ts.Close()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want map[string]any
|
||||
}{
|
||||
{path: "/health", want: map[string]any{"status": "ok"}},
|
||||
{path: "/version", want: map[string]any{"version": "dev"}},
|
||||
{path: "/ready", want: map[string]any{"status": "ok", "template_loaded": true}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
resp, err := ts.Client().Get(ts.URL + tc.path)
|
||||
if err != nil {
|
||||
t.Fatalf("get %s: %v", tc.path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("decode body: %v", err)
|
||||
}
|
||||
|
||||
for key, wantValue := range tc.want {
|
||||
if got[key] != wantValue {
|
||||
t.Fatalf("%s[%q] = %v, want %v", tc.path, key, got[key], wantValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewAndDownloadOneShot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newTestServer(t, defaultTestConfig())
|
||||
previewID := srv.store.Put([]byte("<h1>Preview</h1>"), "text/html; charset=utf-8", "preview.html")
|
||||
downloadID := srv.store.Put([]byte("<h1>Download</h1>"), "text/html; charset=utf-8", "download.html")
|
||||
|
||||
ts := httptest.NewServer(srv.Router())
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := ts.Client().Get(ts.URL + "/preview/" + previewID)
|
||||
if err != nil {
|
||||
t.Fatalf("get preview: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read preview body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("preview status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
if got := resp.Header.Get("Cache-Control"); got != "no-store" {
|
||||
t.Fatalf("preview cache-control = %q, want %q", got, "no-store")
|
||||
}
|
||||
if string(body) != "<h1>Preview</h1>" {
|
||||
t.Fatalf("preview body = %q", body)
|
||||
}
|
||||
|
||||
resp, err = ts.Client().Get(ts.URL + "/preview/" + previewID)
|
||||
if err != nil {
|
||||
t.Fatalf("get preview second time: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("second preview status = %d, want %d", resp.StatusCode, http.StatusNotFound)
|
||||
}
|
||||
|
||||
resp, err = ts.Client().Get(ts.URL + "/download/" + downloadID)
|
||||
if err != nil {
|
||||
t.Fatalf("get download: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("download status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
if got := resp.Header.Get("Content-Disposition"); !strings.Contains(got, `attachment; filename=preview.html`) && !strings.Contains(got, `attachment; filename=download.html`) {
|
||||
t.Fatalf("unexpected content-disposition: %q", got)
|
||||
}
|
||||
|
||||
resp, err = ts.Client().Get(ts.URL + "/download/" + downloadID)
|
||||
if err != nil {
|
||||
t.Fatalf("get download second time: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("second download status = %d, want %d", resp.StatusCode, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newTestServer(t, defaultTestConfig())
|
||||
ts := httptest.NewServer(srv.Router())
|
||||
defer ts.Close()
|
||||
|
||||
resp, err := ts.Client().Get(ts.URL + "/preview/nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("get preview: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSPreflight(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := newTestServer(t, defaultTestConfig())
|
||||
ts := httptest.NewServer(srv.Router())
|
||||
defer ts.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodOptions, ts.URL+"/convert", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
req.Header.Set("Origin", "https://evil.com")
|
||||
req.Header.Set("Access-Control-Request-Method", http.MethodPost)
|
||||
|
||||
resp, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("do request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Fatalf("allow-origin = %q, want %q", got, "*")
|
||||
}
|
||||
if got := resp.Header.Get("Access-Control-Allow-Methods"); got != "POST, GET, OPTIONS" {
|
||||
t.Fatalf("allow-methods = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, cfg Config) *Server {
|
||||
t.Helper()
|
||||
|
||||
conv, err := converter.New(webtemplate.FS)
|
||||
if err != nil {
|
||||
t.Fatalf("new converter: %v", err)
|
||||
}
|
||||
|
||||
srv, err := New(cfg, conv)
|
||||
if err != nil {
|
||||
t.Fatalf("new server: %v", err)
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func defaultTestConfig() Config {
|
||||
return Config{
|
||||
Addr: ":0",
|
||||
MaxMarkdownBytes: 1_048_576,
|
||||
MaxRequestBytes: 1_200_000,
|
||||
PreviewTTL: time.Hour,
|
||||
ShutdownTimeout: time.Second,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user