summaryrefslogtreecommitdiff
path: root/srv/api
diff options
context:
space:
mode:
Diffstat (limited to 'srv/api')
-rw-r--r--srv/api/api.go2
-rw-r--r--srv/api/chat.go61
-rw-r--r--srv/api/csrf.go8
-rw-r--r--srv/api/middleware.go9
-rw-r--r--srv/api/pow.go4
5 files changed, 78 insertions, 6 deletions
diff --git a/srv/api/api.go b/srv/api/api.go
index 6ba7ce0..adaf6a1 100644
--- a/srv/api/api.go
+++ b/srv/api/api.go
@@ -172,7 +172,7 @@ func (a *api) handler() http.Handler {
)))
var apiHandler http.Handler = apiMux
- apiHandler = allowedMethod("POST", apiHandler)
+ apiHandler = postOnlyMiddleware(apiHandler)
apiHandler = checkCSRFMiddleware(apiHandler)
apiHandler = logMiddleware(a.params.Logger, apiHandler)
apiHandler = annotateMiddleware(apiHandler)
diff --git a/srv/api/chat.go b/srv/api/chat.go
index 4ac32e4..a1acc5a 100644
--- a/srv/api/chat.go
+++ b/srv/api/chat.go
@@ -1,12 +1,14 @@
package api
import (
+ "context"
"errors"
"fmt"
"net/http"
"strings"
"unicode"
+ "github.com/gorilla/websocket"
"github.com/mediocregopher/blog.mediocregopher.com/srv/api/apiutils"
"github.com/mediocregopher/blog.mediocregopher.com/srv/chat"
)
@@ -16,6 +18,8 @@ type chatHandler struct {
room chat.Room
userIDCalc *chat.UserIDCalculator
+
+ wsUpgrader websocket.Upgrader
}
func newChatHandler(
@@ -26,11 +30,14 @@ func newChatHandler(
ServeMux: http.NewServeMux(),
room: room,
userIDCalc: userIDCalc,
+
+ wsUpgrader: websocket.Upgrader{},
}
c.Handle("/history", c.historyHandler())
c.Handle("/user-id", requirePowMiddleware(c.userIDHandler()))
c.Handle("/append", requirePowMiddleware(c.appendHandler()))
+ c.Handle("/listen", c.listenHandler())
return c
}
@@ -148,3 +155,57 @@ func (c *chatHandler) appendHandler() http.Handler {
})
})
}
+
+func (c *chatHandler) listenHandler() http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+
+ ctx := r.Context()
+ sinceID := r.FormValue("sinceID")
+
+ conn, err := c.wsUpgrader.Upgrade(rw, r, nil)
+ if err != nil {
+ apiutils.BadRequest(rw, r, err)
+ return
+ }
+ defer conn.Close()
+
+ it, err := c.room.Listen(ctx, sinceID)
+
+ if errors.As(err, new(chat.ErrInvalidArg)) {
+ apiutils.BadRequest(rw, r, err)
+ return
+
+ } else if errors.Is(err, context.Canceled) {
+ return
+
+ } else if err != nil {
+ apiutils.InternalServerError(rw, r, err)
+ return
+ }
+
+ defer it.Close()
+
+ for {
+
+ msg, err := it.Next(ctx)
+ if errors.Is(err, context.Canceled) {
+ return
+
+ } else if err != nil {
+ apiutils.InternalServerError(rw, r, err)
+ return
+ }
+
+ err = conn.WriteJSON(struct {
+ Message chat.Message `json:"message"`
+ }{
+ Message: msg,
+ })
+
+ if err != nil {
+ apiutils.GetRequestLogger(r).Error(ctx, "couldn't write message", err)
+ return
+ }
+ }
+ })
+}
diff --git a/srv/api/csrf.go b/srv/api/csrf.go
index 0802d8a..13b6ec6 100644
--- a/srv/api/csrf.go
+++ b/srv/api/csrf.go
@@ -41,8 +41,14 @@ func checkCSRFMiddleware(h http.Handler) http.Handler {
if err != nil {
apiutils.InternalServerError(rw, r, err)
return
+ }
+
+ givenCSRFTok := r.Header.Get(csrfTokenHeaderName)
+ if givenCSRFTok == "" {
+ givenCSRFTok = r.FormValue("csrfToken")
+ }
- } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok {
+ if csrfTok == "" || givenCSRFTok != csrfTok {
apiutils.BadRequest(rw, r, errors.New("invalid CSRF token"))
return
}
diff --git a/srv/api/middleware.go b/srv/api/middleware.go
index 2605d93..6ea0d13 100644
--- a/srv/api/middleware.go
+++ b/srv/api/middleware.go
@@ -40,12 +40,15 @@ func annotateMiddleware(h http.Handler) http.Handler {
type logResponseWriter struct {
http.ResponseWriter
+ http.Hijacker
statusCode int
}
func newLogResponseWriter(rw http.ResponseWriter) *logResponseWriter {
+ h, _ := rw.(http.Hijacker)
return &logResponseWriter{
ResponseWriter: rw,
+ Hijacker: h,
statusCode: 200,
}
}
@@ -78,9 +81,11 @@ func logMiddleware(logger *mlog.Logger, h http.Handler) http.Handler {
})
}
-func allowedMethod(method string, h http.Handler) http.Handler {
+func postOnlyMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if r.Method == method {
+
+ // we allow websockets to not be POSTs because, well, they can't be
+ if r.Method == "POST" || r.Header.Get("Upgrade") == "websocket" {
h.ServeHTTP(rw, r)
return
}
diff --git a/srv/api/pow.go b/srv/api/pow.go
index 6d11061..1b232b1 100644
--- a/srv/api/pow.go
+++ b/srv/api/pow.go
@@ -27,14 +27,14 @@ func (a *api) newPowChallengeHandler() http.Handler {
func (a *api) requirePowMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- seedHex := r.PostFormValue("powSeed")
+ seedHex := r.FormValue("powSeed")
seed, err := hex.DecodeString(seedHex)
if err != nil || len(seed) == 0 {
apiutils.BadRequest(rw, r, errors.New("invalid powSeed"))
return
}
- solutionHex := r.PostFormValue("powSolution")
+ solutionHex := r.FormValue("powSolution")
solution, err := hex.DecodeString(solutionHex)
if err != nil || len(seed) == 0 {
apiutils.BadRequest(rw, r, errors.New("invalid powSolution"))