From 4f01edb9230f58ff84b0dd892c931ec8ac9aad55 Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Tue, 13 Sep 2022 12:56:08 +0200 Subject: move src out of srv, clean up default.nix and Makefile --- src/http/middleware.go | 199 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 src/http/middleware.go (limited to 'src/http/middleware.go') diff --git a/src/http/middleware.go b/src/http/middleware.go new file mode 100644 index 0000000..b82fc29 --- /dev/null +++ b/src/http/middleware.go @@ -0,0 +1,199 @@ +package http + +import ( + "bytes" + "net" + "net/http" + "path/filepath" + "sync" + "time" + + lru "github.com/hashicorp/golang-lru" + "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" + "github.com/mediocregopher/mediocre-go-lib/v2/mctx" + "github.com/mediocregopher/mediocre-go-lib/v2/mlog" +) + +type middleware func(http.Handler) http.Handler + +func applyMiddlewares(h http.Handler, middlewares ...middleware) http.Handler { + for _, m := range middlewares { + h = m(h) + } + return h +} + +func addResponseHeadersMiddleware(headers map[string]string) middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + for k, v := range headers { + rw.Header().Set(k, v) + } + h.ServeHTTP(rw, r) + }) + } +} + +func setLoggerMiddleware(logger *mlog.Logger) middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + type logCtxKey string + + ctx := r.Context() + ctx = mctx.Annotate(ctx, + logCtxKey("url"), r.URL, + logCtxKey("method"), r.Method, + ) + + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + ctx = mctx.Annotate(ctx, logCtxKey("x_forwarded_for"), xff) + } + + if xrip := r.Header.Get("X-Real-IP"); xrip != "" { + ctx = mctx.Annotate(ctx, logCtxKey("x_real_ip"), xrip) + } + + if ip, _, _ := net.SplitHostPort(r.RemoteAddr); ip != "" { + ctx = mctx.Annotate(ctx, logCtxKey("remote_ip"), ip) + } + + r = r.WithContext(ctx) + r = apiutil.SetRequestLogger(r, logger) + h.ServeHTTP(rw, r) + }) + } +} + +type wrappedResponseWriter struct { + http.ResponseWriter + http.Hijacker + statusCode int +} + +func newWrappedResponseWriter(rw http.ResponseWriter) *wrappedResponseWriter { + h, _ := rw.(http.Hijacker) + return &wrappedResponseWriter{ + ResponseWriter: rw, + Hijacker: h, + statusCode: 200, + } +} + +func (rw *wrappedResponseWriter) WriteHeader(statusCode int) { + rw.statusCode = statusCode + rw.ResponseWriter.WriteHeader(statusCode) +} + +func logReqMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + wrw := newWrappedResponseWriter(rw) + + started := time.Now() + h.ServeHTTP(wrw, r) + took := time.Since(started) + + type logCtxKey string + + ctx := r.Context() + ctx = mctx.Annotate(ctx, + logCtxKey("took"), took.String(), + logCtxKey("response_code"), wrw.statusCode, + ) + + apiutil.GetRequestLogger(r).Info(ctx, "handled HTTP request") + }) +} + +func disallowGetMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + // we allow websockets to be GETs because, well, they must be + if r.Method != "GET" || r.Header.Get("Upgrade") == "websocket" { + h.ServeHTTP(rw, r) + return + } + + apiutil.GetRequestLogger(r).WarnString(r.Context(), "method not allowed") + rw.WriteHeader(405) + }) +} + +type cacheResponseWriter struct { + *wrappedResponseWriter + buf *bytes.Buffer +} + +func newCacheResponseWriter(rw http.ResponseWriter) *cacheResponseWriter { + return &cacheResponseWriter{ + wrappedResponseWriter: newWrappedResponseWriter(rw), + buf: new(bytes.Buffer), + } +} + +func (rw *cacheResponseWriter) Write(b []byte) (int, error) { + if _, err := rw.buf.Write(b); err != nil { + panic(err) + } + return rw.wrappedResponseWriter.Write(b) +} + +func cacheMiddleware(cache *lru.Cache) middleware { + + type entry struct { + body []byte + createdAt time.Time + } + + pool := sync.Pool{ + New: func() interface{} { return new(bytes.Reader) }, + } + + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + id := r.URL.RequestURI() + + if val, ok := cache.Get(id); ok { + + entry := val.(entry) + + reader := pool.Get().(*bytes.Reader) + defer pool.Put(reader) + + reader.Reset(entry.body) + + http.ServeContent( + rw, r, filepath.Base(r.URL.Path), entry.createdAt, reader, + ) + return + } + + cacheRW := newCacheResponseWriter(rw) + h.ServeHTTP(cacheRW, r) + + if cacheRW.statusCode == 200 { + cache.Add(id, entry{ + body: cacheRW.buf.Bytes(), + createdAt: time.Now(), + }) + } + }) + } +} + +func purgeCacheOnOKMiddleware(cache *lru.Cache) middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + + wrw := newWrappedResponseWriter(rw) + h.ServeHTTP(wrw, r) + + if wrw.statusCode == 200 { + apiutil.GetRequestLogger(r).Info(r.Context(), "purging cache!") + cache.Purge() + } + }) + } +} -- cgit v1.2.3