summaryrefslogtreecommitdiff
path: root/src/http
diff options
context:
space:
mode:
Diffstat (limited to 'src/http')
-rw-r--r--src/http/apiutil/apiutil.go59
-rw-r--r--src/http/assets.go36
-rw-r--r--src/http/auth.go6
-rw-r--r--src/http/csrf.go4
-rw-r--r--src/http/drafts.go9
-rw-r--r--src/http/feed.go16
-rw-r--r--src/http/http.go13
-rw-r--r--src/http/middleware.go58
-rw-r--r--src/http/posts.go29
-rw-r--r--src/http/tpl.go21
10 files changed, 144 insertions, 107 deletions
diff --git a/src/http/apiutil/apiutil.go b/src/http/apiutil/apiutil.go
index 2151d83..a27d0d5 100644
--- a/src/http/apiutil/apiutil.go
+++ b/src/http/apiutil/apiutil.go
@@ -16,35 +16,13 @@ import (
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
)
-// TODO I don't think Set/GetRequestLogger are necessary? Seems sufficient to
-// just annotate the request's context
-
-type loggerCtxKey int
-
-// SetRequestLogger sets the given Logger onto the given Request's Context,
-// returning a copy.
-func SetRequestLogger(r *http.Request, logger *mlog.Logger) *http.Request {
- ctx := r.Context()
- ctx = context.WithValue(ctx, loggerCtxKey(0), logger)
- return r.WithContext(ctx)
-}
-
-// GetRequestLogger returns the Logger which was set by SetRequestLogger onto
-// this Request, or nil.
-func GetRequestLogger(r *http.Request) *mlog.Logger {
- ctx := r.Context()
- logger, _ := ctx.Value(loggerCtxKey(0)).(*mlog.Logger)
- if logger == nil {
- logger = mlog.Null
- }
- return logger
-}
-
// JSONResult writes the JSON encoding of the given value as the response body.
-func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) {
+func JSONResult(
+ ctx context.Context, logger *mlog.Logger, rw http.ResponseWriter, v any,
+) {
b, err := json.Marshal(v)
if err != nil {
- InternalServerError(rw, r, err)
+ InternalServerError(ctx, logger, rw, "%w", err)
return
}
b = append(b, '\n')
@@ -54,12 +32,20 @@ func JSONResult(rw http.ResponseWriter, r *http.Request, v interface{}) {
}
// BadRequest writes a 400 status and a JSON encoded error struct containing the
-// given error as the response body.
-func BadRequest(rw http.ResponseWriter, r *http.Request, err error) {
- GetRequestLogger(r).Warn(r.Context(), "bad request", err)
+// given error (created using `fmt.Errorf(fmtStr, args...)`) as the response
+// body.
+func BadRequest(
+ ctx context.Context,
+ logger *mlog.Logger,
+ rw http.ResponseWriter,
+ fmtStr string,
+ args ...any,
+) {
+ err := fmt.Errorf(fmtStr, args...)
+ logger.Warn(ctx, "bad request", err)
rw.WriteHeader(400)
- JSONResult(rw, r, struct {
+ JSONResult(ctx, logger, rw, struct {
Error string `json:"error"`
}{
Error: err.Error(),
@@ -69,11 +55,18 @@ func BadRequest(rw http.ResponseWriter, r *http.Request, err error) {
// InternalServerError writes a 500 status and a JSON encoded error struct
// containing a generic error as the response body (though it will log the given
// one).
-func InternalServerError(rw http.ResponseWriter, r *http.Request, err error) {
- GetRequestLogger(r).Error(r.Context(), "internal server error", err)
+func InternalServerError(
+ ctx context.Context,
+ logger *mlog.Logger,
+ rw http.ResponseWriter,
+ fmtStr string,
+ args ...any,
+) {
+ err := fmt.Errorf(fmtStr, args...)
+ logger.Error(ctx, "internal server error", err)
rw.WriteHeader(500)
- JSONResult(rw, r, struct {
+ JSONResult(ctx, logger, rw, struct {
Error string `json:"error"`
}{
Error: "internal server error",
diff --git a/src/http/assets.go b/src/http/assets.go
index 8f43074..09cbf06 100644
--- a/src/http/assets.go
+++ b/src/http/assets.go
@@ -3,7 +3,6 @@ package http
import (
"bytes"
"errors"
- "fmt"
"net/http"
"path/filepath"
"strings"
@@ -23,10 +22,16 @@ func (a *api) managePostAssetsHandler() http.Handler {
func (a *api) getPostAssetHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ )
maxWidth, err := apiutil.StrToInt(r.FormValue("w"), 0)
if err != nil {
- apiutil.BadRequest(rw, r, fmt.Errorf("invalid w parameter: %w", err))
+ apiutil.BadRequest(
+ ctx, logger, rw, "invalid w parameter: %w", err,
+ )
return
}
@@ -53,7 +58,7 @@ func (a *api) getPostAssetHandler() http.Handler {
} else if err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("fetching asset at path %q: %w", path, err),
+ ctx, logger, rw, "fetching asset at path %q: %w", path, err,
)
return
}
@@ -67,22 +72,30 @@ func (a *api) getPostAssetHandler() http.Handler {
func (a *api) postPostAssetHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ )
id := r.PostFormValue("id")
if id == "" {
- apiutil.BadRequest(rw, r, errors.New("id is required"))
+ apiutil.BadRequest(ctx, logger, rw, "id is required")
return
}
file, _, err := r.FormFile("file")
if err != nil {
- apiutil.BadRequest(rw, r, fmt.Errorf("reading multipart file: %w", err))
+ apiutil.BadRequest(
+ ctx, logger, rw, "reading multipart file: %w", err,
+ )
return
}
defer file.Close()
if err := a.params.PostAssetStore.Set(id, file); err != nil {
- apiutil.InternalServerError(rw, r, fmt.Errorf("storing file: %w", err))
+ apiutil.InternalServerError(
+ ctx, logger, rw, "storing file: %w", err,
+ )
return
}
@@ -95,11 +108,14 @@ func (a *api) postPostAssetHandler() http.Handler {
func (a *api) deletePostAssetHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-
- id := filepath.Base(r.URL.Path)
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ id = filepath.Base(r.URL.Path)
+ )
if id == "/" {
- apiutil.BadRequest(rw, r, errors.New("id is required"))
+ apiutil.BadRequest(ctx, logger, rw, "id is required")
return
}
@@ -110,7 +126,7 @@ func (a *api) deletePostAssetHandler() http.Handler {
return
} else if err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("deleting asset with id %q: %w", id, err),
+ ctx, logger, rw, "deleting asset with id %q: %w", id, err,
)
return
}
diff --git a/src/http/auth.go b/src/http/auth.go
index eac73b7..70d33fb 100644
--- a/src/http/auth.go
+++ b/src/http/auth.go
@@ -5,7 +5,7 @@ import (
"net/http"
"time"
- "dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil"
+ "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"golang.org/x/crypto/bcrypt"
)
@@ -65,12 +65,12 @@ func (a *auther) Allowed(ctx context.Context, username, password string) bool {
return err == nil
}
-func authMiddleware(auther Auther) middleware {
+func authMiddleware(logger *mlog.Logger, auther Auther) middleware {
respondUnauthorized := func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("WWW-Authenticate", `Basic realm="NOPE"`)
rw.WriteHeader(http.StatusUnauthorized)
- apiutil.GetRequestLogger(r).WarnString(r.Context(), "unauthorized")
+ logger.WarnString(r.Context(), "unauthorized")
}
return func(h http.Handler) http.Handler {
diff --git a/src/http/csrf.go b/src/http/csrf.go
index 707aac4..a4f7e73 100644
--- a/src/http/csrf.go
+++ b/src/http/csrf.go
@@ -31,10 +31,12 @@ func checkCSRF(r *http.Request, publicURL *url.URL) error {
}
func (a *api) checkCSRFMiddleware(h http.Handler) http.Handler {
+ logger := a.params.Logger.WithNamespace("csrf")
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
if err := checkCSRF(r, a.params.PublicURL); err != nil {
- apiutil.BadRequest(rw, r, errors.New("invalid Referer"))
+ apiutil.BadRequest(ctx, logger, rw, "invalid Referer")
return
}
diff --git a/src/http/drafts.go b/src/http/drafts.go
index a20464b..aedd8c0 100644
--- a/src/http/drafts.go
+++ b/src/http/drafts.go
@@ -1,7 +1,6 @@
package http
import (
- "fmt"
"net/http"
"dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil"
@@ -16,16 +15,20 @@ func (a *api) manageDraftPostsHandler() http.Handler {
func (a *api) postDraftPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ )
p, err := postFromPostReq(r)
if err != nil {
- apiutil.BadRequest(rw, r, err)
+ apiutil.BadRequest(ctx, logger, rw, "%w", err)
return
}
if err := a.params.PostDraftStore.Set(p); err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("storing post with id %q: %w", p.ID, err),
+ ctx, logger, rw, "storing post with id %q: %w", p.ID, err,
)
return
}
diff --git a/src/http/feed.go b/src/http/feed.go
index 676d376..f38da91 100644
--- a/src/http/feed.go
+++ b/src/http/feed.go
@@ -1,7 +1,6 @@
package http
import (
- "fmt"
"net/http"
"dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil"
@@ -12,8 +11,11 @@ import (
func (a *api) renderFeedHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-
- tag := r.FormValue("tag")
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ tag = r.FormValue("tag")
+ )
var (
posts []post.StoredPost
@@ -27,7 +29,9 @@ func (a *api) renderFeedHandler() http.Handler {
}
if err != nil {
- apiutil.InternalServerError(rw, r, fmt.Errorf("fetching recent posts: %w", err))
+ apiutil.InternalServerError(
+ ctx, logger, rw, "fetching recent posts: %w", err,
+ )
return
}
@@ -68,7 +72,9 @@ func (a *api) renderFeedHandler() http.Handler {
}
if err := feed.WriteAtom(rw); err != nil {
- apiutil.InternalServerError(rw, r, fmt.Errorf("writing atom feed: %w", err))
+ apiutil.InternalServerError(
+ ctx, logger, rw, "writing atom feed: %w", err,
+ )
return
}
})
diff --git a/src/http/http.go b/src/http/http.go
index 4403a69..11b4976 100644
--- a/src/http/http.go
+++ b/src/http/http.go
@@ -209,8 +209,9 @@ func (a *api) blogHandler() http.Handler {
mux.Handle("/drafts/", http.StripPrefix("/drafts",
// everything to do with drafts is protected
- authMiddleware(a.auther)(
-
+ authMiddleware(
+ a.params.Logger.WithNamespace("drafts-auther"), a.auther,
+ )(
apiutil.MethodMux(map[string]http.Handler{
"EDIT": a.editPostHandler(true),
"MANAGE": a.manageDraftPostsHandler(),
@@ -228,13 +229,13 @@ func (a *api) blogHandler() http.Handler {
mux.Handle("/", a.renderIndexHandler())
readOnlyMiddlewares := []middleware{
- logReqMiddleware, // only log GETs on cache miss
+ logReqMiddleware(a.params.Logger), // only log GETs on cache miss
cacheMiddleware(a.params.Cache, a.params.PublicURL),
}
readWriteMiddlewares := []middleware{
- purgeCacheOnOKMiddleware(a.params.Cache),
- authMiddleware(a.auther),
+ purgeCacheOnOKMiddleware(a.params.Logger, a.params.Cache),
+ authMiddleware(a.params.Logger.WithNamespace("rw-auther"), a.auther),
}
h := apiutil.MethodMux(map[string]http.Handler{
@@ -270,7 +271,7 @@ func (a *api) handler() http.Handler {
noCacheMiddleware,
),
}),
- setLoggerMiddleware(a.params.Logger),
+ setLogCtxMiddleware(),
)
return h
diff --git a/src/http/middleware.go b/src/http/middleware.go
index 7a34e83..d8f2e6e 100644
--- a/src/http/middleware.go
+++ b/src/http/middleware.go
@@ -10,7 +10,6 @@ import (
"time"
"dev.mediocregopher.com/mediocre-blog.git/src/cache"
- "dev.mediocregopher.com/mediocre-blog.git/src/http/apiutil"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
)
@@ -35,7 +34,7 @@ func addResponseHeadersMiddleware(headers map[string]string) middleware {
}
}
-func setLoggerMiddleware(logger *mlog.Logger) middleware {
+func setLogCtxMiddleware() middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -60,7 +59,6 @@ func setLoggerMiddleware(logger *mlog.Logger) middleware {
}
r = r.WithContext(ctx)
- r = apiutil.SetRequestLogger(r, logger)
h.ServeHTTP(rw, r)
})
}
@@ -86,40 +84,44 @@ func (rw *wrappedResponseWriter) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
}
-func logReqMiddleware(h http.Handler) http.Handler {
+func logReqMiddleware(logger *mlog.Logger) func(h http.Handler) http.Handler {
type logCtxKey string
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- wrw := newWrappedResponseWriter(rw)
+ wrw := newWrappedResponseWriter(rw)
- started := time.Now()
- h.ServeHTTP(wrw, r)
- took := time.Since(started)
+ started := time.Now()
+ h.ServeHTTP(wrw, r)
+ took := time.Since(started)
- ctx := r.Context()
- ctx = mctx.Annotate(ctx,
- logCtxKey("took"), took.String(),
- logCtxKey("response_code"), wrw.statusCode,
- )
+ ctx := r.Context()
+ ctx = mctx.Annotate(ctx,
+ logCtxKey("took"), took.String(),
+ logCtxKey("response_code"), wrw.statusCode,
+ )
- apiutil.GetRequestLogger(r).Info(ctx, "handled HTTP request")
- })
+ logger.Info(ctx, "handled HTTP request")
+ })
+ }
}
-func disallowGetMiddleware(h http.Handler) http.Handler {
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+func disallowGetMiddleware(logger *mlog.Logger) func(h http.Handler) http.Handler {
+ return func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- // we allow websockets to be GETs because, well, they must be
- if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" {
- h.ServeHTTP(rw, r)
- return
- }
+ // we allow websockets to be GETs because, well, they must be
+ if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" {
+ h.ServeHTTP(rw, r)
+ return
+ }
- apiutil.GetRequestLogger(r).WarnString(r.Context(), "method not allowed")
- rw.WriteHeader(405)
- })
+ logger.WarnString(r.Context(), "method not allowed")
+ rw.WriteHeader(405)
+ })
+ }
}
type cacheResponseWriter struct {
@@ -188,7 +190,7 @@ func cacheMiddleware(cache cache.Cache, publicURL *url.URL) middleware {
}
}
-func purgeCacheOnOKMiddleware(cache cache.Cache) middleware {
+func purgeCacheOnOKMiddleware(logger *mlog.Logger, cache cache.Cache) middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -196,7 +198,7 @@ func purgeCacheOnOKMiddleware(cache cache.Cache) middleware {
h.ServeHTTP(wrw, r)
if wrw.statusCode == 200 {
- apiutil.GetRequestLogger(r).Info(r.Context(), "purging cache!")
+ logger.Info(r.Context(), "purging cache!")
cache.Purge()
}
})
diff --git a/src/http/posts.go b/src/http/posts.go
index 1bc65c8..5a295a7 100644
--- a/src/http/posts.go
+++ b/src/http/posts.go
@@ -165,12 +165,14 @@ func (a *api) publishPost(ctx context.Context, p post.Post) error {
func (a *api) postPostHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-
- ctx := r.Context()
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ )
p, err := postFromPostReq(r)
if err != nil {
- apiutil.BadRequest(rw, r, err)
+ apiutil.BadRequest(ctx, logger, rw, "%w", err)
return
}
@@ -178,7 +180,7 @@ func (a *api) postPostHandler() http.Handler {
if err := a.publishPost(ctx, p); err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("publishing post with id %q: %w", p.ID, err),
+ ctx, logger, rw, "publishing post with id %q: %w", p.ID, err,
)
return
}
@@ -192,11 +194,14 @@ func (a *api) postPostHandler() http.Handler {
func (a *api) deletePostHandler(isDraft bool) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-
- id := filepath.Base(r.URL.Path)
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ id = filepath.Base(r.URL.Path)
+ )
if id == "/" {
- apiutil.BadRequest(rw, r, errors.New("id is required"))
+ apiutil.BadRequest(ctx, logger, rw, "id is required")
return
}
@@ -213,7 +218,7 @@ func (a *api) deletePostHandler(isDraft bool) http.Handler {
return
} else if err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("deleting post with id %q: %w", id, err),
+ ctx, logger, rw, "deleting post with id %q: %w", id, err,
)
return
}
@@ -235,10 +240,14 @@ func (a *api) previewPostHandler() http.Handler {
tpl := a.mustParseBasedTpl("post.html")
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ )
p, err := postFromPostReq(r)
if err != nil {
- apiutil.BadRequest(rw, r, err)
+ apiutil.BadRequest(ctx, logger, rw, "%w", err)
return
}
@@ -247,7 +256,7 @@ func (a *api) previewPostHandler() http.Handler {
PublishedAt: time.Now(),
}
- r = r.WithContext(render.WithPost(r.Context(), storedPost))
+ r = r.WithContext(render.WithPost(ctx, storedPost))
a.executeTemplate(rw, r, tpl, nil)
})
}
diff --git a/src/http/tpl.go b/src/http/tpl.go
index 42341a2..452f444 100644
--- a/src/http/tpl.go
+++ b/src/http/tpl.go
@@ -80,10 +80,12 @@ func (a *api) executeTemplate(
tpl *template.Template,
payload interface{},
) {
-
- tplData := a.newTPLData(r, payload)
-
- buf := new(bytes.Buffer)
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ tplData = a.newTPLData(r, payload)
+ buf = new(bytes.Buffer)
+ )
err := tpl.Execute(buf, tplData)
if errors.Is(err, post.ErrPostNotFound) {
@@ -91,7 +93,7 @@ func (a *api) executeTemplate(
return
} else if err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("rendering template: %w", err),
+ ctx, logger, rw, "rendering template: %w", err,
)
return
}
@@ -114,12 +116,15 @@ func (a *api) renderDumbTplHandler(tplName string) http.Handler {
tpl := a.mustParseBasedTpl(tplName)
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
-
- tplData := a.newTPLData(r, nil)
+ var (
+ ctx = r.Context()
+ logger = a.params.Logger
+ tplData = a.newTPLData(r, nil)
+ )
if err := tpl.Execute(rw, tplData); err != nil {
apiutil.InternalServerError(
- rw, r, fmt.Errorf("rendering %q: %w", tplName, err),
+ ctx, logger, rw, "rendering %q: %w", tplName, err,
)
return
}