diff options
author | Brian Picciano <mediocregopher@gmail.com> | 2022-05-20 14:54:26 -0600 |
---|---|---|
committer | Brian Picciano <mediocregopher@gmail.com> | 2022-05-20 14:54:26 -0600 |
commit | 1ffda21ae38d203e381bedbf7bdbbd69c9031062 (patch) | |
tree | 32b28a8fd92341e69f639b6959bfd04347728494 /srv/src/http | |
parent | ae1fa76efc0d771ca50dede367bd228ce9f7b969 (diff) |
Implement ratelimit on authentications
Diffstat (limited to 'srv/src/http')
-rw-r--r-- | srv/src/http/api.go | 33 | ||||
-rw-r--r-- | srv/src/http/auth.go | 30 | ||||
-rw-r--r-- | srv/src/http/auth_test.go | 11 | ||||
-rw-r--r-- | srv/src/http/posts.go | 28 |
4 files changed, 79 insertions, 23 deletions
diff --git a/srv/src/http/api.go b/srv/src/http/api.go index 4d049ed..85a6375 100644 --- a/srv/src/http/api.go +++ b/srv/src/http/api.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "os" + "time" "github.com/mediocregopher/blog.mediocregopher.com/srv/cfg" "github.com/mediocregopher/blog.mediocregopher.com/srv/chat" @@ -52,6 +53,10 @@ type Params struct { // and the values are the password hash which accompanies those users. The // password hash must have been produced by NewPasswordHash. AuthUsers map[string]string + + // AuthRatelimit indicates how much time must pass between subsequent auth + // attempts. + AuthRatelimit time.Duration } // SetupCfg implement the cfg.Cfger interface. @@ -61,10 +66,20 @@ func (p *Params) SetupCfg(cfg *cfg.Cfg) { httpAuthUsersStr := cfg.String("http-auth-users", "{}", "JSON object with usernames as values and password hashes (produced by the hash-password binary) as values. Denotes users which are able to edit server-side data") + httpAuthRatelimitStr := cfg.String("http-auth-ratelimit", "5s", "Minimum duration which must be waited between subsequent auth attempts") + cfg.OnInit(func(context.Context) error { - if err := json.Unmarshal([]byte(*httpAuthUsersStr), &p.AuthUsers); err != nil { + + err := json.Unmarshal([]byte(*httpAuthUsersStr), &p.AuthUsers) + + if err != nil { return fmt.Errorf("unmarshaling -http-auth-users: %w", err) } + + if p.AuthRatelimit, err = time.ParseDuration(*httpAuthRatelimitStr); err != nil { + return fmt.Errorf("unmarshaling -http-auth-ratelimit: %w", err) + } + return nil }) } @@ -73,6 +88,7 @@ func (p *Params) SetupCfg(cfg *cfg.Cfg) { func (p *Params) Annotate(a mctx.Annotations) { a["listenProto"] = p.ListenProto a["listenAddr"] = p.ListenAddr + a["authRatelimit"] = p.AuthRatelimit } // API will listen on the port configured for it, and serve HTTP requests for @@ -86,6 +102,7 @@ type api struct { srv *http.Server redirectTpl *template.Template + auther Auther } // New initializes and returns a new API instance, including setting up all @@ -105,6 +122,7 @@ func New(params Params) (API, error) { a := &api{ params: params, + auther: NewAuther(params.AuthUsers, params.AuthRatelimit), } a.redirectTpl = a.mustParseTpl("redirect.html") @@ -124,6 +142,7 @@ func New(params Params) (API, error) { } func (a *api) Shutdown(ctx context.Context) error { + defer a.auther.Close() if err := a.srv.Shutdown(ctx); err != nil { return err } @@ -149,8 +168,6 @@ func (a *api) handler() http.Handler { return h } - auther := NewAuther(a.params.AuthUsers) - mux := http.NewServeMux() { @@ -179,13 +196,13 @@ func (a *api) handler() http.Handler { apiutil.MethodMux(map[string]http.Handler{ "GET": a.renderPostHandler(), "EDIT": a.editPostHandler(), - "POST": authMiddleware(auther, + "POST": authMiddleware(a.auther, formMiddleware(a.postPostHandler()), ), - "DELETE": authMiddleware(auther, + "DELETE": authMiddleware(a.auther, formMiddleware(a.deletePostHandler()), ), - "PREVIEW": authMiddleware(auther, + "PREVIEW": authMiddleware(a.auther, formMiddleware(a.previewPostHandler()), ), }), @@ -194,10 +211,10 @@ func (a *api) handler() http.Handler { mux.Handle("/assets/", http.StripPrefix("/assets", apiutil.MethodMux(map[string]http.Handler{ "GET": a.getPostAssetHandler(), - "POST": authMiddleware(auther, + "POST": authMiddleware(a.auther, formMiddleware(a.postPostAssetHandler()), ), - "DELETE": authMiddleware(auther, + "DELETE": authMiddleware(a.auther, formMiddleware(a.deletePostAssetHandler()), ), }), diff --git a/srv/src/http/auth.go b/srv/src/http/auth.go index cd247a3..9527cc8 100644 --- a/srv/src/http/auth.go +++ b/srv/src/http/auth.go @@ -1,7 +1,9 @@ package http import ( + "context" "net/http" + "time" "github.com/mediocregopher/blog.mediocregopher.com/srv/http/apiutil" "golang.org/x/crypto/bcrypt" @@ -19,21 +21,37 @@ func NewPasswordHash(plaintext string) string { // Auther determines who can do what. type Auther interface { - Allowed(username, password string) bool + Allowed(ctx context.Context, username, password string) bool + Close() error } type auther struct { - users map[string]string + users map[string]string + ticker *time.Ticker } // NewAuther initializes and returns an Auther will which allow the given // username and password hash combinations. Password hashes must have been // created using NewPasswordHash. -func NewAuther(users map[string]string) Auther { - return &auther{users: users} +func NewAuther(users map[string]string, ratelimit time.Duration) Auther { + return &auther{ + users: users, + ticker: time.NewTicker(ratelimit), + } +} + +func (a *auther) Close() error { + a.ticker.Stop() + return nil } -func (a *auther) Allowed(username, password string) bool { +func (a *auther) Allowed(ctx context.Context, username, password string) bool { + + select { + case <-ctx.Done(): + return false + case <-a.ticker.C: + } hashedPassword, ok := a.users[username] if !ok { @@ -64,7 +82,7 @@ func authMiddleware(auther Auther, h http.Handler) http.Handler { return } - if !auther.Allowed(username, password) { + if !auther.Allowed(r.Context(), username, password) { respondUnauthorized(rw, r) return } diff --git a/srv/src/http/auth_test.go b/srv/src/http/auth_test.go index 2a1e6e9..9e2d440 100644 --- a/srv/src/http/auth_test.go +++ b/srv/src/http/auth_test.go @@ -1,21 +1,24 @@ package http import ( + "context" "testing" + "time" "github.com/stretchr/testify/assert" ) func TestAuther(t *testing.T) { + ctx := context.Background() password := "foo" hashedPassword := NewPasswordHash(password) auther := NewAuther(map[string]string{ "FOO": hashedPassword, - }) + }, 1*time.Millisecond) - assert.False(t, auther.Allowed("BAR", password)) - assert.False(t, auther.Allowed("FOO", "bar")) - assert.True(t, auther.Allowed("FOO", password)) + assert.False(t, auther.Allowed(ctx, "BAR", password)) + assert.False(t, auther.Allowed(ctx, "FOO", "bar")) + assert.True(t, auther.Allowed(ctx, "FOO", password)) } diff --git a/srv/src/http/posts.go b/srv/src/http/posts.go index 0aea3e3..816e361 100644 --- a/srv/src/http/posts.go +++ b/srv/src/http/posts.go @@ -197,7 +197,7 @@ func (a *api) editPostHandler() http.Handler { }) } -func postFromPostReq(r *http.Request) post.Post { +func postFromPostReq(r *http.Request) (post.Post, error) { p := post.Post{ ID: r.PostFormValue("id"), @@ -207,18 +207,30 @@ func postFromPostReq(r *http.Request) post.Post { Series: r.PostFormValue("series"), } - p.Body = strings.TrimSpace(r.PostFormValue("body")) // textareas encode newlines as CRLF for historical reasons p.Body = strings.ReplaceAll(p.Body, "\r\n", "\n") + p.Body = strings.TrimSpace(r.PostFormValue("body")) + + if p.ID == "" || + p.Title == "" || + p.Description == "" || + p.Body == "" || + len(p.Tags) == 0 { + return post.Post{}, errors.New("ID, Title, Description, Tags, and Body are all required") + } - return p + return p, nil } func (a *api) postPostHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - p := postFromPostReq(r) + p, err := postFromPostReq(r) + if err != nil { + apiutil.BadRequest(rw, r, err) + return + } if err := a.params.PostStore.Set(p, time.Now()); err != nil { apiutil.InternalServerError( @@ -267,8 +279,14 @@ func (a *api) previewPostHandler() http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + p, err := postFromPostReq(r) + if err != nil { + apiutil.BadRequest(rw, r, err) + return + } + storedPost := post.StoredPost{ - Post: postFromPostReq(r), + Post: p, PublishedAt: time.Now(), } |