summaryrefslogtreecommitdiff
path: root/srv/api
diff options
context:
space:
mode:
Diffstat (limited to 'srv/api')
-rw-r--r--srv/api/api.go9
-rw-r--r--srv/api/apiutils/apiutils.go112
-rw-r--r--srv/api/chat.go90
-rw-r--r--srv/api/csrf.go14
-rw-r--r--srv/api/mailinglist.go23
-rw-r--r--srv/api/middleware.go15
-rw-r--r--srv/api/pow.go10
-rw-r--r--srv/api/utils.go91
8 files changed, 236 insertions, 128 deletions
diff --git a/srv/api/api.go b/srv/api/api.go
index bbb677a..6ba7ce0 100644
--- a/srv/api/api.go
+++ b/srv/api/api.go
@@ -26,7 +26,7 @@ type Params struct {
PowManager pow.Manager
MailingList mailinglist.MailingList
GlobalRoom chat.Room
- UserIDCalculator chat.UserIDCalculator
+ UserIDCalculator *chat.UserIDCalculator
// ListenProto and ListenAddr are passed into net.Listen to create the
// API's listener. Both "tcp" and "unix" protocols are explicitly
@@ -165,7 +165,14 @@ func (a *api) handler() http.Handler {
apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler())
apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler())
+ apiMux.Handle("/chat/global/", http.StripPrefix("/chat/global", newChatHandler(
+ a.params.GlobalRoom,
+ a.params.UserIDCalculator,
+ a.requirePowMiddleware,
+ )))
+
var apiHandler http.Handler = apiMux
+ apiHandler = allowedMethod("POST", apiHandler)
apiHandler = checkCSRFMiddleware(apiHandler)
apiHandler = logMiddleware(a.params.Logger, apiHandler)
apiHandler = annotateMiddleware(apiHandler)
diff --git a/srv/api/apiutils/apiutils.go b/srv/api/apiutils/apiutils.go
new file mode 100644
index 0000000..223c2b9
--- /dev/null
+++ b/srv/api/apiutils/apiutils.go
@@ -0,0 +1,112 @@
+// Package apiutils contains utilities which are useful for implementing api
+// endpoints.
+package apiutils
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "strconv"
+
+ "github.com/mediocregopher/mediocre-go-lib/v2/mlog"
+)
+
+type loggerCtxKey int
+
+// SetRequestLogger sets the given Logger onto the given Request's Context,
+// returning a copy.
+func SetRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request {
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, loggerCtxKey(0), logger)
+ return r.WithContext(ctx)
+}
+
+// GetRequestLogger returns the Logger which was set by SetRequestLogger onto
+// this Request, or nil.
+func GetRequestLogger(r *http.Request) *mlog.Logger {
+ ctx := r.Context()
+ logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger)
+ if logger == nil {
+ logger = mlog.Null
+ }
+ return logger
+}
+
+// JSONResult writes the JSON encoding of the given value as the response body.
+func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) {
+ b, err := json.Marshal(v)
+ if err != nil {
+ InternalServerError(rw, r, err)
+ return
+ }
+ b = append(b, '\n')
+
+ rw.Header().Set("Content-Type", "application/json")
+ rw.Write(b)
+}
+
+// BadRequest writes a 400 status and a JSON encoded error struct containing the
+// given error as the response body.
+func BadRequest(rw http.ResponseWriter, r *http.Request, err error) {
+ GetRequestLogger(r).Warn(r.Context(), "bad request", err)
+
+ rw.WriteHeader(400)
+ JSONResult(rw, r, struct {
+ Error string `json:"error"`
+ }{
+ Error: err.Error(),
+ })
+}
+
+// InternalServerError writes a 500 status and a JSON encoded error struct
+// containing a generic error as the response body (though it will log the given
+// one).
+func InternalServerError(rw http.ResponseWriter, r *http.Request, err error) {
+ GetRequestLogger(r).Error(r.Context(), "internal server error", err)
+
+ rw.WriteHeader(500)
+ JSONResult(rw, r, struct {
+ Error string `json:"error"`
+ }{
+ Error: "internal server error",
+ })
+}
+
+// StrToInt parses the given string as an integer, or returns the given default
+// integer if the string is empty.
+func StrToInt(str string, defaultVal int) (int, error) {
+ if str == "" {
+ return defaultVal, nil
+ }
+ return strconv.Atoi(str)
+}
+
+// GetCookie returns the namd cookie's value, or the given default value if the
+// cookie is not set.
+//
+// This will only return an error if there was an unexpected error parsing the
+// Request's cookies.
+func GetCookie(r *http.Request, cookieName, defaultVal string) (string, error) {
+ c, err := r.Cookie(cookieName)
+ if errors.Is(err, http.ErrNoCookie) {
+ return defaultVal, nil
+ } else if err != nil {
+ return "", fmt.Errorf("reading cookie %q: %w", cookieName, err)
+ }
+
+ return c.Value, nil
+}
+
+// RandStr returns a human-readable random string with the given number of bytes
+// of randomness.
+func RandStr(numBytes int) string {
+ b := make([]byte, numBytes)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return hex.EncodeToString(b)
+}
diff --git a/srv/api/chat.go b/srv/api/chat.go
index 55d9d02..4ac32e4 100644
--- a/srv/api/chat.go
+++ b/srv/api/chat.go
@@ -7,32 +7,57 @@ import (
"strings"
"unicode"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/blog.mediocregopher.com/srv/chat"
)
-func (a *api) chatHistoryHandler() http.Handler {
+type chatHandler struct {
+ *http.ServeMux
+
+ room chat.Room
+ userIDCalc *chat.UserIDCalculator
+}
+
+func newChatHandler(
+ room chat.Room, userIDCalc *chat.UserIDCalculator,
+ requirePowMiddleware func(http.Handler) http.Handler,
+) http.Handler {
+ c := &chatHandler{
+ ServeMux: http.NewServeMux(),
+ room: room,
+ userIDCalc: userIDCalc,
+ }
+
+ c.Handle("/history", c.historyHandler())
+ c.Handle("/user-id", requirePowMiddleware(c.userIDHandler()))
+ c.Handle("/append", requirePowMiddleware(c.appendHandler()))
+
+ return c
+}
+
+func (c *chatHandler) historyHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- limit, err := strToInt(r.FormValue("limit"), 0)
+ limit, err := apiutils.StrToInt(r.PostFormValue("limit"), 0)
if err != nil {
- badRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err))
+ apiutils.BadRequest(rw, r, fmt.Errorf("invalid limit parameter: %w", err))
return
}
- cursor := r.FormValue("cursor")
+ cursor := r.PostFormValue("cursor")
- cursor, msgs, err := a.params.GlobalRoom.History(r.Context(), chat.HistoryOpts{
+ cursor, msgs, err := c.room.History(r.Context(), chat.HistoryOpts{
Limit: limit,
Cursor: cursor,
})
if argErr := (chat.ErrInvalidArg{}); errors.As(err, &argErr) {
- badRequest(rw, r, argErr.Err)
+ apiutils.BadRequest(rw, r, argErr.Err)
return
} else if err != nil {
- internalServerError(rw, r, err)
+ apiutils.InternalServerError(rw, r, err)
}
- jsonResult(rw, r, struct {
+ apiutils.JSONResult(rw, r, struct {
Cursor string `json:"cursor"`
Messages []chat.Message `json:"messages"`
}{
@@ -42,7 +67,7 @@ func (a *api) chatHistoryHandler() http.Handler {
})
}
-func (a *api) getUserID(r *http.Request) (chat.UserID, error) {
+func (c *chatHandler) userID(r *http.Request) (chat.UserID, error) {
name := r.PostFormValue("name")
if l := len(name); l == 0 {
return chat.UserID{}, errors.New("name is required")
@@ -68,21 +93,58 @@ func (a *api) getUserID(r *http.Request) (chat.UserID, error) {
return chat.UserID{}, errors.New("password too long")
}
- return a.params.UserIDCalculator.Calculate(name, password), nil
+ return c.userIDCalc.Calculate(name, password), nil
}
-func (a *api) getUserIDHandler() http.Handler {
+func (c *chatHandler) userIDHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- userID, err := a.getUserID(r)
+ userID, err := c.userID(r)
if err != nil {
- badRequest(rw, r, err)
+ apiutils.BadRequest(rw, r, err)
return
}
- jsonResult(rw, r, struct {
+ apiutils.JSONResult(rw, r, struct {
UserID chat.UserID `json:"userID"`
}{
UserID: userID,
})
})
}
+
+func (c *chatHandler) appendHandler() http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ userID, err := c.userID(r)
+ if err != nil {
+ apiutils.BadRequest(rw, r, err)
+ return
+ }
+
+ body := r.PostFormValue("body")
+
+ if l := len(body); l == 0 {
+ apiutils.BadRequest(rw, r, errors.New("body is required"))
+ return
+
+ } else if l > 300 {
+ apiutils.BadRequest(rw, r, errors.New("body too long"))
+ return
+ }
+
+ msg, err := c.room.Append(r.Context(), chat.Message{
+ UserID: userID,
+ Body: body,
+ })
+
+ if err != nil {
+ apiutils.InternalServerError(rw, r, err)
+ return
+ }
+
+ apiutils.JSONResult(rw, r, struct {
+ MessageID string `json:"messageID"`
+ }{
+ MessageID: msg.ID,
+ })
+ })
+}
diff --git a/srv/api/csrf.go b/srv/api/csrf.go
index d705adb..0802d8a 100644
--- a/srv/api/csrf.go
+++ b/srv/api/csrf.go
@@ -3,6 +3,8 @@ package api
import (
"errors"
"net/http"
+
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
)
const (
@@ -13,16 +15,16 @@ const (
func setCSRFMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- csrfTok, err := getCookie(r, csrfTokenCookieName, "")
+ csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
- internalServerError(rw, r, err)
+ apiutils.InternalServerError(rw, r, err)
return
} else if csrfTok == "" {
http.SetCookie(rw, &http.Cookie{
Name: csrfTokenCookieName,
- Value: randStr(32),
+ Value: apiutils.RandStr(32),
Secure: true,
})
}
@@ -34,14 +36,14 @@ func setCSRFMiddleware(h http.Handler) http.Handler {
func checkCSRFMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- csrfTok, err := getCookie(r, csrfTokenCookieName, "")
+ csrfTok, err := apiutils.GetCookie(r, csrfTokenCookieName, "")
if err != nil {
- internalServerError(rw, r, err)
+ apiutils.InternalServerError(rw, r, err)
return
} else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok {
- badRequest(rw, r, errors.New("invalid CSRF token"))
+ apiutils.BadRequest(rw, r, errors.New("invalid CSRF token"))
return
}
diff --git a/srv/api/mailinglist.go b/srv/api/mailinglist.go
index 2ddfbe6..d89fe2a 100644
--- a/srv/api/mailinglist.go
+++ b/srv/api/mailinglist.go
@@ -5,6 +5,7 @@ import (
"net/http"
"strings"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/blog.mediocregopher.com/srv/mailinglist"
)
@@ -15,7 +16,7 @@ func (a *api) mailingListSubscribeHandler() http.Handler {
parts[0] == "" ||
parts[1] == "" ||
len(email) >= 512 {
- badRequest(rw, r, errors.New("invalid email"))
+ apiutils.BadRequest(rw, r, errors.New("invalid email"))
return
}
@@ -25,11 +26,11 @@ func (a *api) mailingListSubscribeHandler() http.Handler {
// just eat the error, make it look to the user like the
// verification email was sent.
} else if err != nil {
- internalServerError(rw, r, err)
+ apiutils.InternalServerError(rw, r, err)
return
}
- jsonResult(rw, r, struct{}{})
+ apiutils.JSONResult(rw, r, struct{}{})
})
}
@@ -39,25 +40,25 @@ func (a *api) mailingListFinalizeHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
subToken := r.PostFormValue("subToken")
if l := len(subToken); l == 0 || l > 128 {
- badRequest(rw, r, errInvalidSubToken)
+ apiutils.BadRequest(rw, r, errInvalidSubToken)
return
}
err := a.params.MailingList.FinalizeSubscription(subToken)
if errors.Is(err, mailinglist.ErrNotFound) {
- badRequest(rw, r, errInvalidSubToken)
+ apiutils.BadRequest(rw, r, errInvalidSubToken)
return
} else if errors.Is(err, mailinglist.ErrAlreadyVerified) {
// no problem
} else if err != nil {
- internalServerError(rw, r, err)
+ apiutils.InternalServerError(rw, r, err)
return
}
- jsonResult(rw, r, struct{}{})
+ apiutils.JSONResult(rw, r, struct{}{})
})
}
@@ -67,21 +68,21 @@ func (a *api) mailingListUnsubscribeHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
unsubToken := r.PostFormValue("unsubToken")
if l := len(unsubToken); l == 0 || l > 128 {
- badRequest(rw, r, errInvalidUnsubToken)
+ apiutils.BadRequest(rw, r, errInvalidUnsubToken)
return
}
err := a.params.MailingList.Unsubscribe(unsubToken)
if errors.Is(err, mailinglist.ErrNotFound) {
- badRequest(rw, r, errInvalidUnsubToken)
+ apiutils.BadRequest(rw, r, errInvalidUnsubToken)
return
} else if err != nil {
- internalServerError(rw, r, err)
+ apiutils.InternalServerError(rw, r, err)
return
}
- jsonResult(rw, r, struct{}{})
+ apiutils.JSONResult(rw, r, struct{}{})
})
}
diff --git a/srv/api/middleware.go b/srv/api/middleware.go
index e3e85bb..2605d93 100644
--- a/srv/api/middleware.go
+++ b/srv/api/middleware.go
@@ -5,6 +5,7 @@ import (
"net/http"
"time"
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/mediocre-go-lib/v2/mctx"
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
@@ -57,7 +58,7 @@ func (lrw *logResponseWriter) WriteHeader(statusCode int) {
func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- r = setRequestLogger(r, logger)
+ r = apiutils.SetRequestLogger(r, logger)
lrw := newLogResponseWriter(rw)
@@ -76,3 +77,15 @@ func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler {
logger.Info(ctx, "handled HTTP request")
})
}
+
+func allowedMethod(method string, h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ if r.Method == method {
+ h.ServeHTTP(rw, r)
+ return
+ }
+
+ apiutils.GetRequestLogger(r).WarnString(r.Context(), "method not allowed")
+ rw.WriteHeader(405)
+ })
+}
diff --git a/srv/api/pow.go b/srv/api/pow.go
index 096e252..6d11061 100644
--- a/srv/api/pow.go
+++ b/srv/api/pow.go
@@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"net/http"
+
+ "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
)
func (a *api) newPowChallengeHandler() http.Handler {
@@ -12,7 +14,7 @@ func (a *api) newPowChallengeHandler() http.Handler {
challenge := a.params.PowManager.NewChallenge()
- jsonResult(rw, r, struct {
+ apiutils.JSONResult(rw, r, struct {
Seed string `json:"seed"`
Target uint32 `json:"target"`
}{
@@ -28,21 +30,21 @@ func (a *api) requirePowMiddleware(h http.Handler) http.Handler {
seedHex := r.PostFormValue("powSeed")
seed, err := hex.DecodeString(seedHex)
if err != nil || len(seed) == 0 {
- badRequest(rw, r, errors.New("invalid powSeed"))
+ apiutils.BadRequest(rw, r, errors.New("invalid powSeed"))
return
}
solutionHex := r.PostFormValue("powSolution")
solution, err := hex.DecodeString(solutionHex)
if err != nil || len(seed) == 0 {
- badRequest(rw, r, errors.New("invalid powSolution"))
+ apiutils.BadRequest(rw, r, errors.New("invalid powSolution"))
return
}
err = a.params.PowManager.CheckSolution(seed, solution)
if err != nil {
- badRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err))
+ apiutils.BadRequest(rw, r, fmt.Errorf("checking proof-of-work solution: %w", err))
return
}
diff --git a/srv/api/utils.go b/srv/api/utils.go
deleted file mode 100644
index 2cf40b6..0000000
--- a/srv/api/utils.go
+++ /dev/null
@@ -1,91 +0,0 @@
-package api
-
-import (
- "context"
- "crypto/rand"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "strconv"
-
- "github.com/mediocregopher/mediocre-go-lib/v2/mlog"
-)
-
-type loggerCtxKey int
-
-func setRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request {
- ctx := r.Context()
- ctx = context.WithValue(ctx, loggerCtxKey(0), logger)
- return r.WithContext(ctx)
-}
-
-func getRequestLogger(r *http.Request) *mlog.Logger {
- ctx := r.Context()
- logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger)
- if logger == nil {
- logger = mlog.Null
- }
- return logger
-}
-
-func jsonResult(rw http.ResponseWriter, r *http.Request, v interface{}) {
- b, err := json.Marshal(v)
- if err != nil {
- internalServerError(rw, r, err)
- return
- }
- b = append(b, '\n')
-
- rw.Header().Set("Content-Type", "application/json")
- rw.Write(b)
-}
-
-func badRequest(rw http.ResponseWriter, r *http.Request, err error) {
- getRequestLogger(r).Warn(r.Context(), "bad request", err)
-
- rw.WriteHeader(400)
- jsonResult(rw, r, struct {
- Error string `json:"error"`
- }{
- Error: err.Error(),
- })
-}
-
-func internalServerError(rw http.ResponseWriter, r *http.Request, err error) {
- getRequestLogger(r).Error(r.Context(), "internal server error", err)
-
- rw.WriteHeader(500)
- jsonResult(rw, r, struct {
- Error string `json:"error"`
- }{
- Error: "internal server error",
- })
-}
-
-func strToInt(str string, defaultVal int) (int, error) {
- if str == "" {
- return defaultVal, nil
- }
- return strconv.Atoi(str)
-}
-
-func getCookie(r *http.Request, cookieName, defaultVal string) (string, error) {
- c, err := r.Cookie(cookieName)
- if errors.Is(err, http.ErrNoCookie) {
- return defaultVal, nil
- } else if err != nil {
- return "", fmt.Errorf("reading cookie %q: %w", cookieName, err)
- }
-
- return c.Value, nil
-}
-
-func randStr(numBytesEntropy int) string {
- b := make([]byte, numBytesEntropy)
- if _, err := rand.Read(b); err != nil {
- panic(err)
- }
- return hex.EncodeToString(b)
-}