diff options
Diffstat (limited to 'srv/src/api')
-rw-r--r-- | srv/src/api/api.go | 188 | ||||
-rw-r--r-- | srv/src/api/apiutils/apiutils.go | 112 | ||||
-rw-r--r-- | srv/src/api/chat.go | 211 | ||||
-rw-r--r-- | srv/src/api/csrf.go | 58 | ||||
-rw-r--r-- | srv/src/api/mailinglist.go | 88 | ||||
-rw-r--r-- | srv/src/api/middleware.go | 96 | ||||
-rw-r--r-- | srv/src/api/pow.go | 53 |
7 files changed, 806 insertions, 0 deletions
diff --git a/srv/src/api/api.go b/srv/src/api/api.go new file mode 100644 index 0000000..56f33b2 --- /dev/null +++ b/srv/src/api/api.go @@ -0,0 +1,188 @@ +// Package api implements the HTTP-based api for the mediocre-blog. +package api + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" + "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" + "github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist" + "github.com/mediocregopher/blog.mediocregopher.com/srv/pow" + "github.com/mediocregopher/mediocre-go-lib/v2/mctx" + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" +) + +// Params are used to instantiate a new API instance. All fields are required +// unless otherwise noted. +type Params struct { + Logger *mlog.Logger + PowManager pow.Manager + MailingList mailinglist.MailingList + GlobalRoom chat.Room + UserIDCalculator *chat.UserIDCalculator + + // ListenProto and ListenAddr are passed into net.Listen to create the + // API's listener. Both "tcp" and "unix" protocols are explicitly + // supported. + ListenProto, ListenAddr string + + // StaticDir and StaticProxy are mutually exclusive. + // + // If StaticDir is set then that directory on the filesystem will be used to + // serve the static site. + // + // Otherwise if StaticProxy is set all requests for the static site will be + // reverse-proxied there. + StaticDir string + StaticProxy *url.URL +} + +// SetupCfg implement the cfg.Cfger interface. +func (p *Params) SetupCfg(cfg *cfg.Cfg) { + + cfg.StringVar(&p.ListenProto, "listen-proto", "tcp", "Protocol to listen for HTTP requests with") + cfg.StringVar(&p.ListenAddr, "listen-addr", ":4000", "Address/path to listen for HTTP requests on") + + cfg.StringVar(&p.StaticDir, "static-dir", "", "Directory from which static files are served (mutually exclusive with -static-proxy-url)") + staticProxyURLStr := cfg.String("static-proxy-url", "", "HTTP address from which static files are served (mutually exclusive with -static-dir)") + + cfg.OnInit(func(ctx context.Context) error { + if *staticProxyURLStr != "" { + var err error + if p.StaticProxy, err = url.Parse(*staticProxyURLStr); err != nil { + return fmt.Errorf("parsing -static-proxy-url: %w", err) + } + + } else if p.StaticDir == "" { + return errors.New("-static-dir or -static-proxy-url is required") + } + + return nil + }) +} + +// Annotate implements mctx.Annotator interface. +func (p *Params) Annotate(a mctx.Annotations) { + a["listenProto"] = p.ListenProto + a["listenAddr"] = p.ListenAddr + + if p.StaticProxy != nil { + a["staticProxy"] = p.StaticProxy.String() + return + } + + a["staticDir"] = p.StaticDir +} + +// API will listen on the port configured for it, and serve HTTP requests for +// the mediocre-blog. +type API interface { + Shutdown(ctx context.Context) error +} + +type api struct { + params Params + srv *http.Server +} + +// New initializes and returns a new API instance, including setting up all +// listening ports. +func New(params Params) (API, error) { + + l, err := net.Listen(params.ListenProto, params.ListenAddr) + if err != nil { + return nil, fmt.Errorf("creating listen socket: %w", err) + } + + if params.ListenProto == "unix" { + if err := os.Chmod(params.ListenAddr, 0777); err != nil { + return nil, fmt.Errorf("chmod-ing unix socket: %w", err) + } + } + + a := &api{ + params: params, + } + + a.srv = &http.Server{Handler: a.handler()} + + go func() { + + err := a.srv.Serve(l) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + ctx := mctx.Annotate(context.Background(), a.params) + params.Logger.Fatal(ctx, "serving http server", err) + } + }() + + return a, nil +} + +func (a *api) Shutdown(ctx context.Context) error { + if err := a.srv.Shutdown(ctx); err != nil { + return err + } + + return nil +} + +func (a *api) handler() http.Handler { + + var staticHandler http.Handler + if a.params.StaticDir != "" { + staticHandler = http.FileServer(http.Dir(a.params.StaticDir)) + } else { + staticHandler = httputil.NewSingleHostReverseProxy(a.params.StaticProxy) + } + + staticHandler = setCSRFMiddleware(staticHandler) + + // sugar + requirePow := func(h http.Handler) http.Handler { + return a.requirePowMiddleware(h) + } + + mux := http.NewServeMux() + + mux.Handle("/", staticHandler) + + apiMux := http.NewServeMux() + apiMux.Handle("/pow/challenge", a.newPowChallengeHandler()) + apiMux.Handle("/pow/check", + requirePow( + http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), + ), + ) + + apiMux.Handle("/mailinglist/subscribe", requirePow(a.mailingListSubscribeHandler())) + apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler()) + apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler()) + + apiMux.Handle("/chat/global/", http.StripPrefix("/chat/global", newChatHandler( + a.params.GlobalRoom, + a.params.UserIDCalculator, + a.requirePowMiddleware, + ))) + + var apiHandler http.Handler = apiMux + apiHandler = postOnlyMiddleware(apiHandler) + apiHandler = checkCSRFMiddleware(apiHandler) + apiHandler = logMiddleware(a.params.Logger, apiHandler) + apiHandler = annotateMiddleware(apiHandler) + apiHandler = addResponseHeaders(map[string]string{ + "Cache-Control": "no-store, max-age=0", + "Pragma": "no-cache", + "Expires": "0", + }, apiHandler) + + mux.Handle("/api/", http.StripPrefix("/api", apiHandler)) + + return mux +} diff --git a/srv/src/api/apiutils/apiutils.go b/srv/src/api/apiutils/apiutils.go new file mode 100644 index 0000000..223c2b9 --- /dev/null +++ b/srv/src/api/apiutils/apiutils.go @@ -0,0 +1,112 @@ +// Package apiutils contains utilities which are useful for implementing api +// endpoints. +package apiutils + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" +) + +type loggerCtxKey int + +// SetRequestLogger sets the given Logger onto the given Request's Context, +// returning a copy. +func SetRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, loggerCtxKey(0), logger) + return r.WithContext(ctx) +} + +// GetRequestLogger returns the Logger which was set by SetRequestLogger onto +// this Request, or nil. +func GetRequestLogger(r *http.Request) *mlog.Logger { + ctx := r.Context() + logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger) + if logger == nil { + logger = mlog.Null + } + return logger +} + +// JSONResult writes the JSON encoding of the given value as the response body. +func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + InternalServerError(rw, r, err) + return + } + b = append(b, '\n') + + rw.Header().Set("Content-Type", "application/json") + rw.Write(b) +} + +// BadRequest writes a 400 status and a JSON encoded error struct containing the +// given error as the response body. +func BadRequest(rw http.ResponseWriter, r *http.Request, err error) { + GetRequestLogger(r).Warn(r.Context(), "bad request", err) + + rw.WriteHeader(400) + JSONResult(rw, r, struct { + Error string `json:"error"` + }{ + Error: err.Error(), + }) +} + +// InternalServerError writes a 500 status and a JSON encoded error struct +// containing a generic error as the response body (though it will log the given +// one). +func InternalServerError(rw http.ResponseWriter, r *http.Request, err error) { + GetRequestLogger(r).Error(r.Context(), "internal server error", err) + + rw.WriteHeader(500) + JSONResult(rw, r, struct { + Error string `json:"error"` + }{ + Error: "internal server error", + }) +} + +// StrToInt parses the given string as an integer, or returns the given default +// integer if the string is empty. +func StrToInt(str string, defaultVal int) (int, error) { + if str == "" { + return defaultVal, nil + } + return strconv.Atoi(str) +} + +// GetCookie returns the namd cookie's value, or the given default value if the +// cookie is not set. +// +// This will only return an error if there was an unexpected error parsing the +// Request's cookies. +func GetCookie(r *http.Request, cookieName, defaultVal string) (string, error) { + c, err := r.Cookie(cookieName) + if errors.Is(err, http.ErrNoCookie) { + return defaultVal, nil + } else if err != nil { + return "", fmt.Errorf("reading cookie %q: %w", cookieName, err) + } + + return c.Value, nil +} + +// RandStr returns a human-readable random string with the given number of bytes +// of randomness. +func RandStr(numBytes int) string { + b := make([]byte, numBytes) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return hex.EncodeToString(b) +} diff --git a/srv/src/api/chat.go b/srv/src/api/chat.go new file mode 100644 index 0000000..a1acc5a --- /dev/null +++ b/srv/src/api/chat.go @@ -0,0 +1,211 @@ +package api + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "unicode" + + "github.com/gorilla/websocket" + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" + "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" +) + +type chatHandler struct { + *http.ServeMux + + room chat.Room + userIDCalc *chat.UserIDCalculator + + wsUpgrader websocket.Upgrader +} + +func newChatHandler( + room chat.Room, userIDCalc *chat.UserIDCalculator, + requirePowMiddleware func(http.Handler) http.Handler, +) http.Handler { + c := &chatHandler{ + ServeMux: http.NewServeMux(), + room: room, + userIDCalc: userIDCalc, + + wsUpgrader: websocket.Upgrader{}, + } + + c.Handle("/history", c.historyHandler()) + c.Handle("/user-id", requirePowMiddleware(c.userIDHandler())) + c.Handle("/append", requirePowMiddleware(c.appendHandler())) + c.Handle("/listen", c.listenHandler()) + + return c +} + +func (c *chatHandler) historyHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + limit, err := apiutils.StrToInt(r.PostFormValue("limit"), 0) + if err != nil { + apiutils.BadRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err)) + return + } + + cursor := r.PostFormValue("cursor") + + cursor, msgs, err := c.room.History(r.Context(), chat.HistoryOpts{ + Limit: limit, + Cursor: cursor, + }) + + if argErr := (chat.ErrInvalidArg{}); errors.As(err, &argErr) { + apiutils.BadRequest(rw, r, argErr.Err) + return + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + } + + apiutils.JSONResult(rw, r, struct { + Cursor string `json:"cursor"` + Messages []chat.Message `json:"messages"` + }{ + Cursor: cursor, + Messages: msgs, + }) + }) +} + +func (c *chatHandler) userID(r *http.Request) (chat.UserID, error) { + name := r.PostFormValue("name") + if l := len(name); l == 0 { + return chat.UserID{}, errors.New("name is required") + } else if l > 16 { + return chat.UserID{}, errors.New("name too long") + } + + nameClean := strings.Map(func(r rune) rune { + if !unicode.IsPrint(r) { + return -1 + } + return r + }, name) + + if nameClean != name { + return chat.UserID{}, errors.New("name contains invalid characters") + } + + password := r.PostFormValue("password") + if l := len(password); l == 0 { + return chat.UserID{}, errors.New("password is required") + } else if l > 128 { + return chat.UserID{}, errors.New("password too long") + } + + return c.userIDCalc.Calculate(name, password), nil +} + +func (c *chatHandler) userIDHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + userID, err := c.userID(r) + if err != nil { + apiutils.BadRequest(rw, r, err) + return + } + + apiutils.JSONResult(rw, r, struct { + UserID chat.UserID `json:"userID"` + }{ + UserID: userID, + }) + }) +} + +func (c *chatHandler) appendHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + userID, err := c.userID(r) + if err != nil { + apiutils.BadRequest(rw, r, err) + return + } + + body := r.PostFormValue("body") + + if l := len(body); l == 0 { + apiutils.BadRequest(rw, r, errors.New("body is required")) + return + + } else if l > 300 { + apiutils.BadRequest(rw, r, errors.New("body too long")) + return + } + + msg, err := c.room.Append(r.Context(), chat.Message{ + UserID: userID, + Body: body, + }) + + if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + apiutils.JSONResult(rw, r, struct { + MessageID string `json:"messageID"` + }{ + MessageID: msg.ID, + }) + }) +} + +func (c *chatHandler) listenHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + sinceID := r.FormValue("sinceID") + + conn, err := c.wsUpgrader.Upgrade(rw, r, nil) + if err != nil { + apiutils.BadRequest(rw, r, err) + return + } + defer conn.Close() + + it, err := c.room.Listen(ctx, sinceID) + + if errors.As(err, new(chat.ErrInvalidArg)) { + apiutils.BadRequest(rw, r, err) + return + + } else if errors.Is(err, context.Canceled) { + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + defer it.Close() + + for { + + msg, err := it.Next(ctx) + if errors.Is(err, context.Canceled) { + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + err = conn.WriteJSON(struct { + Message chat.Message `json:"message"` + }{ + Message: msg, + }) + + if err != nil { + apiutils.GetRequestLogger(r).Error(ctx, "couldn't write message", err) + return + } + } + }) +} diff --git a/srv/src/api/csrf.go b/srv/src/api/csrf.go new file mode 100644 index 0000000..13b6ec6 --- /dev/null +++ b/srv/src/api/csrf.go @@ -0,0 +1,58 @@ +package api + +import ( + "errors" + "net/http" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" +) + +const ( + csrfTokenCookieName = "csrf_token" + csrfTokenHeaderName = "X-CSRF-Token" +) + +func setCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "") + + if err != nil { + apiutils.InternalServerError(rw, r, err) + return + + } else if csrfTok == "" { + http.SetCookie(rw, &http.Cookie{ + Name: csrfTokenCookieName, + Value: apiutils.RandStr(32), + Secure: true, + }) + } + + h.ServeHTTP(rw, r) + }) +} + +func checkCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "") + + if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + givenCSRFTok := r.Header.Get(csrfTokenHeaderName) + if givenCSRFTok == "" { + givenCSRFTok = r.FormValue("csrfToken") + } + + if csrfTok == "" || givenCSRFTok != csrfTok { + apiutils.BadRequest(rw, r, errors.New("invalid CSRF token")) + return + } + + h.ServeHTTP(rw, r) + }) +} diff --git a/srv/src/api/mailinglist.go b/srv/src/api/mailinglist.go new file mode 100644 index 0000000..d89fe2a --- /dev/null +++ b/srv/src/api/mailinglist.go @@ -0,0 +1,88 @@ +package api + +import ( + "errors" + "net/http" + "strings" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" + "github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist" +) + +func (a *api) mailingListSubscribeHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + email := r.PostFormValue("email") + if parts := strings.Split(email, "@"); len(parts) != 2 || + parts[0] == "" || + parts[1] == "" || + len(email) >= 512 { + apiutils.BadRequest(rw, r, errors.New("invalid email")) + return + } + + err := a.params.MailingList.BeginSubscription(email) + + if errors.Is(err, mailinglist.ErrAlreadyVerified) { + // just eat the error, make it look to the user like the + // verification email was sent. + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + apiutils.JSONResult(rw, r, struct{}{}) + }) +} + +func (a *api) mailingListFinalizeHandler() http.Handler { + var errInvalidSubToken = errors.New("invalid subToken") + + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + subToken := r.PostFormValue("subToken") + if l := len(subToken); l == 0 || l > 128 { + apiutils.BadRequest(rw, r, errInvalidSubToken) + return + } + + err := a.params.MailingList.FinalizeSubscription(subToken) + + if errors.Is(err, mailinglist.ErrNotFound) { + apiutils.BadRequest(rw, r, errInvalidSubToken) + return + + } else if errors.Is(err, mailinglist.ErrAlreadyVerified) { + // no problem + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + apiutils.JSONResult(rw, r, struct{}{}) + }) +} + +func (a *api) mailingListUnsubscribeHandler() http.Handler { + var errInvalidUnsubToken = errors.New("invalid unsubToken") + + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + unsubToken := r.PostFormValue("unsubToken") + if l := len(unsubToken); l == 0 || l > 128 { + apiutils.BadRequest(rw, r, errInvalidUnsubToken) + return + } + + err := a.params.MailingList.Unsubscribe(unsubToken) + + if errors.Is(err, mailinglist.ErrNotFound) { + apiutils.BadRequest(rw, r, errInvalidUnsubToken) + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + apiutils.JSONResult(rw, r, struct{}{}) + }) +} diff --git a/srv/src/api/middleware.go b/srv/src/api/middleware.go new file mode 100644 index 0000000..6ea0d13 --- /dev/null +++ b/srv/src/api/middleware.go @@ -0,0 +1,96 @@ +package api + +import ( + "net" + "net/http" + "time" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" + "github.com/mediocregopher/mediocre-go-lib/v2/mctx" + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" +) + +func addResponseHeaders(headers map[string]string, h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + for k, v := range headers { + rw.Header().Set(k, v) + } + h.ServeHTTP(rw, r) + }) +} + +func annotateMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + type reqInfoKey string + + ip, _, _ := net.SplitHostPort(r.RemoteAddr) + + ctx := r.Context() + ctx = mctx.Annotate(ctx, + reqInfoKey("remote_ip"), ip, + reqInfoKey("url"), r.URL, + reqInfoKey("method"), r.Method, + ) + + r = r.WithContext(ctx) + h.ServeHTTP(rw, r) + }) +} + +type logResponseWriter struct { + http.ResponseWriter + http.Hijacker + statusCode int +} + +func newLogResponseWriter(rw http.ResponseWriter) *logResponseWriter { + h, _ := rw.(http.Hijacker) + return &logResponseWriter{ + ResponseWriter: rw, + Hijacker: h, + statusCode: 200, + } +} + +func (lrw *logResponseWriter) WriteHeader(statusCode int) { + lrw.statusCode = statusCode + lrw.ResponseWriter.WriteHeader(statusCode) +} + +func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + r = apiutils.SetRequestLogger(r, logger) + + lrw := newLogResponseWriter(rw) + + started := time.Now() + h.ServeHTTP(lrw, r) + took := time.Since(started) + + type logCtxKey string + + ctx := r.Context() + ctx = mctx.Annotate(ctx, + logCtxKey("took"), took.String(), + logCtxKey("response_code"), lrw.statusCode, + ) + + logger.Info(ctx, "handled HTTP request") + }) +} + +func postOnlyMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + // we allow websockets to not be POSTs because, well, they can't be + if r.Method == "POST" || r.Header.Get("Upgrade") == "websocket" { + h.ServeHTTP(rw, r) + return + } + + apiutils.GetRequestLogger(r).WarnString(r.Context(), "method not allowed") + rw.WriteHeader(405) + }) +} diff --git a/srv/src/api/pow.go b/srv/src/api/pow.go new file mode 100644 index 0000000..1b232b1 --- /dev/null +++ b/srv/src/api/pow.go @@ -0,0 +1,53 @@ +package api + +import ( + "encoding/hex" + "errors" + "fmt" + "net/http" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" +) + +func (a *api) newPowChallengeHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + challenge := a.params.PowManager.NewChallenge() + + apiutils.JSONResult(rw, r, struct { + Seed string `json:"seed"` + Target uint32 `json:"target"` + }{ + Seed: hex.EncodeToString(challenge.Seed), + Target: challenge.Target, + }) + }) +} + +func (a *api) requirePowMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + seedHex := r.FormValue("powSeed") + seed, err := hex.DecodeString(seedHex) + if err != nil || len(seed) == 0 { + apiutils.BadRequest(rw, r, errors.New("invalid powSeed")) + return + } + + solutionHex := r.FormValue("powSolution") + solution, err := hex.DecodeString(solutionHex) + if err != nil || len(seed) == 0 { + apiutils.BadRequest(rw, r, errors.New("invalid powSolution")) + return + } + + err = a.params.PowManager.CheckSolution(seed, solution) + + if err != nil { + apiutils.BadRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err)) + return + } + + h.ServeHTTP(rw, r) + }) +} |