feat: rate limiter, users table, validation, servers split up, graceful quit, user functions
This commit is contained in:
@@ -71,3 +71,8 @@ func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Requ
|
|||||||
message := "unable to update the record due to an edit conflict, please try again"
|
message := "unable to update the record due to an edit conflict, please try again"
|
||||||
app.errorResponse(w, r, http.StatusConflict, message)
|
app.errorResponse(w, r, http.StatusConflict, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) {
|
||||||
|
message := "rate limit exceeded"
|
||||||
|
app.errorResponse(w, r, http.StatusTooManyRequests, message)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,9 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -34,6 +31,13 @@ type config struct {
|
|||||||
maxOpenConns int
|
maxOpenConns int
|
||||||
maxIdleConns int
|
maxIdleConns int
|
||||||
maxIdleTime string
|
maxIdleTime string
|
||||||
|
} // Add a new limiter struct containing fields for the requests-per-second and burst
|
||||||
|
// values, and a boolean field which we can use to enable/disable rate limiting
|
||||||
|
// altogether.
|
||||||
|
limiter struct {
|
||||||
|
rps float64
|
||||||
|
burst int
|
||||||
|
enabled bool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,6 +74,12 @@ func main() {
|
|||||||
flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "PostgreSQL max idle connections")
|
flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "PostgreSQL max idle connections")
|
||||||
flag.StringVar(&cfg.db.maxIdleTime, "db-max-idle-time", "15m", "PostgreSQL max connection idle time")
|
flag.StringVar(&cfg.db.maxIdleTime, "db-max-idle-time", "15m", "PostgreSQL max connection idle time")
|
||||||
|
|
||||||
|
// Create command line flags to read the setting values into the config struct.
|
||||||
|
// Notice that we use true as the default for the 'enabled' setting?
|
||||||
|
flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limiter maximum requests per second")
|
||||||
|
flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Rate limiter maximum burst")
|
||||||
|
flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiter")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// Call the openDB() helper function (see below) to create the connection pool,
|
// Call the openDB() helper function (see below) to create the connection pool,
|
||||||
@@ -97,23 +107,11 @@ func main() {
|
|||||||
models: data.NewModels(db),
|
models: data.NewModels(db),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the httprouter instance returned by app.routes() as the server handler.
|
err = app.serve()
|
||||||
srv := &http.Server{
|
if err != nil {
|
||||||
Addr: fmt.Sprintf(":%d", cfg.port),
|
|
||||||
Handler: app.routes(),
|
|
||||||
IdleTimeout: time.Minute,
|
|
||||||
ErrorLog: log.New(logger, "", 0),
|
|
||||||
ReadTimeout: 10 * time.Second,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.PrintInfo("starting server", map[string]string{
|
|
||||||
"addr": srv.Addr,
|
|
||||||
"env": cfg.env,
|
|
||||||
})
|
|
||||||
err = srv.ListenAndServe()
|
|
||||||
logger.PrintFatal(err, nil)
|
logger.PrintFatal(err, nil)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func openDB(cfg config) (*sql.DB, error) {
|
func openDB(cfg config) (*sql.DB, error) {
|
||||||
db, err := sql.Open("postgres", cfg.db.dsn)
|
db, err := sql.Open("postgres", cfg.db.dsn)
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (app *application) recoverPanic(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Create a deferred function (which will always be run in the event of a panic
|
||||||
|
// as Go unwinds the stack).
|
||||||
|
defer func() {
|
||||||
|
// Use the builtin recover function to check if there has been a panic or
|
||||||
|
// not.
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
// If there was a panic, set a "Connection: close" header on the
|
||||||
|
// response. This acts as a trigger to make Go's HTTP server
|
||||||
|
// automatically close the current connection after a response has been
|
||||||
|
// sent.
|
||||||
|
w.Header().Set("Connection", "close")
|
||||||
|
// The value returned by recover() has the type interface{}, so we use
|
||||||
|
// fmt.Errorf() to normalize it into an error and call our
|
||||||
|
// serverErrorResponse() helper. In turn, this will log the error using
|
||||||
|
// our custom Logger type at the ERROR level and send the client a 500
|
||||||
|
// Internal Server Error response.
|
||||||
|
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *application) rateLimit(next http.Handler) http.Handler {
|
||||||
|
// Define a client struct to hold the rate limiter and last seen time for each
|
||||||
|
// client.
|
||||||
|
type client struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
lastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
// Update the map so the values are pointers to a client struct.
|
||||||
|
clients = make(map[string]*client)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Launch a background goroutine which removes old entries from the
|
||||||
|
// clients map once
|
||||||
|
// every minute.
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Minute)
|
||||||
|
|
||||||
|
// Lock the mutex to prevent any rate limiter checks from happening while
|
||||||
|
// the cleanup is taking place.
|
||||||
|
mu.Lock()
|
||||||
|
|
||||||
|
// Loop through all clients. If they haven't been seen within the last three
|
||||||
|
// minutes, delete the corresponding entry from the map.
|
||||||
|
for ip, client := range clients {
|
||||||
|
if time.Since(client.lastSeen) > 3*time.Minute {
|
||||||
|
delete(clients, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Importantly, unlock the mutex when the cleanup is complete.
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Only carry out the check if rate limiting is enabled.
|
||||||
|
if app.config.limiter.enabled {
|
||||||
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
app.serverErrorResponse(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
|
||||||
|
if _, found := clients[ip]; !found {
|
||||||
|
clients[ip] = &client{
|
||||||
|
// Use the requests-per-second and burst values from the config
|
||||||
|
// struct.
|
||||||
|
limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clients[ip].lastSeen = time.Now()
|
||||||
|
|
||||||
|
if !clients[ip].limiter.Allow() {
|
||||||
|
mu.Unlock()
|
||||||
|
app.rateLimitExceededResponse(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (app *application) recoverPanic(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Create a deferred function (which will always be run in the event of a panic
|
|
||||||
// as Go unwinds the stack).
|
|
||||||
defer func() {
|
|
||||||
// Use the builtin recover function to check if there has been a panic or
|
|
||||||
// not.
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
// If there was a panic, set a "Connection: close" header on the
|
|
||||||
// response. This acts as a trigger to make Go's HTTP server
|
|
||||||
// automatically close the current connection after a response has been
|
|
||||||
// sent.
|
|
||||||
w.Header().Set("Connection", "close")
|
|
||||||
// The value returned by recover() has the type interface{}, so we use
|
|
||||||
// fmt.Errorf() to normalize it into an error and call our
|
|
||||||
// serverErrorResponse() helper. In turn, this will log the error using
|
|
||||||
// our custom Logger type at the ERROR level and send the client a 500
|
|
||||||
// Internal Server Error response.
|
|
||||||
app.serverErrorResponse(w, r, fmt.Errorf("%s", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -24,5 +24,6 @@ func (app *application) routes() http.Handler {
|
|||||||
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
|
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
|
||||||
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
|
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
|
||||||
|
|
||||||
return app.recoverPanic(router)
|
// Wrap the router with the rateLimit() middleware.
|
||||||
|
return app.recoverPanic(app.rateLimit(router))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (app *application) serve() error {
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", app.config.port),
|
||||||
|
Handler: app.routes(),
|
||||||
|
IdleTimeout: time.Minute,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a shutdownError channel. We will use this to receive any errors returned
|
||||||
|
// by the graceful Shutdown() function.
|
||||||
|
shutdownError := make(chan error)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
// Intercept the signals, as before.
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
s := <-quit
|
||||||
|
|
||||||
|
// Update the log entry to say "shutting down server" instead of "caught signal".
|
||||||
|
app.logger.PrintInfo("shutting down server", map[string]string{
|
||||||
|
"signal": s.String(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a context with a 5-second timeout.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Call Shutdown() on our server, passing in the context we just made.
|
||||||
|
// Shutdown() will return nil if the graceful shutdown was successful, or an
|
||||||
|
// error (which may happen because of a problem closing the listeners, or
|
||||||
|
// because the shutdown didn't complete before the 5-second context deadline is
|
||||||
|
// hit). We relay this return value to the shutdownError channel.
|
||||||
|
shutdownError <- srv.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
app.logger.PrintInfo("starting server", map[string]string{
|
||||||
|
"addr": srv.Addr,
|
||||||
|
"env": app.config.env,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Calling Shutdown() on our server will cause ListenAndServe() to immediately
|
||||||
|
// return a http.ErrServerClosed error. So if we see this error, it is actually a
|
||||||
|
// good thing and an indication that the graceful shutdown has started. So we check
|
||||||
|
// specifically for this, only returning the error if it is NOT http.ErrServerClosed.
|
||||||
|
err := srv.ListenAndServe()
|
||||||
|
if !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we wait to receive the return value from Shutdown() on the
|
||||||
|
// shutdownError channel. If return value is an error, we know that there was a
|
||||||
|
// problem with the graceful shutdown and we return the error.
|
||||||
|
err = <-shutdownError
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point we know that the graceful shutdown completed successfully and we
|
||||||
|
// log a "stopped server" message.
|
||||||
|
app.logger.PrintInfo("stopped server", map[string]string{
|
||||||
|
"addr": srv.Addr,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -3,7 +3,9 @@ module greenlight.debuggingjon.dev
|
|||||||
go 1.25.0
|
go 1.25.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/joho/godotenv v1.5.1 // indirect
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/julienschmidt/httprouter v1.3.0 // indirect
|
github.com/julienschmidt/httprouter v1.3.0
|
||||||
github.com/lib/pq v1.10.0 // indirect
|
github.com/lib/pq v1.10.0
|
||||||
|
golang.org/x/crypto v0.49.0
|
||||||
|
golang.org/x/time v0.15.0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,3 +4,7 @@ github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4d
|
|||||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||||
github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E=
|
github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E=
|
||||||
github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
|
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||||
|
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||||
|
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||||
|
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ var (
|
|||||||
// like a UserModel and PermissionModel, as our build progresses.
|
// like a UserModel and PermissionModel, as our build progresses.
|
||||||
type Models struct {
|
type Models struct {
|
||||||
Movies MovieModel
|
Movies MovieModel
|
||||||
|
Users UserModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// For ease of use, we also add a New()
|
// For ease of use, we also add a New()
|
||||||
@@ -24,5 +25,6 @@ type Models struct {
|
|||||||
func NewModels(db *sql.DB) Models {
|
func NewModels(db *sql.DB) Models {
|
||||||
return Models{
|
return Models{
|
||||||
Movies: MovieModel{DB: db},
|
Movies: MovieModel{DB: db},
|
||||||
|
Users: UserModel{DB: db},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,215 @@
|
|||||||
|
package data
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"greenlight.debuggingjon.dev/internal/validator"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Define a custom ErrDuplicateEmail error.
|
||||||
|
var (
|
||||||
|
ErrDuplicateEmail = errors.New("duplicate email")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Define a User struct to represent an individual user. Importantly, notice how we are
|
||||||
|
// using the json:"-" struct tag to prevent the Password and Version fields appearing in
|
||||||
|
// any output when we encode it to JSON. Also notice that the Password field uses the
|
||||||
|
// custom password type defined below.
|
||||||
|
type User struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password password `json:"-"`
|
||||||
|
Activated bool `json:"activated"`
|
||||||
|
Version int `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a custom password type which is a struct containing the plaintext and hashed
|
||||||
|
// versions of the password for a user. The plaintext field is a *pointer* to a string,
|
||||||
|
// so that we're able to distinguish between a plaintext password not being present in
|
||||||
|
// the struct at all, versus a plaintext password which is the empty string "".
|
||||||
|
type password struct {
|
||||||
|
plaintext *string
|
||||||
|
hash []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Set() method calculates the bcrypt hash of a plaintext password, and stores both
|
||||||
|
// the hash and the plaintext versions in the struct.
|
||||||
|
func (p *password) Set(plaintextPassword string) error {
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.plaintext = &plaintextPassword
|
||||||
|
p.hash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Matches() method checks whether the provided plaintext password matches the
|
||||||
|
// hashed password stored in the struct, returning true if it matches and false
|
||||||
|
// otherwise.
|
||||||
|
func (p *password) Matches(plaintextPassword string) (bool, error) {
|
||||||
|
err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword))
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
|
||||||
|
return false, nil
|
||||||
|
default:
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateEmail(v *validator.Validator, email string) {
|
||||||
|
v.Check(email != "", "email", "must be provided")
|
||||||
|
v.Check(validator.Matches(email, validator.EmailRX), "email", "must be a valid email address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidatePasswordPlaintext(v *validator.Validator, password string) {
|
||||||
|
v.Check(password != "", "password", "must be provided")
|
||||||
|
v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
|
||||||
|
v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateUser(v *validator.Validator, user *User) {
|
||||||
|
v.Check(user.Name != "", "name", "must be provided")
|
||||||
|
v.Check(len(user.Name) <= 500, "name", "must not be more than 500 bytes long")
|
||||||
|
|
||||||
|
// Call the standalone ValidateEmail() helper.
|
||||||
|
ValidateEmail(v, user.Email)
|
||||||
|
|
||||||
|
// If the plaintext password is not nil, call the standalone
|
||||||
|
// ValidatePasswordPlaintext() helper.
|
||||||
|
if user.Password.plaintext != nil {
|
||||||
|
ValidatePasswordPlaintext(v, *user.Password.plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the password hash is ever nil, this will be due to a logic error in our
|
||||||
|
// codebase (probably because we forgot to set a password for the user). It's a
|
||||||
|
// useful sanity check to include here, but it's not a problem with the data
|
||||||
|
// provided by the client. So rather than adding an error to the validation map we
|
||||||
|
// raise a panic instead.
|
||||||
|
if user.Password.hash == nil {
|
||||||
|
panic("missing password hash for user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a UserModel struct which wraps the connection pool.
|
||||||
|
type UserModel struct {
|
||||||
|
DB *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert a new record in the database for the user. Note that the id, created_at and
|
||||||
|
// version fields are all automatically generated by our database, so we use the
|
||||||
|
// RETURNING clause to read them into the User struct after the insert, in the same way
|
||||||
|
// that we did when creating a movie.
|
||||||
|
func (m UserModel) Insert(user *User) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO users (name, email, password_hash, activated)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
RETURNING id, created_at, version`
|
||||||
|
|
||||||
|
args := []any{user.Name, user.Email, user.Password.hash, user.Activated}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// If the table already contains a record with this email address, then when we try
|
||||||
|
// to perform the insert there will be a violation of the UNIQUE "users_email_key"
|
||||||
|
// constraint that we set up in the previous chapter. We check for this error
|
||||||
|
// specifically, and return custom ErrDuplicateEmail error instead.
|
||||||
|
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.CreatedAt, &user.Version)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
|
||||||
|
return ErrDuplicateEmail
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve the User details from the database based on the user's email address.
|
||||||
|
// Because we have a UNIQUE constraint on the email column, this SQL query will only
|
||||||
|
// return one record (or none at all, in which case we return a ErrRecordNotFound error).
|
||||||
|
func (m UserModel) GetByEmail(email string) (*User, error) {
|
||||||
|
query := `
|
||||||
|
SELECT id, created_at, name, email, password_hash, activated, version
|
||||||
|
FROM users
|
||||||
|
WHERE email = $1`
|
||||||
|
|
||||||
|
var user User
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := m.DB.QueryRowContext(ctx, query, email).Scan(
|
||||||
|
&user.ID,
|
||||||
|
&user.CreatedAt,
|
||||||
|
&user.Name,
|
||||||
|
&user.Email,
|
||||||
|
&user.Password.hash,
|
||||||
|
&user.Activated,
|
||||||
|
&user.Version,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
|
return nil, ErrRecordNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the details for a specific user. Notice that we check against the version
|
||||||
|
// field to help prevent any race conditions during the request cycle, just like we did
|
||||||
|
// when updating a movie. And we also check for a violation of the "users_email_key"
|
||||||
|
// constraint when performing the update, just like we did when inserting the user
|
||||||
|
// record originally.
|
||||||
|
func (m UserModel) Update(user *User) error {
|
||||||
|
query := `
|
||||||
|
UPDATE users
|
||||||
|
SET name = $1, email = $2, password_hash = $3, activated = $4, version = version + 1
|
||||||
|
WHERE id = $5 AND version = $6
|
||||||
|
RETURNING version`
|
||||||
|
|
||||||
|
args := []any{
|
||||||
|
user.Name,
|
||||||
|
user.Email,
|
||||||
|
user.Password.hash,
|
||||||
|
user.Activated,
|
||||||
|
user.ID,
|
||||||
|
user.Version,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
|
||||||
|
return ErrDuplicateEmail
|
||||||
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
|
return ErrEditConflict
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
DROP TABLE IF EXISTS users;
|
||||||
|
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id bigserial PRIMARY KEY,
|
||||||
|
created_at timestamp(0) with time zone NOT NULL DEFAULT NOW(),
|
||||||
|
name text NOT NULL,
|
||||||
|
email citext UNIQUE NOT NULL,
|
||||||
|
password_hash bytea NOT NULL,
|
||||||
|
activated bool NOT NULL,
|
||||||
|
version integer NOT NULL DEFAULT 1
|
||||||
|
);
|
||||||
|
|
||||||
Reference in New Issue
Block a user