diff --git a/projects/greenlight/cmd/api/errors.go b/projects/greenlight/cmd/api/errors.go index 72ae823..b49a42b 100644 --- a/projects/greenlight/cmd/api/errors.go +++ b/projects/greenlight/cmd/api/errors.go @@ -71,3 +71,8 @@ func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Requ message := "unable to update the record due to an edit conflict, please try again" app.errorResponse(w, r, http.StatusConflict, message) } + +func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) { + message := "rate limit exceeded" + app.errorResponse(w, r, http.StatusTooManyRequests, message) +} diff --git a/projects/greenlight/cmd/api/main.go b/projects/greenlight/cmd/api/main.go index 4208d51..9cb6763 100644 --- a/projects/greenlight/cmd/api/main.go +++ b/projects/greenlight/cmd/api/main.go @@ -4,9 +4,6 @@ import ( "context" "database/sql" "flag" - "fmt" - "log" - "net/http" "os" "time" @@ -34,6 +31,13 @@ type config struct { maxOpenConns int maxIdleConns int maxIdleTime string + } // Add a new limiter struct containing fields for the requests-per-second and burst + // values, and a boolean field which we can use to enable/disable rate limiting + // altogether. + limiter struct { + rps float64 + burst int + enabled bool } } @@ -70,6 +74,12 @@ func main() { flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "PostgreSQL max idle connections") flag.StringVar(&cfg.db.maxIdleTime, "db-max-idle-time", "15m", "PostgreSQL max connection idle time") + // Create command line flags to read the setting values into the config struct. + // Notice that we use true as the default for the 'enabled' setting? + flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second") + flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst") + flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiter") + flag.Parse() // Call the openDB() helper function (see below) to create the connection pool, @@ -97,22 +107,10 @@ func main() { models: data.NewModels(db), } - // Use the httprouter instance returned by app.routes() as the server handler. - srv := &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.port), - Handler: app.routes(), - IdleTimeout: time.Minute, - ErrorLog: log.New(logger, "", 0), - ReadTimeout: 10 * time.Second, - WriteTimeout: 30 * time.Second, + err = app.serve() + if err != nil { + logger.PrintFatal(err, nil) } - - logger.PrintInfo("starting server", map[string]string{ - "addr": srv.Addr, - "env": cfg.env, - }) - err = srv.ListenAndServe() - logger.PrintFatal(err, nil) } func openDB(cfg config) (*sql.DB, error) { diff --git a/projects/greenlight/cmd/api/middleware.go b/projects/greenlight/cmd/api/middleware.go new file mode 100644 index 0000000..ef30092 --- /dev/null +++ b/projects/greenlight/cmd/api/middleware.go @@ -0,0 +1,109 @@ +package main + +import ( + "fmt" + "net" + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" +) + +func (app *application) recoverPanic(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create a deferred function (which will always be run in the event of a panic + // as Go unwinds the stack). + defer func() { + // Use the builtin recover function to check if there has been a panic or + // not. + if err := recover(); err != nil { + // If there was a panic, set a "Connection: close" header on the + // response. This acts as a trigger to make Go's HTTP server + // automatically close the current connection after a response has been + // sent. + w.Header().Set("Connection", "close") + // The value returned by recover() has the type interface{}, so we use + // fmt.Errorf() to normalize it into an error and call our + // serverErrorResponse() helper. In turn, this will log the error using + // our custom Logger type at the ERROR level and send the client a 500 + // Internal Server Error response. + app.serverErrorResponse(w, r, fmt.Errorf("%s", err)) + } + }() + + next.ServeHTTP(w, r) + }) +} + +func (app *application) rateLimit(next http.Handler) http.Handler { + // Define a client struct to hold the rate limiter and last seen time for each + // client. + type client struct { + limiter *rate.Limiter + lastSeen time.Time + } + + var ( + mu sync.Mutex + // Update the map so the values are pointers to a client struct. + clients = make(map[string]*client) + ) + + // Launch a background goroutine which removes old entries from the + // clients map once + // every minute. + go func() { + for { + time.Sleep(time.Minute) + + // Lock the mutex to prevent any rate limiter checks from happening while + // the cleanup is taking place. + mu.Lock() + + // Loop through all clients. If they haven't been seen within the last three + // minutes, delete the corresponding entry from the map. + for ip, client := range clients { + if time.Since(client.lastSeen) > 3*time.Minute { + delete(clients, ip) + } + } + + // Importantly, unlock the mutex when the cleanup is complete. + mu.Unlock() + } + }() + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only carry out the check if rate limiting is enabled. + if app.config.limiter.enabled { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + mu.Lock() + + if _, found := clients[ip]; !found { + clients[ip] = &client{ + // Use the requests-per-second and burst values from the config + // struct. + limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst), + } + } + + clients[ip].lastSeen = time.Now() + + if !clients[ip].limiter.Allow() { + mu.Unlock() + app.rateLimitExceededResponse(w, r) + return + } + + mu.Unlock() + } + + next.ServeHTTP(w, r) + }) +} diff --git a/projects/greenlight/cmd/api/midlleware.go b/projects/greenlight/cmd/api/midlleware.go deleted file mode 100644 index 14ba557..0000000 --- a/projects/greenlight/cmd/api/midlleware.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "fmt" - "net/http" -) - -func (app *application) recoverPanic(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Create a deferred function (which will always be run in the event of a panic - // as Go unwinds the stack). - defer func() { - // Use the builtin recover function to check if there has been a panic or - // not. - if err := recover(); err != nil { - // If there was a panic, set a "Connection: close" header on the - // response. This acts as a trigger to make Go's HTTP server - // automatically close the current connection after a response has been - // sent. - w.Header().Set("Connection", "close") - // The value returned by recover() has the type interface{}, so we use - // fmt.Errorf() to normalize it into an error and call our - // serverErrorResponse() helper. In turn, this will log the error using - // our custom Logger type at the ERROR level and send the client a 500 - // Internal Server Error response. - app.serverErrorResponse(w, r, fmt.Errorf("%s", err)) - } - }() - - next.ServeHTTP(w, r) - }) -} diff --git a/projects/greenlight/cmd/api/routes.go b/projects/greenlight/cmd/api/routes.go index 715e279..6197867 100644 --- a/projects/greenlight/cmd/api/routes.go +++ b/projects/greenlight/cmd/api/routes.go @@ -24,5 +24,6 @@ func (app *application) routes() http.Handler { router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler) router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler) - return app.recoverPanic(router) + // Wrap the router with the rateLimit() middleware. + return app.recoverPanic(app.rateLimit(router)) } diff --git a/projects/greenlight/cmd/api/server.go b/projects/greenlight/cmd/api/server.go new file mode 100644 index 0000000..75f13f2 --- /dev/null +++ b/projects/greenlight/cmd/api/server.go @@ -0,0 +1,79 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +func (app *application) serve() error { + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", app.config.port), + Handler: app.routes(), + IdleTimeout: time.Minute, + ReadTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + } + + // Create a shutdownError channel. We will use this to receive any errors returned + // by the graceful Shutdown() function. + shutdownError := make(chan error) + + go func() { + quit := make(chan os.Signal, 1) + // Intercept the signals, as before. + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + s := <-quit + + // Update the log entry to say "shutting down server" instead of "caught signal". + app.logger.PrintInfo("shutting down server", map[string]string{ + "signal": s.String(), + }) + + // Create a context with a 5-second timeout. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Call Shutdown() on our server, passing in the context we just made. + // Shutdown() will return nil if the graceful shutdown was successful, or an + // error (which may happen because of a problem closing the listeners, or + // because the shutdown didn't complete before the 5-second context deadline is + // hit). We relay this return value to the shutdownError channel. + shutdownError <- srv.Shutdown(ctx) + }() + + app.logger.PrintInfo("starting server", map[string]string{ + "addr": srv.Addr, + "env": app.config.env, + }) + + // Calling Shutdown() on our server will cause ListenAndServe() to immediately + // return a http.ErrServerClosed error. So if we see this error, it is actually a + // good thing and an indication that the graceful shutdown has started. So we check + // specifically for this, only returning the error if it is NOT http.ErrServerClosed. + err := srv.ListenAndServe() + if !errors.Is(err, http.ErrServerClosed) { + return err + } + + // Otherwise, we wait to receive the return value from Shutdown() on the + // shutdownError channel. If return value is an error, we know that there was a + // problem with the graceful shutdown and we return the error. + err = <-shutdownError + if err != nil { + return err + } + + // At this point we know that the graceful shutdown completed successfully and we + // log a "stopped server" message. + app.logger.PrintInfo("stopped server", map[string]string{ + "addr": srv.Addr, + }) + + return nil +} diff --git a/projects/greenlight/go.mod b/projects/greenlight/go.mod index 02cb525..4d2c305 100644 --- a/projects/greenlight/go.mod +++ b/projects/greenlight/go.mod @@ -3,7 +3,9 @@ module greenlight.debuggingjon.dev go 1.25.0 require ( - github.com/joho/godotenv v1.5.1 // indirect - github.com/julienschmidt/httprouter v1.3.0 // indirect - github.com/lib/pq v1.10.0 // indirect + github.com/joho/godotenv v1.5.1 + github.com/julienschmidt/httprouter v1.3.0 + github.com/lib/pq v1.10.0 + golang.org/x/crypto v0.49.0 + golang.org/x/time v0.15.0 ) diff --git a/projects/greenlight/go.sum b/projects/greenlight/go.sum index de6d4bc..9c96815 100644 --- a/projects/greenlight/go.sum +++ b/projects/greenlight/go.sum @@ -4,3 +4,7 @@ github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4d github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E= github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= diff --git a/projects/greenlight/internal/data/models.go b/projects/greenlight/internal/data/models.go index 13674e8..ba1a718 100644 --- a/projects/greenlight/internal/data/models.go +++ b/projects/greenlight/internal/data/models.go @@ -16,6 +16,7 @@ var ( // like a UserModel and PermissionModel, as our build progresses. type Models struct { Movies MovieModel + Users UserModel } // For ease of use, we also add a New() @@ -24,5 +25,6 @@ type Models struct { func NewModels(db *sql.DB) Models { return Models{ Movies: MovieModel{DB: db}, + Users: UserModel{DB: db}, } } diff --git a/projects/greenlight/internal/data/users.go b/projects/greenlight/internal/data/users.go new file mode 100644 index 0000000..57ac2ec --- /dev/null +++ b/projects/greenlight/internal/data/users.go @@ -0,0 +1,215 @@ +package data + +import ( + "context" + "database/sql" + "errors" + "time" + + "golang.org/x/crypto/bcrypt" + "greenlight.debuggingjon.dev/internal/validator" +) + +// Define a custom ErrDuplicateEmail error. +var ( + ErrDuplicateEmail = errors.New("duplicate email") +) + +// Define a User struct to represent an individual user. Importantly, notice how we are +// using the json:"-" struct tag to prevent the Password and Version fields appearing in +// any output when we encode it to JSON. Also notice that the Password field uses the +// custom password type defined below. +type User struct { + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Name string `json:"name"` + Email string `json:"email"` + Password password `json:"-"` + Activated bool `json:"activated"` + Version int `json:"-"` +} + +// Create a custom password type which is a struct containing the plaintext and hashed +// versions of the password for a user. The plaintext field is a *pointer* to a string, +// so that we're able to distinguish between a plaintext password not being present in +// the struct at all, versus a plaintext password which is the empty string "". +type password struct { + plaintext *string + hash []byte +} + +// The Set() method calculates the bcrypt hash of a plaintext password, and stores both +// the hash and the plaintext versions in the struct. +func (p *password) Set(plaintextPassword string) error { + hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12) + if err != nil { + return err + } + + p.plaintext = &plaintextPassword + p.hash = hash + + return nil +} + +// The Matches() method checks whether the provided plaintext password matches the +// hashed password stored in the struct, returning true if it matches and false +// otherwise. +func (p *password) Matches(plaintextPassword string) (bool, error) { + err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword)) + if err != nil { + switch { + case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword): + return false, nil + default: + return false, err + } + } + + return true, nil +} + +func ValidateEmail(v *validator.Validator, email string) { + v.Check(email != "", "email", "must be provided") + v.Check(validator.Matches(email, validator.EmailRX), "email", "must be a valid email address") +} + +func ValidatePasswordPlaintext(v *validator.Validator, password string) { + v.Check(password != "", "password", "must be provided") + v.Check(len(password) >= 8, "password", "must be at least 8 bytes long") + v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long") +} + +func ValidateUser(v *validator.Validator, user *User) { + v.Check(user.Name != "", "name", "must be provided") + v.Check(len(user.Name) <= 500, "name", "must not be more than 500 bytes long") + + // Call the standalone ValidateEmail() helper. + ValidateEmail(v, user.Email) + + // If the plaintext password is not nil, call the standalone + // ValidatePasswordPlaintext() helper. + if user.Password.plaintext != nil { + ValidatePasswordPlaintext(v, *user.Password.plaintext) + } + + // If the password hash is ever nil, this will be due to a logic error in our + // codebase (probably because we forgot to set a password for the user). It's a + // useful sanity check to include here, but it's not a problem with the data + // provided by the client. So rather than adding an error to the validation map we + // raise a panic instead. + if user.Password.hash == nil { + panic("missing password hash for user") + } +} + +// Create a UserModel struct which wraps the connection pool. +type UserModel struct { + DB *sql.DB +} + +// Insert a new record in the database for the user. Note that the id, created_at and +// version fields are all automatically generated by our database, so we use the +// RETURNING clause to read them into the User struct after the insert, in the same way +// that we did when creating a movie. +func (m UserModel) Insert(user *User) error { + query := ` + INSERT INTO users (name, email, password_hash, activated) + VALUES ($1, $2, $3, $4) + RETURNING id, created_at, version` + + args := []any{user.Name, user.Email, user.Password.hash, user.Activated} + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // If the table already contains a record with this email address, then when we try + // to perform the insert there will be a violation of the UNIQUE "users_email_key" + // constraint that we set up in the previous chapter. We check for this error + // specifically, and return custom ErrDuplicateEmail error instead. + err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.CreatedAt, &user.Version) + if err != nil { + switch { + case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: + return ErrDuplicateEmail + default: + return err + } + } + + return nil +} + +// Retrieve the User details from the database based on the user's email address. +// Because we have a UNIQUE constraint on the email column, this SQL query will only +// return one record (or none at all, in which case we return a ErrRecordNotFound error). +func (m UserModel) GetByEmail(email string) (*User, error) { + query := ` + SELECT id, created_at, name, email, password_hash, activated, version + FROM users + WHERE email = $1` + + var user User + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := m.DB.QueryRowContext(ctx, query, email).Scan( + &user.ID, + &user.CreatedAt, + &user.Name, + &user.Email, + &user.Password.hash, + &user.Activated, + &user.Version, + ) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrRecordNotFound + default: + return nil, err + } + } + + return &user, nil +} + +// Update the details for a specific user. Notice that we check against the version +// field to help prevent any race conditions during the request cycle, just like we did +// when updating a movie. And we also check for a violation of the "users_email_key" +// constraint when performing the update, just like we did when inserting the user +// record originally. +func (m UserModel) Update(user *User) error { + query := ` + UPDATE users + SET name = $1, email = $2, password_hash = $3, activated = $4, version = version + 1 + WHERE id = $5 AND version = $6 + RETURNING version` + + args := []any{ + user.Name, + user.Email, + user.Password.hash, + user.Activated, + user.ID, + user.Version, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version) + if err != nil { + switch { + case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`: + return ErrDuplicateEmail + case errors.Is(err, sql.ErrNoRows): + return ErrEditConflict + default: + return err + } + } + + return nil +} diff --git a/projects/greenlight/migrations/000004_create_users_table.down.sql b/projects/greenlight/migrations/000004_create_users_table.down.sql new file mode 100644 index 0000000..2830ac2 --- /dev/null +++ b/projects/greenlight/migrations/000004_create_users_table.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS users; + diff --git a/projects/greenlight/migrations/000004_create_users_table.up.sql b/projects/greenlight/migrations/000004_create_users_table.up.sql new file mode 100644 index 0000000..8281daf --- /dev/null +++ b/projects/greenlight/migrations/000004_create_users_table.up.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS users ( + id bigserial PRIMARY KEY, + created_at timestamp(0) with time zone NOT NULL DEFAULT NOW(), + name text NOT NULL, + email citext UNIQUE NOT NULL, + password_hash bytea NOT NULL, + activated bool NOT NULL, + version integer NOT NULL DEFAULT 1 +); +