diff options
author | Brian Picciano <mediocregopher@gmail.com> | 2021-08-30 20:08:51 -0600 |
---|---|---|
committer | Brian Picciano <mediocregopher@gmail.com> | 2021-08-30 20:44:45 -0600 |
commit | 9343d2ea697f13e52e9199fce62a959f1954f580 (patch) | |
tree | e1e36e330a3c9891bfd8a625229a9b417ad89afa /srv/api | |
parent | 3e9a17abb9a9d63af3c260fba9dc404dd9c59ade (diff) |
add chat handlers and only allow POST methods
Diffstat (limited to 'srv/api')
-rw-r--r-- | srv/api/api.go | 9 | ||||
-rw-r--r-- | srv/api/apiutils/apiutils.go | 112 | ||||
-rw-r--r-- | srv/api/chat.go | 90 | ||||
-rw-r--r-- | srv/api/csrf.go | 14 | ||||
-rw-r--r-- | srv/api/mailinglist.go | 23 | ||||
-rw-r--r-- | srv/api/middleware.go | 15 | ||||
-rw-r--r-- | srv/api/pow.go | 10 | ||||
-rw-r--r-- | srv/api/utils.go | 91 |
8 files changed, 236 insertions, 128 deletions
diff --git a/srv/api/api.go b/srv/api/api.go index bbb677a..6ba7ce0 100644 --- a/srv/api/api.go +++ b/srv/api/api.go @@ -26,7 +26,7 @@ type Params struct { PowManager pow.Manager MailingList mailinglist.MailingList GlobalRoom chat.Room - UserIDCalculator chat.UserIDCalculator + UserIDCalculator *chat.UserIDCalculator // ListenProto and ListenAddr are passed into net.Listen to create the // API's listener. Both "tcp" and "unix" protocols are explicitly @@ -165,7 +165,14 @@ func (a *api) handler() http.Handler { apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler()) apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler()) + apiMux.Handle("/chat/global/", http.StripPrefix("/chat/global", newChatHandler( + a.params.GlobalRoom, + a.params.UserIDCalculator, + a.requirePowMiddleware, + ))) + var apiHandler http.Handler = apiMux + apiHandler = allowedMethod("POST", apiHandler) apiHandler = checkCSRFMiddleware(apiHandler) apiHandler = logMiddleware(a.params.Logger, apiHandler) apiHandler = annotateMiddleware(apiHandler) diff --git a/srv/api/apiutils/apiutils.go b/srv/api/apiutils/apiutils.go new file mode 100644 index 0000000..223c2b9 --- /dev/null +++ b/srv/api/apiutils/apiutils.go @@ -0,0 +1,112 @@ +// Package apiutils contains utilities which are useful for implementing api +// endpoints. +package apiutils + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" +) + +type loggerCtxKey int + +// SetRequestLogger sets the given Logger onto the given Request's Context, +// returning a copy. +func SetRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, loggerCtxKey(0), logger) + return r.WithContext(ctx) +} + +// GetRequestLogger returns the Logger which was set by SetRequestLogger onto +// this Request, or nil. +func GetRequestLogger(r *http.Request) *mlog.Logger { + ctx := r.Context() + logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger) + if logger == nil { + logger = mlog.Null + } + return logger +} + +// JSONResult writes the JSON encoding of the given value as the response body. +func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + InternalServerError(rw, r, err) + return + } + b = append(b, '\n') + + rw.Header().Set("Content-Type", "application/json") + rw.Write(b) +} + +// BadRequest writes a 400 status and a JSON encoded error struct containing the +// given error as the response body. +func BadRequest(rw http.ResponseWriter, r *http.Request, err error) { + GetRequestLogger(r).Warn(r.Context(), "bad request", err) + + rw.WriteHeader(400) + JSONResult(rw, r, struct { + Error string `json:"error"` + }{ + Error: err.Error(), + }) +} + +// InternalServerError writes a 500 status and a JSON encoded error struct +// containing a generic error as the response body (though it will log the given +// one). +func InternalServerError(rw http.ResponseWriter, r *http.Request, err error) { + GetRequestLogger(r).Error(r.Context(), "internal server error", err) + + rw.WriteHeader(500) + JSONResult(rw, r, struct { + Error string `json:"error"` + }{ + Error: "internal server error", + }) +} + +// StrToInt parses the given string as an integer, or returns the given default +// integer if the string is empty. +func StrToInt(str string, defaultVal int) (int, error) { + if str == "" { + return defaultVal, nil + } + return strconv.Atoi(str) +} + +// GetCookie returns the namd cookie's value, or the given default value if the +// cookie is not set. +// +// This will only return an error if there was an unexpected error parsing the +// Request's cookies. +func GetCookie(r *http.Request, cookieName, defaultVal string) (string, error) { + c, err := r.Cookie(cookieName) + if errors.Is(err, http.ErrNoCookie) { + return defaultVal, nil + } else if err != nil { + return "", fmt.Errorf("reading cookie %q: %w", cookieName, err) + } + + return c.Value, nil +} + +// RandStr returns a human-readable random string with the given number of bytes +// of randomness. +func RandStr(numBytes int) string { + b := make([]byte, numBytes) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return hex.EncodeToString(b) +} diff --git a/srv/api/chat.go b/srv/api/chat.go index 55d9d02..4ac32e4 100644 --- a/srv/api/chat.go +++ b/srv/api/chat.go @@ -7,32 +7,57 @@ import ( "strings" "unicode" + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" ) -func (a *api) chatHistoryHandler() http.Handler { +type chatHandler struct { + *http.ServeMux + + room chat.Room + userIDCalc *chat.UserIDCalculator +} + +func newChatHandler( + room chat.Room, userIDCalc *chat.UserIDCalculator, + requirePowMiddleware func(http.Handler) http.Handler, +) http.Handler { + c := &chatHandler{ + ServeMux: http.NewServeMux(), + room: room, + userIDCalc: userIDCalc, + } + + c.Handle("/history", c.historyHandler()) + c.Handle("/user-id", requirePowMiddleware(c.userIDHandler())) + c.Handle("/append", requirePowMiddleware(c.appendHandler())) + + return c +} + +func (c *chatHandler) historyHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - limit, err := strToInt(r.FormValue("limit"), 0) + limit, err := apiutils.StrToInt(r.PostFormValue("limit"), 0) if err != nil { - badRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err)) + apiutils.BadRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err)) return } - cursor := r.FormValue("cursor") + cursor := r.PostFormValue("cursor") - cursor, msgs, err := a.params.GlobalRoom.History(r.Context(), chat.HistoryOpts{ + cursor, msgs, err := c.room.History(r.Context(), chat.HistoryOpts{ Limit: limit, Cursor: cursor, }) if argErr := (chat.ErrInvalidArg{}); errors.As(err, &argErr) { - badRequest(rw, r, argErr.Err) + apiutils.BadRequest(rw, r, argErr.Err) return } else if err != nil { - internalServerError(rw, r, err) + apiutils.InternalServerError(rw, r, err) } - jsonResult(rw, r, struct { + apiutils.JSONResult(rw, r, struct { Cursor string `json:"cursor"` Messages []chat.Message `json:"messages"` }{ @@ -42,7 +67,7 @@ func (a *api) chatHistoryHandler() http.Handler { }) } -func (a *api) getUserID(r *http.Request) (chat.UserID, error) { +func (c *chatHandler) userID(r *http.Request) (chat.UserID, error) { name := r.PostFormValue("name") if l := len(name); l == 0 { return chat.UserID{}, errors.New("name is required") @@ -68,21 +93,58 @@ func (a *api) getUserID(r *http.Request) (chat.UserID, error) { return chat.UserID{}, errors.New("password too long") } - return a.params.UserIDCalculator.Calculate(name, password), nil + return c.userIDCalc.Calculate(name, password), nil } -func (a *api) getUserIDHandler() http.Handler { +func (c *chatHandler) userIDHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - userID, err := a.getUserID(r) + userID, err := c.userID(r) if err != nil { - badRequest(rw, r, err) + apiutils.BadRequest(rw, r, err) return } - jsonResult(rw, r, struct { + apiutils.JSONResult(rw, r, struct { UserID chat.UserID `json:"userID"` }{ UserID: userID, }) }) } + +func (c *chatHandler) appendHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + userID, err := c.userID(r) + if err != nil { + apiutils.BadRequest(rw, r, err) + return + } + + body := r.PostFormValue("body") + + if l := len(body); l == 0 { + apiutils.BadRequest(rw, r, errors.New("body is required")) + return + + } else if l > 300 { + apiutils.BadRequest(rw, r, errors.New("body too long")) + return + } + + msg, err := c.room.Append(r.Context(), chat.Message{ + UserID: userID, + Body: body, + }) + + if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + apiutils.JSONResult(rw, r, struct { + MessageID string `json:"messageID"` + }{ + MessageID: msg.ID, + }) + }) +} diff --git a/srv/api/csrf.go b/srv/api/csrf.go index d705adb..0802d8a 100644 --- a/srv/api/csrf.go +++ b/srv/api/csrf.go @@ -3,6 +3,8 @@ package api import ( "errors" "net/http" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" ) const ( @@ -13,16 +15,16 @@ const ( func setCSRFMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - csrfTok, err := getCookie(r, csrfTokenCookieName, "") + csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "") if err != nil { - internalServerError(rw, r, err) + apiutils.InternalServerError(rw, r, err) return } else if csrfTok == "" { http.SetCookie(rw, &http.Cookie{ Name: csrfTokenCookieName, - Value: randStr(32), + Value: apiutils.RandStr(32), Secure: true, }) } @@ -34,14 +36,14 @@ func setCSRFMiddleware(h http.Handler) http.Handler { func checkCSRFMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - csrfTok, err := getCookie(r, csrfTokenCookieName, "") + csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "") if err != nil { - internalServerError(rw, r, err) + apiutils.InternalServerError(rw, r, err) return } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok { - badRequest(rw, r, errors.New("invalid CSRF token")) + apiutils.BadRequest(rw, r, errors.New("invalid CSRF token")) return } diff --git a/srv/api/mailinglist.go b/srv/api/mailinglist.go index 2ddfbe6..d89fe2a 100644 --- a/srv/api/mailinglist.go +++ b/srv/api/mailinglist.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" "github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist" ) @@ -15,7 +16,7 @@ func (a *api) mailingListSubscribeHandler() http.Handler { parts[0] == "" || parts[1] == "" || len(email) >= 512 { - badRequest(rw, r, errors.New("invalid email")) + apiutils.BadRequest(rw, r, errors.New("invalid email")) return } @@ -25,11 +26,11 @@ func (a *api) mailingListSubscribeHandler() http.Handler { // just eat the error, make it look to the user like the // verification email was sent. } else if err != nil { - internalServerError(rw, r, err) + apiutils.InternalServerError(rw, r, err) return } - jsonResult(rw, r, struct{}{}) + apiutils.JSONResult(rw, r, struct{}{}) }) } @@ -39,25 +40,25 @@ func (a *api) mailingListFinalizeHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { subToken := r.PostFormValue("subToken") if l := len(subToken); l == 0 || l > 128 { - badRequest(rw, r, errInvalidSubToken) + apiutils.BadRequest(rw, r, errInvalidSubToken) return } err := a.params.MailingList.FinalizeSubscription(subToken) if errors.Is(err, mailinglist.ErrNotFound) { - badRequest(rw, r, errInvalidSubToken) + apiutils.BadRequest(rw, r, errInvalidSubToken) return } else if errors.Is(err, mailinglist.ErrAlreadyVerified) { // no problem } else if err != nil { - internalServerError(rw, r, err) + apiutils.InternalServerError(rw, r, err) return } - jsonResult(rw, r, struct{}{}) + apiutils.JSONResult(rw, r, struct{}{}) }) } @@ -67,21 +68,21 @@ func (a *api) mailingListUnsubscribeHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { unsubToken := r.PostFormValue("unsubToken") if l := len(unsubToken); l == 0 || l > 128 { - badRequest(rw, r, errInvalidUnsubToken) + apiutils.BadRequest(rw, r, errInvalidUnsubToken) return } err := a.params.MailingList.Unsubscribe(unsubToken) if errors.Is(err, mailinglist.ErrNotFound) { - badRequest(rw, r, errInvalidUnsubToken) + apiutils.BadRequest(rw, r, errInvalidUnsubToken) return } else if err != nil { - internalServerError(rw, r, err) + apiutils.InternalServerError(rw, r, err) return } - jsonResult(rw, r, struct{}{}) + apiutils.JSONResult(rw, r, struct{}{}) }) } diff --git a/srv/api/middleware.go b/srv/api/middleware.go index e3e85bb..2605d93 100644 --- a/srv/api/middleware.go +++ b/srv/api/middleware.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" "github.com/mediocregopher/mediocre-go-lib/v2/mctx" "github.com/mediocregopher/mediocre-go-lib/v2/mlog" ) @@ -57,7 +58,7 @@ func (lrw *logResponseWriter) WriteHeader(statusCode int) { func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - r = setRequestLogger(r, logger) + r = apiutils.SetRequestLogger(r, logger) lrw := newLogResponseWriter(rw) @@ -76,3 +77,15 @@ func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { logger.Info(ctx, "handled HTTP request") }) } + +func allowedMethod(method string, h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method == method { + h.ServeHTTP(rw, r) + return + } + + apiutils.GetRequestLogger(r).WarnString(r.Context(), "method not allowed") + rw.WriteHeader(405) + }) +} diff --git a/srv/api/pow.go b/srv/api/pow.go index 096e252..6d11061 100644 --- a/srv/api/pow.go +++ b/srv/api/pow.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "net/http" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" ) func (a *api) newPowChallengeHandler() http.Handler { @@ -12,7 +14,7 @@ func (a *api) newPowChallengeHandler() http.Handler { challenge := a.params.PowManager.NewChallenge() - jsonResult(rw, r, struct { + apiutils.JSONResult(rw, r, struct { Seed string `json:"seed"` Target uint32 `json:"target"` }{ @@ -28,21 +30,21 @@ func (a *api) requirePowMiddleware(h http.Handler) http.Handler { seedHex := r.PostFormValue("powSeed") seed, err := hex.DecodeString(seedHex) if err != nil || len(seed) == 0 { - badRequest(rw, r, errors.New("invalid powSeed")) + apiutils.BadRequest(rw, r, errors.New("invalid powSeed")) return } solutionHex := r.PostFormValue("powSolution") solution, err := hex.DecodeString(solutionHex) if err != nil || len(seed) == 0 { - badRequest(rw, r, errors.New("invalid powSolution")) + apiutils.BadRequest(rw, r, errors.New("invalid powSolution")) return } err = a.params.PowManager.CheckSolution(seed, solution) if err != nil { - badRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err)) + apiutils.BadRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err)) return } diff --git a/srv/api/utils.go b/srv/api/utils.go deleted file mode 100644 index 2cf40b6..0000000 --- a/srv/api/utils.go +++ /dev/null @@ -1,91 +0,0 @@ -package api - -import ( - "context" - "crypto/rand" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "net/http" - "strconv" - - "github.com/mediocregopher/mediocre-go-lib/v2/mlog" -) - -type loggerCtxKey int - -func setRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request { - ctx := r.Context() - ctx = context.WithValue(ctx, loggerCtxKey(0), logger) - return r.WithContext(ctx) -} - -func getRequestLogger(r *http.Request) *mlog.Logger { - ctx := r.Context() - logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger) - if logger == nil { - logger = mlog.Null - } - return logger -} - -func jsonResult(rw http.ResponseWriter, r *http.Request, v interface{}) { - b, err := json.Marshal(v) - if err != nil { - internalServerError(rw, r, err) - return - } - b = append(b, '\n') - - rw.Header().Set("Content-Type", "application/json") - rw.Write(b) -} - -func badRequest(rw http.ResponseWriter, r *http.Request, err error) { - getRequestLogger(r).Warn(r.Context(), "bad request", err) - - rw.WriteHeader(400) - jsonResult(rw, r, struct { - Error string `json:"error"` - }{ - Error: err.Error(), - }) -} - -func internalServerError(rw http.ResponseWriter, r *http.Request, err error) { - getRequestLogger(r).Error(r.Context(), "internal server error", err) - - rw.WriteHeader(500) - jsonResult(rw, r, struct { - Error string `json:"error"` - }{ - Error: "internal server error", - }) -} - -func strToInt(str string, defaultVal int) (int, error) { - if str == "" { - return defaultVal, nil - } - return strconv.Atoi(str) -} - -func getCookie(r *http.Request, cookieName, defaultVal string) (string, error) { - c, err := r.Cookie(cookieName) - if errors.Is(err, http.ErrNoCookie) { - return defaultVal, nil - } else if err != nil { - return "", fmt.Errorf("reading cookie %q: %w", cookieName, err) - } - - return c.Value, nil -} - -func randStr(numBytesEntropy int) string { - b := make([]byte, numBytesEntropy) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return hex.EncodeToString(b) -} |