From faa296075f5ea2d8e01004b46b036997f9529d99 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sun, 26 May 2024 22:06:44 +0200 Subject: Clean out Get/SetRequestLogger from apiutil --- src/http/apiutil/apiutil.go | 59 ++++++++++++++++++++------------------------- src/http/assets.go | 36 +++++++++++++++++++-------- src/http/auth.go | 6 ++--- src/http/csrf.go | 4 ++- src/http/drafts.go | 9 ++++--- src/http/feed.go | 16 ++++++++---- src/http/http.go | 13 +++++----- src/http/middleware.go | 58 +++++++++++++++++++++++--------------------- src/http/posts.go | 29 ++++++++++++++-------- src/http/tpl.go | 21 ++++++++++------ 10 files changed, 144 insertions(+), 107 deletions(-) (limited to 'src/http') diff --git a/src/http/apiutil/apiutil.go b/src/http/apiutil/apiutil.go index 2151d83..a27d0d5 100644 --- a/src/http/apiutil/apiutil.go +++ b/src/http/apiutil/apiutil.go @@ -16,35 +16,13 @@ import ( "dev.mediocregopher.com/mediocre-go-lib.git/mlog" ) -// TODO I don't think Set/GetRequestLogger are necessary? Seems sufficient to -// just annotate the request's context - -type loggerCtxKey int - -// SetRequestLogger sets the given Logger onto the given Request's Context, -// returning a copy. -func SetRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request { - ctx := r.Context() - ctx = context.WithValue(ctx, loggerCtxKey(0), logger) - return r.WithContext(ctx) -} - -// GetRequestLogger returns the Logger which was set by SetRequestLogger onto -// this Request, or nil. -func GetRequestLogger(r *http.Request) *mlog.Logger { - ctx := r.Context() - logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger) - if logger == nil { - logger = mlog.Null - } - return logger -} - // JSONResult writes the JSON encoding of the given value as the response body. -func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) { +func JSONResult( + ctx context.Context, logger *mlog.Logger, rw http.ResponseWriter, v any, +) { b, err := json.Marshal(v) if err != nil { - InternalServerError(rw, r, err) + InternalServerError(ctx, logger, rw, "%w", err) return } b = append(b, '\n') @@ -54,12 +32,20 @@ func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) { } // BadRequest writes a 400 status and a JSON encoded error struct containing the -// given error as the response body. -func BadRequest(rw http.ResponseWriter, r *http.Request, err error) { - GetRequestLogger(r).Warn(r.Context(), "bad request", err) +// given error (created using `fmt.Errorf(fmtStr, args...)`) as the response +// body. +func BadRequest( + ctx context.Context, + logger *mlog.Logger, + rw http.ResponseWriter, + fmtStr string, + args ...any, +) { + err := fmt.Errorf(fmtStr, args...) + logger.Warn(ctx, "bad request", err) rw.WriteHeader(400) - JSONResult(rw, r, struct { + JSONResult(ctx, logger, rw, struct { Error string `json:"error"` }{ Error: err.Error(), @@ -69,11 +55,18 @@ func BadRequest(rw http.ResponseWriter, r *http.Request, err error) { // InternalServerError writes a 500 status and a JSON encoded error struct // containing a generic error as the response body (though it will log the given // one). -func InternalServerError(rw http.ResponseWriter, r *http.Request, err error) { - GetRequestLogger(r).Error(r.Context(), "internal server error", err) +func InternalServerError( + ctx context.Context, + logger *mlog.Logger, + rw http.ResponseWriter, + fmtStr string, + args ...any, +) { + err := fmt.Errorf(fmtStr, args...) + logger.Error(ctx, "internal server error", err) rw.WriteHeader(500) - JSONResult(rw, r, struct { + JSONResult(ctx, logger, rw, struct { Error string `json:"error"` }{ Error: "internal server error", diff --git a/src/http/assets.go b/src/http/assets.go index 8f43074..09cbf06 100644 --- a/src/http/assets.go +++ b/src/http/assets.go @@ -3,7 +3,6 @@ package http import ( "bytes" "errors" - "fmt" "net/http" "path/filepath" "strings" @@ -23,10 +22,16 @@ func (a *api) managePostAssetsHandler() http.Handler { func (a *api) getPostAssetHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + logger = a.params.Logger + ) maxWidth, err := apiutil.StrToInt(r.FormValue("w"), 0) if err != nil { - apiutil.BadRequest(rw, r, fmt.Errorf("invalid w parameter: %w", err)) + apiutil.BadRequest( + ctx, logger, rw, "invalid w parameter: %w", err, + ) return } @@ -53,7 +58,7 @@ func (a *api) getPostAssetHandler() http.Handler { } else if err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("fetching asset at path %q: %w", path, err), + ctx, logger, rw, "fetching asset at path %q: %w", path, err, ) return } @@ -67,22 +72,30 @@ func (a *api) getPostAssetHandler() http.Handler { func (a *api) postPostAssetHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + logger = a.params.Logger + ) id := r.PostFormValue("id") if id == "" { - apiutil.BadRequest(rw, r, errors.New("id is required")) + apiutil.BadRequest(ctx, logger, rw, "id is required") return } file, _, err := r.FormFile("file") if err != nil { - apiutil.BadRequest(rw, r, fmt.Errorf("reading multipart file: %w", err)) + apiutil.BadRequest( + ctx, logger, rw, "reading multipart file: %w", err, + ) return } defer file.Close() if err := a.params.PostAssetStore.Set(id, file); err != nil { - apiutil.InternalServerError(rw, r, fmt.Errorf("storing file: %w", err)) + apiutil.InternalServerError( + ctx, logger, rw, "storing file: %w", err, + ) return } @@ -95,11 +108,14 @@ func (a *api) postPostAssetHandler() http.Handler { func (a *api) deletePostAssetHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - - id := filepath.Base(r.URL.Path) + var ( + ctx = r.Context() + logger = a.params.Logger + id = filepath.Base(r.URL.Path) + ) if id == "/" { - apiutil.BadRequest(rw, r, errors.New("id is required")) + apiutil.BadRequest(ctx, logger, rw, "id is required") return } @@ -110,7 +126,7 @@ func (a *api) deletePostAssetHandler() http.Handler { return } else if err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("deleting asset with id %q: %w", id, err), + ctx, logger, rw, "deleting asset with id %q: %w", id, err, ) return } diff --git a/src/http/auth.go b/src/http/auth.go index eac73b7..70d33fb 100644 --- a/src/http/auth.go +++ b/src/http/auth.go @@ -5,7 +5,7 @@ import ( "net/http" "time" - "dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil" + "dev.mediocregopher.com/mediocre-go-lib.git/mlog" "golang.org/x/crypto/bcrypt" ) @@ -65,12 +65,12 @@ func (a *auther) Allowed(ctx context.Context, username, password string) bool { return err == nil } -func authMiddleware(auther Auther) middleware { +func authMiddleware(logger *mlog.Logger, auther Auther) middleware { respondUnauthorized := func(rw http.ResponseWriter, r *http.Request) { rw.Header().Set("WWW-Authenticate", `Basic realm="NOPE"`) rw.WriteHeader(http.StatusUnauthorized) - apiutil.GetRequestLogger(r).WarnString(r.Context(), "unauthorized") + logger.WarnString(r.Context(), "unauthorized") } return func(h http.Handler) http.Handler { diff --git a/src/http/csrf.go b/src/http/csrf.go index 707aac4..a4f7e73 100644 --- a/src/http/csrf.go +++ b/src/http/csrf.go @@ -31,10 +31,12 @@ func checkCSRF(r *http.Request, publicURL *url.URL) error { } func (a *api) checkCSRFMiddleware(h http.Handler) http.Handler { + logger := a.params.Logger.WithNamespace("csrf") return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() if err := checkCSRF(r, a.params.PublicURL); err != nil { - apiutil.BadRequest(rw, r, errors.New("invalid Referer")) + apiutil.BadRequest(ctx, logger, rw, "invalid Referer") return } diff --git a/src/http/drafts.go b/src/http/drafts.go index a20464b..aedd8c0 100644 --- a/src/http/drafts.go +++ b/src/http/drafts.go @@ -1,7 +1,6 @@ package http import ( - "fmt" "net/http" "dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil" @@ -16,16 +15,20 @@ func (a *api) manageDraftPostsHandler() http.Handler { func (a *api) postDraftPostHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + logger = a.params.Logger + ) p, err := postFromPostReq(r) if err != nil { - apiutil.BadRequest(rw, r, err) + apiutil.BadRequest(ctx, logger, rw, "%w", err) return } if err := a.params.PostDraftStore.Set(p); err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("storing post with id %q: %w", p.ID, err), + ctx, logger, rw, "storing post with id %q: %w", p.ID, err, ) return } diff --git a/src/http/feed.go b/src/http/feed.go index 676d376..f38da91 100644 --- a/src/http/feed.go +++ b/src/http/feed.go @@ -1,7 +1,6 @@ package http import ( - "fmt" "net/http" "dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil" @@ -12,8 +11,11 @@ import ( func (a *api) renderFeedHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - - tag := r.FormValue("tag") + var ( + ctx = r.Context() + logger = a.params.Logger + tag = r.FormValue("tag") + ) var ( posts []post.StoredPost @@ -27,7 +29,9 @@ func (a *api) renderFeedHandler() http.Handler { } if err != nil { - apiutil.InternalServerError(rw, r, fmt.Errorf("fetching recent posts: %w", err)) + apiutil.InternalServerError( + ctx, logger, rw, "fetching recent posts: %w", err, + ) return } @@ -68,7 +72,9 @@ func (a *api) renderFeedHandler() http.Handler { } if err := feed.WriteAtom(rw); err != nil { - apiutil.InternalServerError(rw, r, fmt.Errorf("writing atom feed: %w", err)) + apiutil.InternalServerError( + ctx, logger, rw, "writing atom feed: %w", err, + ) return } }) diff --git a/src/http/http.go b/src/http/http.go index 4403a69..11b4976 100644 --- a/src/http/http.go +++ b/src/http/http.go @@ -209,8 +209,9 @@ func (a *api) blogHandler() http.Handler { mux.Handle("/drafts/", http.StripPrefix("/drafts", // everything to do with drafts is protected - authMiddleware(a.auther)( - + authMiddleware( + a.params.Logger.WithNamespace("drafts-auther"), a.auther, + )( apiutil.MethodMux(map[string]http.Handler{ "EDIT": a.editPostHandler(true), "MANAGE": a.manageDraftPostsHandler(), @@ -228,13 +229,13 @@ func (a *api) blogHandler() http.Handler { mux.Handle("/", a.renderIndexHandler()) readOnlyMiddlewares := []middleware{ - logReqMiddleware, // only log GETs on cache miss + logReqMiddleware(a.params.Logger), // only log GETs on cache miss cacheMiddleware(a.params.Cache, a.params.PublicURL), } readWriteMiddlewares := []middleware{ - purgeCacheOnOKMiddleware(a.params.Cache), - authMiddleware(a.auther), + purgeCacheOnOKMiddleware(a.params.Logger, a.params.Cache), + authMiddleware(a.params.Logger.WithNamespace("rw-auther"), a.auther), } h := apiutil.MethodMux(map[string]http.Handler{ @@ -270,7 +271,7 @@ func (a *api) handler() http.Handler { noCacheMiddleware, ), }), - setLoggerMiddleware(a.params.Logger), + setLogCtxMiddleware(), ) return h diff --git a/src/http/middleware.go b/src/http/middleware.go index 7a34e83..d8f2e6e 100644 --- a/src/http/middleware.go +++ b/src/http/middleware.go @@ -10,7 +10,6 @@ import ( "time" "dev.mediocregopher.com/mediocre-blog.git/src/cache" - "dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mlog" ) @@ -35,7 +34,7 @@ func addResponseHeadersMiddleware(headers map[string]string) middleware { } } -func setLoggerMiddleware(logger *mlog.Logger) middleware { +func setLogCtxMiddleware() middleware { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -60,7 +59,6 @@ func setLoggerMiddleware(logger *mlog.Logger) middleware { } r = r.WithContext(ctx) - r = apiutil.SetRequestLogger(r, logger) h.ServeHTTP(rw, r) }) } @@ -86,40 +84,44 @@ func (rw *wrappedResponseWriter) WriteHeader(statusCode int) { rw.ResponseWriter.WriteHeader(statusCode) } -func logReqMiddleware(h http.Handler) http.Handler { +func logReqMiddleware(logger *mlog.Logger) func(h http.Handler) http.Handler { type logCtxKey string - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - wrw := newWrappedResponseWriter(rw) + wrw := newWrappedResponseWriter(rw) - started := time.Now() - h.ServeHTTP(wrw, r) - took := time.Since(started) + started := time.Now() + h.ServeHTTP(wrw, r) + took := time.Since(started) - ctx := r.Context() - ctx = mctx.Annotate(ctx, - logCtxKey("took"), took.String(), - logCtxKey("response_code"), wrw.statusCode, - ) + ctx := r.Context() + ctx = mctx.Annotate(ctx, + logCtxKey("took"), took.String(), + logCtxKey("response_code"), wrw.statusCode, + ) - apiutil.GetRequestLogger(r).Info(ctx, "handled HTTP request") - }) + logger.Info(ctx, "handled HTTP request") + }) + } } -func disallowGetMiddleware(h http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { +func disallowGetMiddleware(logger *mlog.Logger) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // we allow websockets to be GETs because, well, they must be - if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" { - h.ServeHTTP(rw, r) - return - } + // we allow websockets to be GETs because, well, they must be + if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" { + h.ServeHTTP(rw, r) + return + } - apiutil.GetRequestLogger(r).WarnString(r.Context(), "method not allowed") - rw.WriteHeader(405) - }) + logger.WarnString(r.Context(), "method not allowed") + rw.WriteHeader(405) + }) + } } type cacheResponseWriter struct { @@ -188,7 +190,7 @@ func cacheMiddleware(cache cache.Cache, publicURL *url.URL) middleware { } } -func purgeCacheOnOKMiddleware(cache cache.Cache) middleware { +func purgeCacheOnOKMiddleware(logger *mlog.Logger, cache cache.Cache) middleware { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -196,7 +198,7 @@ func purgeCacheOnOKMiddleware(cache cache.Cache) middleware { h.ServeHTTP(wrw, r) if wrw.statusCode == 200 { - apiutil.GetRequestLogger(r).Info(r.Context(), "purging cache!") + logger.Info(r.Context(), "purging cache!") cache.Purge() } }) diff --git a/src/http/posts.go b/src/http/posts.go index 1bc65c8..5a295a7 100644 --- a/src/http/posts.go +++ b/src/http/posts.go @@ -165,12 +165,14 @@ func (a *api) publishPost(ctx context.Context, p post.Post) error { func (a *api) postPostHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - - ctx := r.Context() + var ( + ctx = r.Context() + logger = a.params.Logger + ) p, err := postFromPostReq(r) if err != nil { - apiutil.BadRequest(rw, r, err) + apiutil.BadRequest(ctx, logger, rw, "%w", err) return } @@ -178,7 +180,7 @@ func (a *api) postPostHandler() http.Handler { if err := a.publishPost(ctx, p); err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("publishing post with id %q: %w", p.ID, err), + ctx, logger, rw, "publishing post with id %q: %w", p.ID, err, ) return } @@ -192,11 +194,14 @@ func (a *api) postPostHandler() http.Handler { func (a *api) deletePostHandler(isDraft bool) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - - id := filepath.Base(r.URL.Path) + var ( + ctx = r.Context() + logger = a.params.Logger + id = filepath.Base(r.URL.Path) + ) if id == "/" { - apiutil.BadRequest(rw, r, errors.New("id is required")) + apiutil.BadRequest(ctx, logger, rw, "id is required") return } @@ -213,7 +218,7 @@ func (a *api) deletePostHandler(isDraft bool) http.Handler { return } else if err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("deleting post with id %q: %w", id, err), + ctx, logger, rw, "deleting post with id %q: %w", id, err, ) return } @@ -235,10 +240,14 @@ func (a *api) previewPostHandler() http.Handler { tpl := a.mustParseBasedTpl("post.html") return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + logger = a.params.Logger + ) p, err := postFromPostReq(r) if err != nil { - apiutil.BadRequest(rw, r, err) + apiutil.BadRequest(ctx, logger, rw, "%w", err) return } @@ -247,7 +256,7 @@ func (a *api) previewPostHandler() http.Handler { PublishedAt: time.Now(), } - r = r.WithContext(render.WithPost(r.Context(), storedPost)) + r = r.WithContext(render.WithPost(ctx, storedPost)) a.executeTemplate(rw, r, tpl, nil) }) } diff --git a/src/http/tpl.go b/src/http/tpl.go index 42341a2..452f444 100644 --- a/src/http/tpl.go +++ b/src/http/tpl.go @@ -80,10 +80,12 @@ func (a *api) executeTemplate( tpl *template.Template, payload interface{}, ) { - - tplData := a.newTPLData(r, payload) - - buf := new(bytes.Buffer) + var ( + ctx = r.Context() + logger = a.params.Logger + tplData = a.newTPLData(r, payload) + buf = new(bytes.Buffer) + ) err := tpl.Execute(buf, tplData) if errors.Is(err, post.ErrPostNotFound) { @@ -91,7 +93,7 @@ func (a *api) executeTemplate( return } else if err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("rendering template: %w", err), + ctx, logger, rw, "rendering template: %w", err, ) return } @@ -114,12 +116,15 @@ func (a *api) renderDumbTplHandler(tplName string) http.Handler { tpl := a.mustParseBasedTpl(tplName) return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - - tplData := a.newTPLData(r, nil) + var ( + ctx = r.Context() + logger = a.params.Logger + tplData = a.newTPLData(r, nil) + ) if err := tpl.Execute(rw, tplData); err != nil { apiutil.InternalServerError( - rw, r, fmt.Errorf("rendering %q: %w", tplName, err), + ctx, logger, rw, "rendering %q: %w", tplName, err, ) return } -- cgit v1.2.3