diff options
author | Brian Picciano <mediocregopher@gmail.com> | 2021-08-29 22:15:58 -0600 |
---|---|---|
committer | Brian Picciano <mediocregopher@gmail.com> | 2021-08-29 22:15:58 -0600 |
commit | 15ae483fadbd136acefcd602b2f2ac5a83165c73 (patch) | |
tree | 0f25ed1dd81e4fffeed6055dd02da48a567c8fb2 /srv/api | |
parent | 5746a510fc569fd464e46b646d4979a976ad769b (diff) |
add CSRF checking
Diffstat (limited to 'srv/api')
-rw-r--r-- | srv/api/api.go | 6 | ||||
-rw-r--r-- | srv/api/csrf.go | 50 | ||||
-rw-r--r-- | srv/api/utils.go | 23 |
3 files changed, 78 insertions, 1 deletions
diff --git a/srv/api/api.go b/srv/api/api.go index 39d73d9..bbb677a 100644 --- a/srv/api/api.go +++ b/srv/api/api.go @@ -142,6 +142,8 @@ func (a *api) handler() http.Handler { staticHandler = httputil.NewSingleHostReverseProxy(a.params.StaticProxy) } + staticHandler = setCSRFMiddleware(staticHandler) + // sugar requirePow := func(h http.Handler) http.Handler { return a.requirePowMiddleware(h) @@ -163,7 +165,9 @@ func (a *api) handler() http.Handler { apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler()) apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler()) - apiHandler := logMiddleware(a.params.Logger, apiMux) + var apiHandler http.Handler = apiMux + apiHandler = checkCSRFMiddleware(apiHandler) + apiHandler = logMiddleware(a.params.Logger, apiHandler) apiHandler = annotateMiddleware(apiHandler) apiHandler = addResponseHeaders(map[string]string{ "Cache-Control": "no-store, max-age=0", diff --git a/srv/api/csrf.go b/srv/api/csrf.go new file mode 100644 index 0000000..d705adb --- /dev/null +++ b/srv/api/csrf.go @@ -0,0 +1,50 @@ +package api + +import ( + "errors" + "net/http" +) + +const ( + csrfTokenCookieName = "csrf_token" + csrfTokenHeaderName = "X-CSRF-Token" +) + +func setCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := getCookie(r, csrfTokenCookieName, "") + + if err != nil { + internalServerError(rw, r, err) + return + + } else if csrfTok == "" { + http.SetCookie(rw, &http.Cookie{ + Name: csrfTokenCookieName, + Value: randStr(32), + Secure: true, + }) + } + + h.ServeHTTP(rw, r) + }) +} + +func checkCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := getCookie(r, csrfTokenCookieName, "") + + if err != nil { + internalServerError(rw, r, err) + return + + } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok { + badRequest(rw, r, errors.New("invalid CSRF token")) + return + } + + h.ServeHTTP(rw, r) + }) +} diff --git a/srv/api/utils.go b/srv/api/utils.go index 7662e17..2cf40b6 100644 --- a/srv/api/utils.go +++ b/srv/api/utils.go @@ -2,7 +2,11 @@ package api import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" + "errors" + "fmt" "net/http" "strconv" @@ -66,3 +70,22 @@ func strToInt(str string, defaultVal int) (int, error) { } return strconv.Atoi(str) } + +func getCookie(r *http.Request, cookieName, defaultVal string) (string, error) { + c, err := r.Cookie(cookieName) + if errors.Is(err, http.ErrNoCookie) { + return defaultVal, nil + } else if err != nil { + return "", fmt.Errorf("reading cookie %q: %w", cookieName, err) + } + + return c.Value, nil +} + +func randStr(numBytesEntropy int) string { + b := make([]byte, numBytesEntropy) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return hex.EncodeToString(b) +} |