summaryrefslogtreecommitdiff
path: root/srv/api
diff options
context:
space:
mode:
authorBrian Picciano <mediocregopher@gmail.com>2021-08-29 22:15:58 -0600
committerBrian Picciano <mediocregopher@gmail.com>2021-08-29 22:15:58 -0600
commit15ae483fadbd136acefcd602b2f2ac5a83165c73 (patch)
tree0f25ed1dd81e4fffeed6055dd02da48a567c8fb2 /srv/api
parent5746a510fc569fd464e46b646d4979a976ad769b (diff)
add CSRF checking
Diffstat (limited to 'srv/api')
-rw-r--r--srv/api/api.go6
-rw-r--r--srv/api/csrf.go50
-rw-r--r--srv/api/utils.go23
3 files changed, 78 insertions, 1 deletions
diff --git a/srv/api/api.go b/srv/api/api.go
index 39d73d9..bbb677a 100644
--- a/srv/api/api.go
+++ b/srv/api/api.go
@@ -142,6 +142,8 @@ func (a *api) handler() http.Handler {
staticHandler = httputil.NewSingleHostReverseProxy(a.params.StaticProxy)
}
+ staticHandler = setCSRFMiddleware(staticHandler)
+
// sugar
requirePow := func(h http.Handler) http.Handler {
return a.requirePowMiddleware(h)
@@ -163,7 +165,9 @@ func (a *api) handler() http.Handler {
apiMux.Handle("/mailinglist/finalize", a.mailingListFinalizeHandler())
apiMux.Handle("/mailinglist/unsubscribe", a.mailingListUnsubscribeHandler())
- apiHandler := logMiddleware(a.params.Logger, apiMux)
+ var apiHandler http.Handler = apiMux
+ apiHandler = checkCSRFMiddleware(apiHandler)
+ apiHandler = logMiddleware(a.params.Logger, apiHandler)
apiHandler = annotateMiddleware(apiHandler)
apiHandler = addResponseHeaders(map[string]string{
"Cache-Control": "no-store, max-age=0",
diff --git a/srv/api/csrf.go b/srv/api/csrf.go
new file mode 100644
index 0000000..d705adb
--- /dev/null
+++ b/srv/api/csrf.go
@@ -0,0 +1,50 @@
+package api
+
+import (
+ "errors"
+ "net/http"
+)
+
+const (
+ csrfTokenCookieName = "csrf_token"
+ csrfTokenHeaderName = "X-CSRF-Token"
+)
+
+func setCSRFMiddleware(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+
+ csrfTok, err := getCookie(r, csrfTokenCookieName, "")
+
+ if err != nil {
+ internalServerError(rw, r, err)
+ return
+
+ } else if csrfTok == "" {
+ http.SetCookie(rw, &http.Cookie{
+ Name: csrfTokenCookieName,
+ Value: randStr(32),
+ Secure: true,
+ })
+ }
+
+ h.ServeHTTP(rw, r)
+ })
+}
+
+func checkCSRFMiddleware(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+
+ csrfTok, err := getCookie(r, csrfTokenCookieName, "")
+
+ if err != nil {
+ internalServerError(rw, r, err)
+ return
+
+ } else if csrfTok == "" || r.Header.Get(csrfTokenHeaderName) != csrfTok {
+ badRequest(rw, r, errors.New("invalid CSRF token"))
+ return
+ }
+
+ h.ServeHTTP(rw, r)
+ })
+}
diff --git a/srv/api/utils.go b/srv/api/utils.go
index 7662e17..2cf40b6 100644
--- a/srv/api/utils.go
+++ b/srv/api/utils.go
@@ -2,7 +2,11 @@ package api
import (
"context"
+ "crypto/rand"
+ "encoding/hex"
"encoding/json"
+ "errors"
+ "fmt"
"net/http"
"strconv"
@@ -66,3 +70,22 @@ func strToInt(str string, defaultVal int) (int, error) {
}
return strconv.Atoi(str)
}
+
+func getCookie(r *http.Request, cookieName, defaultVal string) (string, error) {
+ c, err := r.Cookie(cookieName)
+ if errors.Is(err, http.ErrNoCookie) {
+ return defaultVal, nil
+ } else if err != nil {
+ return "", fmt.Errorf("reading cookie %q: %w", cookieName, err)
+ }
+
+ return c.Value, nil
+}
+
+func randStr(numBytesEntropy int) string {
+ b := make([]byte, numBytesEntropy)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return hex.EncodeToString(b)
+}