summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Picciano <mediocregopher@gmail.com>2022-05-21 09:17:43 -0600
committerBrian Picciano <mediocregopher@gmail.com>2022-05-21 09:17:43 -0600
commit1de0ab3b720cf7b83a7e29de4dbe35c117ccea0e (patch)
tree3d5441d0b793c23c02b228707acca36649691ed3
parent034342421bd0b3d40276df79a0f2450d1fbd643f (diff)
Define an actual middleware type, use that to set up API routes
-rw-r--r--srv/src/http/api.go54
-rw-r--r--srv/src/http/apiutil/apiutil.go14
-rw-r--r--srv/src/http/auth.go28
-rw-r--r--srv/src/http/middleware.go55
4 files changed, 86 insertions, 65 deletions
diff --git a/srv/src/http/api.go b/srv/src/http/api.go
index bcd0150..11d092d 100644
--- a/srv/src/http/api.go
+++ b/srv/src/http/api.go
@@ -166,24 +166,6 @@ func (a *api) handler() http.Handler {
return a.requirePowMiddleware(h)
}
- formMiddleware := func(h http.Handler) http.Handler {
- wh := checkCSRFMiddleware(h)
- wh = logReqMiddleware(wh)
- wh = addResponseHeaders(map[string]string{
- "Cache-Control": "no-store, max-age=0",
- "Pragma": "no-cache",
- "Expires": "0",
- }, wh)
-
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if r.Method != "GET" {
- wh.ServeHTTP(rw, r)
- } else {
- h.ServeHTTP(rw, r)
- }
- })
- }
-
mux := http.NewServeMux()
{
@@ -215,17 +197,17 @@ func (a *api) handler() http.Handler {
mux.Handle("/posts/", http.StripPrefix("/posts",
apiutil.MethodMux(map[string]http.Handler{
"GET": a.renderPostHandler(),
- "POST": authMiddleware(a.auther, a.postPostHandler()),
- "DELETE": authMiddleware(a.auther, a.deletePostHandler()),
- "PREVIEW": authMiddleware(a.auther, a.previewPostHandler()),
+ "POST": a.postPostHandler(),
+ "DELETE": a.deletePostHandler(),
+ "PREVIEW": a.previewPostHandler(),
}),
))
mux.Handle("/assets/", http.StripPrefix("/assets",
apiutil.MethodMux(map[string]http.Handler{
"GET": a.getPostAssetHandler(),
- "POST": authMiddleware(a.auther, a.postPostAssetHandler()),
- "DELETE": authMiddleware(a.auther, a.deletePostAssetHandler()),
+ "POST": a.postPostAssetHandler(),
+ "DELETE": a.deletePostAssetHandler(),
}),
))
@@ -234,10 +216,28 @@ func (a *api) handler() http.Handler {
mux.Handle("/feed.xml", a.renderFeedHandler())
mux.Handle("/", a.renderIndexHandler())
- var globalHandler http.Handler = mux
- globalHandler = formMiddleware(globalHandler)
- globalHandler = setCSRFMiddleware(globalHandler)
- globalHandler = setLoggerMiddleware(a.params.Logger, globalHandler)
+ globalHandler := http.Handler(mux)
+
+ globalHandler = apiutil.MethodMux(map[string]http.Handler{
+ "GET": applyMiddlewares(
+ globalHandler,
+ logReqMiddleware,
+ setCSRFMiddleware,
+ ),
+ "*": applyMiddlewares(
+ globalHandler,
+ authMiddleware(a.auther),
+ checkCSRFMiddleware,
+ addResponseHeadersMiddleware(map[string]string{
+ "Cache-Control": "no-store, max-age=0",
+ "Pragma": "no-cache",
+ "Expires": "0",
+ }),
+ logReqMiddleware,
+ ),
+ })
+
+ globalHandler = setLoggerMiddleware(a.params.Logger)(globalHandler)
return globalHandler
}
diff --git a/srv/src/http/apiutil/apiutil.go b/srv/src/http/apiutil/apiutil.go
index aa62299..fed6fb5 100644
--- a/srv/src/http/apiutil/apiutil.go
+++ b/srv/src/http/apiutil/apiutil.go
@@ -117,6 +117,9 @@ func RandStr(numBytes int) string {
//
// If no Handler is defined for a method then a 405 Method Not Allowed error is
// returned.
+//
+// If the method "*" is defined then all methods not defined will be directed to
+// that handler, and 405 Method Not Allowed is never returned.
func MethodMux(handlers map[string]http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -128,13 +131,16 @@ func MethodMux(handlers map[string]http.Handler) http.Handler {
method = formMethod
}
- handler, ok := handlers[method]
+ if handler, ok := handlers[method]; ok {
+ handler.ServeHTTP(rw, r)
+ return
+ }
- if !ok {
- http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
+ if handler, ok := handlers["*"]; ok {
+ handler.ServeHTTP(rw, r)
return
}
- handler.ServeHTTP(rw, r)
+ http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
})
}
diff --git a/srv/src/http/auth.go b/srv/src/http/auth.go
index 9527cc8..3ad026a 100644
--- a/srv/src/http/auth.go
+++ b/srv/src/http/auth.go
@@ -65,7 +65,7 @@ func (a *auther) Allowed(ctx context.Context, username, password string) bool {
return err == nil
}
-func authMiddleware(auther Auther, h http.Handler) http.Handler {
+func authMiddleware(auther Auther) middleware {
respondUnauthorized := func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("WWW-Authenticate", `Basic realm="NOPE"`)
@@ -73,20 +73,22 @@ func authMiddleware(auther Auther, h http.Handler) http.Handler {
apiutil.GetRequestLogger(r).WarnString(r.Context(), "unauthorized")
}
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- username, password, ok := r.BasicAuth()
+ username, password, ok := r.BasicAuth()
- if !ok {
- respondUnauthorized(rw, r)
- return
- }
+ if !ok {
+ respondUnauthorized(rw, r)
+ return
+ }
- if !auther.Allowed(r.Context(), username, password) {
- respondUnauthorized(rw, r)
- return
- }
+ if !auther.Allowed(r.Context(), username, password) {
+ respondUnauthorized(rw, r)
+ return
+ }
- h.ServeHTTP(rw, r)
- })
+ h.ServeHTTP(rw, r)
+ })
+ }
}
diff --git a/srv/src/http/middleware.go b/srv/src/http/middleware.go
index 8299a71..02d156b 100644
--- a/srv/src/http/middleware.go
+++ b/srv/src/http/middleware.go
@@ -10,33 +10,46 @@ import (
"github.com/mediocregopher/mediocre-go-lib/v2/mlog"
)
-func addResponseHeaders(headers map[string]string, h http.Handler) http.Handler {
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- for k, v := range headers {
- rw.Header().Set(k, v)
- }
- h.ServeHTTP(rw, r)
- })
+type middleware func(http.Handler) http.Handler
+
+func applyMiddlewares(h http.Handler, middlewares ...middleware) http.Handler {
+ for _, m := range middlewares {
+ h = m(h)
+ }
+ return h
}
-func setLoggerMiddleware(logger *mlog.Logger, h http.Handler) http.Handler {
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+func addResponseHeadersMiddleware(headers map[string]string) middleware {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ for k, v := range headers {
+ rw.Header().Set(k, v)
+ }
+ h.ServeHTTP(rw, r)
+ })
+ }
+}
- type reqInfoKey string
+func setLoggerMiddleware(logger *mlog.Logger) middleware {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- ip, _, _ := net.SplitHostPort(r.RemoteAddr)
+ type logCtxKey string
- ctx := r.Context()
- ctx = mctx.Annotate(ctx,
- reqInfoKey("remote_ip"), ip,
- reqInfoKey("url"), r.URL,
- reqInfoKey("method"), r.Method,
- )
+ ip, _, _ := net.SplitHostPort(r.RemoteAddr)
- r = r.WithContext(ctx)
- r = apiutil.SetRequestLogger(r, logger)
- h.ServeHTTP(rw, r)
- })
+ ctx := r.Context()
+ ctx = mctx.Annotate(ctx,
+ logCtxKey("remote_ip"), ip,
+ logCtxKey("url"), r.URL,
+ logCtxKey("method"), r.Method,
+ )
+
+ r = r.WithContext(ctx)
+ r = apiutil.SetRequestLogger(r, logger)
+ h.ServeHTTP(rw, r)
+ })
+ }
}
type logResponseWriter struct {