diff options
Diffstat (limited to 'srv')
-rw-r--r-- | srv/api/api.go | 2 | ||||
-rw-r--r-- | srv/api/chat.go | 61 | ||||
-rw-r--r-- | srv/api/csrf.go | 8 | ||||
-rw-r--r-- | srv/api/middleware.go | 9 | ||||
-rw-r--r-- | srv/api/pow.go | 4 | ||||
-rw-r--r-- | srv/chat/chat.go | 13 | ||||
-rw-r--r-- | srv/default.nix | 2 | ||||
-rw-r--r-- | srv/go.mod | 1 | ||||
-rw-r--r-- | srv/go.sum | 2 |
9 files changed, 90 insertions, 12 deletions
diff --git a/srv/api/api.go b/srv/api/api.go index 6ba7ce0..adaf6a1 100644 --- a/srv/api/api.go +++ b/srv/api/api.go @@ -172,7 +172,7 @@ func (a *api) handler() http.Handler { ))) var apiHandler http.Handler = apiMux - apiHandler = allowedMethod("POST", apiHandler) + apiHandler = postOnlyMiddleware(apiHandler) apiHandler = checkCSRFMiddleware(apiHandler) apiHandler = logMiddleware(a.params.Logger, apiHandler) apiHandler = annotateMiddleware(apiHandler) diff --git a/srv/api/chat.go b/srv/api/chat.go index 4ac32e4..a1acc5a 100644 --- a/srv/api/chat.go +++ b/srv/api/chat.go @@ -1,12 +1,14 @@ package api import ( + "context" "errors" "fmt" "net/http" "strings" "unicode" + "github.com/gorilla/websocket" "github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils" "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" ) @@ -16,6 +18,8 @@ type chatHandler struct { room chat.Room userIDCalc *chat.UserIDCalculator + + wsUpgrader websocket.Upgrader } func newChatHandler( @@ -26,11 +30,14 @@ func newChatHandler( ServeMux: http.NewServeMux(), room: room, userIDCalc: userIDCalc, + + wsUpgrader: websocket.Upgrader{}, } c.Handle("/history", c.historyHandler()) c.Handle("/user-id", requirePowMiddleware(c.userIDHandler())) c.Handle("/append", requirePowMiddleware(c.appendHandler())) + c.Handle("/listen", c.listenHandler()) return c } @@ -148,3 +155,57 @@ func (c *chatHandler) appendHandler() http.Handler { }) }) } + +func (c *chatHandler) listenHandler() http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + sinceID := r.FormValue("sinceID") + + conn, err := c.wsUpgrader.Upgrade(rw, r, nil) + if err != nil { + apiutils.BadRequest(rw, r, err) + return + } + defer conn.Close() + + it, err := c.room.Listen(ctx, sinceID) + + if errors.As(err, new(chat.ErrInvalidArg)) { + apiutils.BadRequest(rw, r, err) + return + + } else if errors.Is(err, context.Canceled) { + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + defer it.Close() + + for { + + msg, err := it.Next(ctx) + if errors.Is(err, context.Canceled) { + return + + } else if err != nil { + apiutils.InternalServerError(rw, r, err) + return + } + + err = conn.WriteJSON(struct { + Message chat.Message `json:"message"` + }{ + Message: msg, + }) + + if err != nil { + apiutils.GetRequestLogger(r).Error(ctx, "couldn't write message", err) + return + } + } + }) +} diff --git a/srv/api/csrf.go b/srv/api/csrf.go index 0802d8a..13b6ec6 100644 --- a/srv/api/csrf.go +++ b/srv/api/csrf.go @@ -41,8 +41,14 @@ func checkCSRFMiddleware(h http.Handler) http.Handler { if err != nil { apiutils.InternalServerError(rw, r, err) return + } + + givenCSRFTok := r.Header.Get(csrfTokenHeaderName) + if givenCSRFTok == "" { + givenCSRFTok = r.FormValue("csrfToken") + } - } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok { + if csrfTok == "" || givenCSRFTok != csrfTok { apiutils.BadRequest(rw, r, errors.New("invalid CSRF token")) return } diff --git a/srv/api/middleware.go b/srv/api/middleware.go index 2605d93..6ea0d13 100644 --- a/srv/api/middleware.go +++ b/srv/api/middleware.go @@ -40,12 +40,15 @@ func annotateMiddleware(h http.Handler) http.Handler { type logResponseWriter struct { http.ResponseWriter + http.Hijacker statusCode int } func newLogResponseWriter(rw http.ResponseWriter) *logResponseWriter { + h, _ := rw.(http.Hijacker) return &logResponseWriter{ ResponseWriter: rw, + Hijacker: h, statusCode: 200, } } @@ -78,9 +81,11 @@ func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler { }) } -func allowedMethod(method string, h http.Handler) http.Handler { +func postOnlyMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if r.Method == method { + + // we allow websockets to not be POSTs because, well, they can't be + if r.Method == "POST" || r.Header.Get("Upgrade") == "websocket" { h.ServeHTTP(rw, r) return } diff --git a/srv/api/pow.go b/srv/api/pow.go index 6d11061..1b232b1 100644 --- a/srv/api/pow.go +++ b/srv/api/pow.go @@ -27,14 +27,14 @@ func (a *api) newPowChallengeHandler() http.Handler { func (a *api) requirePowMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - seedHex := r.PostFormValue("powSeed") + seedHex := r.FormValue("powSeed") seed, err := hex.DecodeString(seedHex) if err != nil || len(seed) == 0 { apiutils.BadRequest(rw, r, errors.New("invalid powSeed")) return } - solutionHex := r.PostFormValue("powSolution") + solutionHex := r.FormValue("powSolution") solution, err := hex.DecodeString(solutionHex) if err != nil || len(seed) == 0 { apiutils.BadRequest(rw, r, errors.New("invalid powSolution")) diff --git a/srv/chat/chat.go b/srv/chat/chat.go index acb7b2d..0a88d3b 100644 --- a/srv/chat/chat.go +++ b/srv/chat/chat.go @@ -31,9 +31,10 @@ var ( // Message describes a message which has been posted to a Room. type Message struct { - ID string `json:"id"` - UserID UserID `json:"userID"` - Body string `json:"body"` + ID string `json:"id"` + UserID UserID `json:"userID"` + Body string `json:"body"` + CreatedAt int64 `json:"createdAt,omitempty"` } func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) { @@ -59,6 +60,7 @@ func msgFromStreamEntry(entry radix.StreamEntry) (Message, error) { } msg.ID = entry.ID.String() + msg.CreatedAt = int64(entry.ID.Time / 1000) return msg, nil } @@ -211,7 +213,7 @@ func (r *room) Append(ctx context.Context, msg Message) (Message, error) { maxLen := strconv.Itoa(r.params.MaxMessages) body := string(b) - var id string + var id radix.StreamEntryID err = r.params.Redis.Do(ctx, radix.Cmd( &id, "XADD", key, "MAXLEN", "=", maxLen, "*", "json", body, @@ -221,7 +223,8 @@ func (r *room) Append(ctx context.Context, msg Message) (Message, error) { return Message{}, fmt.Errorf("posting message to redis: %w", err) } - msg.ID = id + msg.ID = id.String() + msg.CreatedAt = int64(id.Time / 1000) return msg, nil } diff --git a/srv/default.nix b/srv/default.nix index a36739a..bc828a0 100644 --- a/srv/default.nix +++ b/srv/default.nix @@ -23,7 +23,7 @@ pname = "mediocre-blog-srv"; version = "dev"; src = ./.; - vendorSha256 = "0c6j989q6r2q967gx90cl4l8skflkx2npmxd3f5l16bwj2ldw11j"; + vendorSha256 = "02szg1lisfjk8pk9pflbyv97ykg9362r4fhd0w0p2a7c81kf9b8y"; # disable tests checkPhase = ''''; @@ -6,6 +6,7 @@ require ( github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 github.com/emersion/go-smtp v0.15.0 github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.4.2 // indirect github.com/mattn/go-sqlite3 v1.14.8 github.com/mediocregopher/mediocre-go-lib/v2 v2.0.0-beta.0 github.com/mediocregopher/radix/v4 v4.0.0-beta.1.0.20210726230805-d62fa1b2e3cb // indirect @@ -61,6 +61,8 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= |