diff options
Diffstat (limited to 'srv/api')
-rw-r--r-- | srv/api/api.go | 2 | ||||
-rw-r--r-- | srv/api/chat.go | 61 | ||||
-rw-r--r-- | srv/api/csrf.go | 8 | ||||
-rw-r--r-- | srv/api/middleware.go | 9 | ||||
-rw-r--r-- | srv/api/pow.go | 4 |
5 files changed, 78 insertions, 6 deletions
diff --git a/srv/api/api.go b/srv/api/api.go index 6ba7ce0..adaf6a1 100644 --- a/srv/api/api.go +++ b/srv/api/api.go @@ -172,7 +172,7 @@ func (a *api) handler() http.Handler { ))) var apiHandler http.Handler = apiMux - apiHandler = allowedMethod("POST", apiHandler) + apiHandler = postOnlyMiddleware(apiHandler) apiHandler = checkCSRFMiddleware(apiHandler) apiHandler = logMiddleware(a.params.Logger, apiHandler) apiHandler = annotateMiddleware(apiHandler) diff --git a/srv/api/chat.go b/srv/api/chat.go index 4ac32e4..a1acc5a 100644 --- a/srv/api/chat.go +++ b/srv/api/chat.go @@ -1,12 +1,14 @@ package api import ( + "context" "errors" "fmt" "net/http" "strings" "unicode" + "github.com/gorilla/websocket" "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" ) @@ -16,6 +18,8 @@ type chatHandler struct { room chat.Room userIDCalc *chat.UserIDCalculator + + wsUpgrader websocket.Upgrader } func newChatHandler( @@ -26,11 +30,14 @@ func newChatHandler( ServeMux: http.NewServeMux(), room: room, userIDCalc: userIDCalc, + + wsUpgrader: websocket.Upgrader{}, } c.Handle("/history", c.historyHandler()) c.Handle("/user-id", requirePowMiddleware(c.userIDHandler())) c.Handle("/append", requirePowMiddleware(c.appendHandler())) + c.Handle("/listen", c.listenHandler()) return c } @@ -148,3 +155,57 @@ func (c *chatHandler) appendHandler() http.Handler { }) }) } + +func (c *chatHandler) listenHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + sinceID := r.FormValue("sinceID") + + conn, err := c.wsUpgrader.Upgrade(rw, r, nil) + if err != nil { + apiutils.BadRequest(rw, r, err) + return + } + defer conn.Close() + + it, err := c.room.Listen(ctx, sinceID) + + if errors.As(err, new(chat.ErrInvalidArg)) { + apiutils.BadRequest(rw, r, err) + return + + } else if errors.Is(err, context.Canceled) { + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + defer it.Close() + + for { + + msg, err := it.Next(ctx) + if errors.Is(err, context.Canceled) { + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + err = conn.WriteJSON(struct { + Message chat.Message `json:"message"` + }{ + Message: msg, + }) + + if err != nil { + apiutils.GetRequestLogger(r).Error(ctx, "couldn't write message", err) + return + } + } + }) +} diff --git a/srv/api/csrf.go b/srv/api/csrf.go index 0802d8a..13b6ec6 100644 --- a/srv/api/csrf.go +++ b/srv/api/csrf.go @@ -41,8 +41,14 @@ func checkCSRFMiddleware(h http.Handler) http.Handler { if err != nil { apiutils.InternalServerError(rw, r, err) return + } + + givenCSRFTok := r.Header.Get(csrfTokenHeaderName) + if givenCSRFTok == "" { + givenCSRFTok = r.FormValue("csrfToken") + } - } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok { + if csrfTok == "" || givenCSRFTok != csrfTok { apiutils.BadRequest(rw, r, errors.New("invalid CSRF token")) return } diff --git a/srv/api/middleware.go b/srv/api/middleware.go index 2605d93..6ea0d13 100644 --- a/srv/api/middleware.go +++ b/srv/api/middleware.go @@ -40,12 +40,15 @@ func annotateMiddleware(h http.Handler) http.Handler { type logResponseWriter struct { http.ResponseWriter + http.Hijacker statusCode int } func newLogResponseWriter(rw http.ResponseWriter) *logResponseWriter { + h, _ := rw.(http.Hijacker) return &logResponseWriter{ ResponseWriter: rw, + Hijacker: h, statusCode: 200, } } @@ -78,9 +81,11 @@ func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { }) } -func allowedMethod(method string, h http.Handler) http.Handler { +func postOnlyMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if r.Method == method { + + // we allow websockets to not be POSTs because, well, they can't be + if r.Method == "POST" || r.Header.Get("Upgrade") == "websocket" { h.ServeHTTP(rw, r) return } diff --git a/srv/api/pow.go b/srv/api/pow.go index 6d11061..1b232b1 100644 --- a/srv/api/pow.go +++ b/srv/api/pow.go @@ -27,14 +27,14 @@ func (a *api) newPowChallengeHandler() http.Handler { func (a *api) requirePowMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - seedHex := r.PostFormValue("powSeed") + seedHex := r.FormValue("powSeed") seed, err := hex.DecodeString(seedHex) if err != nil || len(seed) == 0 { apiutils.BadRequest(rw, r, errors.New("invalid powSeed")) return } - solutionHex := r.PostFormValue("powSolution") + solutionHex := r.FormValue("powSolution") solution, err := hex.DecodeString(solutionHex) if err != nil || len(seed) == 0 { apiutils.BadRequest(rw, r, errors.New("invalid powSolution")) |