From 09acb111a2b22f5794541fac175b024dd0f9100e Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Fri, 20 May 2022 11:17:31 -0600 Subject: Rename api package to http --- srv/src/http/csrf.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 srv/src/http/csrf.go (limited to 'srv/src/http/csrf.go') diff --git a/srv/src/http/csrf.go b/srv/src/http/csrf.go new file mode 100644 index 0000000..1c80dee --- /dev/null +++ b/srv/src/http/csrf.go @@ -0,0 +1,59 @@ +package http + +import ( + "errors" + "net/http" + + "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" +) + +const ( + csrfTokenCookieName = "csrf_token" + csrfTokenHeaderName = "X-CSRF-Token" + csrfTokenFormName = "csrfToken" +) + +func setCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := apiutil.GetCookie(r, csrfTokenCookieName, "") + + if err != nil { + apiutil.InternalServerError(rw, r, err) + return + + } else if csrfTok == "" { + http.SetCookie(rw, &http.Cookie{ + Name: csrfTokenCookieName, + Value: apiutil.RandStr(32), + Secure: true, + }) + } + + h.ServeHTTP(rw, r) + }) +} + +func checkCSRFMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + csrfTok, err := apiutil.GetCookie(r, csrfTokenCookieName, "") + + if err != nil { + apiutil.InternalServerError(rw, r, err) + return + } + + givenCSRFTok := r.Header.Get(csrfTokenHeaderName) + if givenCSRFTok == "" { + givenCSRFTok = r.FormValue(csrfTokenFormName) + } + + if csrfTok == "" || givenCSRFTok != csrfTok { + apiutil.BadRequest(rw, r, errors.New("invalid CSRF token")) + return + } + + h.ServeHTTP(rw, r) + }) +} -- cgit v1.2.3