feat: rate limiter, users table, validation, servers split up, graceful quit, user functions
This commit is contained in:
@@ -71,3 +71,8 @@ func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Requ
|
||||
message := "unable to update the record due to an edit conflict, please try again"
|
||||
app.errorResponse(w, r, http.StatusConflict, message)
|
||||
}
|
||||
|
||||
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) {
|
||||
message := "rate limit exceeded"
|
||||
app.errorResponse(w, r, http.StatusTooManyRequests, message)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -34,6 +31,13 @@ type config struct {
|
||||
maxOpenConns int
|
||||
maxIdleConns int
|
||||
maxIdleTime string
|
||||
} // Add a new limiter struct containing fields for the requests-per-second and burst
|
||||
// values, and a boolean field which we can use to enable/disable rate limiting
|
||||
// altogether.
|
||||
limiter struct {
|
||||
rps float64
|
||||
burst int
|
||||
enabled bool
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +74,12 @@ func main() {
|
||||
flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "PostgreSQL max idle connections")
|
||||
flag.StringVar(&cfg.db.maxIdleTime, "db-max-idle-time", "15m", "PostgreSQL max connection idle time")
|
||||
|
||||
// Create command line flags to read the setting values into the config struct.
|
||||
// Notice that we use true as the default for the 'enabled' setting?
|
||||
flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second")
|
||||
flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst")
|
||||
flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiter")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// Call the openDB() helper function (see below) to create the connection pool,
|
||||
@@ -97,22 +107,10 @@ func main() {
|
||||
models: data.NewModels(db),
|
||||
}
|
||||
|
||||
// Use the httprouter instance returned by app.routes() as the server handler.
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.port),
|
||||
Handler: app.routes(),
|
||||
IdleTimeout: time.Minute,
|
||||
ErrorLog: log.New(logger, "", 0),
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
err = app.serve()
|
||||
if err != nil {
|
||||
logger.PrintFatal(err, nil)
|
||||
}
|
||||
|
||||
logger.PrintInfo("starting server", map[string]string{
|
||||
"addr": srv.Addr,
|
||||
"env": cfg.env,
|
||||
})
|
||||
err = srv.ListenAndServe()
|
||||
logger.PrintFatal(err, nil)
|
||||
}
|
||||
|
||||
func openDB(cfg config) (*sql.DB, error) {
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func (app *application) recoverPanic(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a deferred function (which will always be run in the event of a panic
|
||||
// as Go unwinds the stack).
|
||||
defer func() {
|
||||
// Use the builtin recover function to check if there has been a panic or
|
||||
// not.
|
||||
if err := recover(); err != nil {
|
||||
// If there was a panic, set a "Connection: close" header on the
|
||||
// response. This acts as a trigger to make Go's HTTP server
|
||||
// automatically close the current connection after a response has been
|
||||
// sent.
|
||||
w.Header().Set("Connection", "close")
|
||||
// The value returned by recover() has the type interface{}, so we use
|
||||
// fmt.Errorf() to normalize it into an error and call our
|
||||
// serverErrorResponse() helper. In turn, this will log the error using
|
||||
// our custom Logger type at the ERROR level and send the client a 500
|
||||
// Internal Server Error response.
|
||||
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) rateLimit(next http.Handler) http.Handler {
|
||||
// Define a client struct to hold the rate limiter and last seen time for each
|
||||
// client.
|
||||
type client struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
// Update the map so the values are pointers to a client struct.
|
||||
clients = make(map[string]*client)
|
||||
)
|
||||
|
||||
// Launch a background goroutine which removes old entries from the
|
||||
// clients map once
|
||||
// every minute.
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
|
||||
// Lock the mutex to prevent any rate limiter checks from happening while
|
||||
// the cleanup is taking place.
|
||||
mu.Lock()
|
||||
|
||||
// Loop through all clients. If they haven't been seen within the last three
|
||||
// minutes, delete the corresponding entry from the map.
|
||||
for ip, client := range clients {
|
||||
if time.Since(client.lastSeen) > 3*time.Minute {
|
||||
delete(clients, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Importantly, unlock the mutex when the cleanup is complete.
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only carry out the check if rate limiting is enabled.
|
||||
if app.config.limiter.enabled {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
app.serverErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
|
||||
if _, found := clients[ip]; !found {
|
||||
clients[ip] = &client{
|
||||
// Use the requests-per-second and burst values from the config
|
||||
// struct.
|
||||
limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst),
|
||||
}
|
||||
}
|
||||
|
||||
clients[ip].lastSeen = time.Now()
|
||||
|
||||
if !clients[ip].limiter.Allow() {
|
||||
mu.Unlock()
|
||||
app.rateLimitExceededResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (app *application) recoverPanic(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a deferred function (which will always be run in the event of a panic
|
||||
// as Go unwinds the stack).
|
||||
defer func() {
|
||||
// Use the builtin recover function to check if there has been a panic or
|
||||
// not.
|
||||
if err := recover(); err != nil {
|
||||
// If there was a panic, set a "Connection: close" header on the
|
||||
// response. This acts as a trigger to make Go's HTTP server
|
||||
// automatically close the current connection after a response has been
|
||||
// sent.
|
||||
w.Header().Set("Connection", "close")
|
||||
// The value returned by recover() has the type interface{}, so we use
|
||||
// fmt.Errorf() to normalize it into an error and call our
|
||||
// serverErrorResponse() helper. In turn, this will log the error using
|
||||
// our custom Logger type at the ERROR level and send the client a 500
|
||||
// Internal Server Error response.
|
||||
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -24,5 +24,6 @@ func (app *application) routes() http.Handler {
|
||||
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
|
||||
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
|
||||
|
||||
return app.recoverPanic(router)
|
||||
// Wrap the router with the rateLimit() middleware.
|
||||
return app.recoverPanic(app.rateLimit(router))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (app *application) serve() error {
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", app.config.port),
|
||||
Handler: app.routes(),
|
||||
IdleTimeout: time.Minute,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Create a shutdownError channel. We will use this to receive any errors returned
|
||||
// by the graceful Shutdown() function.
|
||||
shutdownError := make(chan error)
|
||||
|
||||
go func() {
|
||||
quit := make(chan os.Signal, 1)
|
||||
// Intercept the signals, as before.
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
s := <-quit
|
||||
|
||||
// Update the log entry to say "shutting down server" instead of "caught signal".
|
||||
app.logger.PrintInfo("shutting down server", map[string]string{
|
||||
"signal": s.String(),
|
||||
})
|
||||
|
||||
// Create a context with a 5-second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Call Shutdown() on our server, passing in the context we just made.
|
||||
// Shutdown() will return nil if the graceful shutdown was successful, or an
|
||||
// error (which may happen because of a problem closing the listeners, or
|
||||
// because the shutdown didn't complete before the 5-second context deadline is
|
||||
// hit). We relay this return value to the shutdownError channel.
|
||||
shutdownError <- srv.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
app.logger.PrintInfo("starting server", map[string]string{
|
||||
"addr": srv.Addr,
|
||||
"env": app.config.env,
|
||||
})
|
||||
|
||||
// Calling Shutdown() on our server will cause ListenAndServe() to immediately
|
||||
// return a http.ErrServerClosed error. So if we see this error, it is actually a
|
||||
// good thing and an indication that the graceful shutdown has started. So we check
|
||||
// specifically for this, only returning the error if it is NOT http.ErrServerClosed.
|
||||
err := srv.ListenAndServe()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Otherwise, we wait to receive the return value from Shutdown() on the
|
||||
// shutdownError channel. If return value is an error, we know that there was a
|
||||
// problem with the graceful shutdown and we return the error.
|
||||
err = <-shutdownError
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// At this point we know that the graceful shutdown completed successfully and we
|
||||
// log a "stopped server" message.
|
||||
app.logger.PrintInfo("stopped server", map[string]string{
|
||||
"addr": srv.Addr,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user