diff --git a/cmd/go-project-template/main.go b/cmd/go-project-template/main.go index d47cec1..c69306e 100644 --- a/cmd/go-project-template/main.go +++ b/cmd/go-project-template/main.go @@ -9,7 +9,7 @@ import ( "git.janky.solutions/finn/go-project-template/config" "git.janky.solutions/finn/go-project-template/db" - "git.janky.solutions/finn/go-project-template/httpserver" + "git.janky.solutions/finn/go-project-template/web" ) func main() { @@ -17,10 +17,9 @@ func main() { } func run() { - if err := config.Load(); err != nil { + if err := config.Load(context.Background()); err != nil { logrus.WithError(err).Fatal("error loading config") } - logrus.SetLevel(logrus.DebugLevel) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) defer cancel() @@ -29,11 +28,11 @@ func run() { logrus.WithError(err).Fatal("error migrating database") } - go httpserver.ListenAndServe() + go web.ListenAndServe() <-ctx.Done() - if err := httpserver.Shutdown(ctx); err != nil { + if err := web.Shutdown(ctx); err != nil { logrus.WithError(err).Error("error shutting down web server") } } diff --git a/config/config.go b/config/config.go index c51ac05..6373902 100644 --- a/config/config.go +++ b/config/config.go @@ -1,22 +1,42 @@ package config import ( + "context" "encoding/json" + "errors" + "fmt" "os" + "time" "github.com/sirupsen/logrus" + "golang.org/x/exp/rand" ) +func init() { + rand.Seed(uint64(time.Now().UnixNano())) +} + type Config struct { Database string - HTTPBind string + Web WebConfig + OAuth2 OAuth2 + LogLevel string +} + +type WebConfig struct { + Bind string + SessionKey string + BaseURL string } var C = Config{ - HTTPBind: ":8080", + LogLevel: "INFO", + Web: WebConfig{ + Bind: ":8080", + }, } -func Load() error { +func Load(ctx context.Context) error { logrus.Info("loading config file") f, err := os.Open("go-project-template.json") @@ -30,5 +50,29 @@ func Load() error { return err } + level, err := logrus.ParseLevel(C.LogLevel) + if err != nil { + return fmt.Errorf("error parsing requested log level %s: %v", C.LogLevel, err) + } + logrus.SetLevel(level) + + if C.Web.SessionKey == "" { + logrus.Info("No session key specified. Here's a good one if you need it: ", generateSessionKey()) + return errors.New("session key may not be empty") + } + + if err := C.OAuth2.Load(ctx); err != nil { + return err + } return nil } + +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") + +func generateSessionKey() string { + b := make([]rune, 64) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} diff --git a/config/oauth.go b/config/oauth.go new file mode 100644 index 0000000..d6fbc58 --- /dev/null +++ b/config/oauth.go @@ -0,0 +1,60 @@ +package config + +import ( + "context" + "fmt" + "net/url" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +type OAuth2 struct { + ClientID string + ClientSecret string + ProviderURL string + Scopes []string + + provider *oidc.Provider +} + +func (o *OAuth2) Load(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + + provider, err := oidc.NewProvider(ctx, o.ProviderURL) + if err != nil { + return err + } + + hasOpenIDScope := false + for _, scope := range o.Scopes { + if scope == oidc.ScopeOpenID { + hasOpenIDScope = true + break + } + } + if !hasOpenIDScope { + o.Scopes = append(o.Scopes, oidc.ScopeOpenID) + } + + o.provider = provider + return nil +} + +func (o OAuth2) GetConfig(postAuthPath string) *oauth2.Config { + params := url.Values{} + params.Add("dest", postAuthPath) + return &oauth2.Config{ + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + RedirectURL: fmt.Sprintf("%s/auth/finish", C.Web.BaseURL), + Endpoint: o.provider.Endpoint(), + Scopes: append([]string{oidc.ScopeOpenID}, o.Scopes...), + } +} + +func (o OAuth2) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*oidc.UserInfo, error) { + return o.provider.UserInfo(ctx, tokenSource) +} diff --git a/db/example.sql.go b/db/example.sql.go deleted file mode 100644 index a3943c6..0000000 --- a/db/example.sql.go +++ /dev/null @@ -1,63 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.20.0 -// source: example.sql - -package db - -import ( - "context" -) - -const exampleCreate = `-- name: ExampleCreate :exec -INSERT INTO example (name) VALUES ($1) -` - -func (q *Queries) ExampleCreate(ctx context.Context, name string) error { - _, err := q.db.Exec(ctx, exampleCreate, name) - return err -} - -const exampleDelete = `-- name: ExampleDelete :exec -DELETE FROM example WHERE id = $1 -` - -func (q *Queries) ExampleDelete(ctx context.Context, id int64) error { - _, err := q.db.Exec(ctx, exampleDelete, id) - return err -} - -const exampleGetAll = `-- name: ExampleGetAll :many -SELECT id, name FROM example -` - -func (q *Queries) ExampleGetAll(ctx context.Context) ([]Example, error) { - rows, err := q.db.Query(ctx, exampleGetAll) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Example - for rows.Next() { - var i Example - if err := rows.Scan(&i.ID, &i.Name); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const exampleGetByID = `-- name: ExampleGetByID :one -SELECT id, name FROM example WHERE id = $1 -` - -func (q *Queries) ExampleGetByID(ctx context.Context, id int64) (Example, error) { - row := q.db.QueryRow(ctx, exampleGetByID, id) - var i Example - err := row.Scan(&i.ID, &i.Name) - return i, err -} diff --git a/db/migrations/001_init.sql b/db/migrations/001_init.sql index ba72f16..0603684 100644 --- a/db/migrations/001_init.sql +++ b/db/migrations/001_init.sql @@ -1,12 +1,13 @@ -- +goose Up -- +goose StatementBegin -CREATE TABLE example ( - id BIGSERIAL PRIMARY KEY, - name TEXT NOT NULL -) +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + idp_sub VARCHAR(36) UNIQUE NOT NULL, -- is this long enough for other IDPs? Keycloak gives a UUID + username VARCHAR(50) NOT NULL +); -- +goose StatementEnd -- +goose Down -- +goose StatementBegin -DROP TABLE example; +DROP TABLE users; -- +goose StatementEnd diff --git a/db/models.go b/db/models.go index e0f1ae7..d5c4d95 100644 --- a/db/models.go +++ b/db/models.go @@ -6,7 +6,8 @@ package db import () -type Example struct { - ID int64 - Name string +type User struct { + ID int32 + IdpSub string + Username string } diff --git a/db/queries/example.sql b/db/queries/example.sql deleted file mode 100644 index b989a31..0000000 --- a/db/queries/example.sql +++ /dev/null @@ -1,11 +0,0 @@ --- name: ExampleCreate :exec -INSERT INTO example (name) VALUES ($1); - --- name: ExampleGetByID :one -SELECT * FROM example WHERE id = $1; - --- name: ExampleGetAll :many -SELECT * FROM example; - --- name: ExampleDelete :exec -DELETE FROM example WHERE id = $1; diff --git a/db/queries/users.sql b/db/queries/users.sql new file mode 100644 index 0000000..5568f0e --- /dev/null +++ b/db/queries/users.sql @@ -0,0 +1,8 @@ +-- name: GetUser :one +SELECT * FROM users WHERE id = $1 LIMIT 1; + +-- name: GetUserByIDP :one +SELECT * FROM users WHERE idp_sub = $1 LIMIT 1; + +-- name: CreateUser :one +INSERT INTO users (username, idp_sub) VALUES ($1, $2) RETURNING *; diff --git a/db/users.sql.go b/db/users.sql.go new file mode 100644 index 0000000..3b8acff --- /dev/null +++ b/db/users.sql.go @@ -0,0 +1,48 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: users.sql + +package db + +import ( + "context" +) + +const createUser = `-- name: CreateUser :one +INSERT INTO users (username, idp_sub) VALUES ($1, $2) RETURNING id, idp_sub, username +` + +type CreateUserParams struct { + Username string + IdpSub string +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRow(ctx, createUser, arg.Username, arg.IdpSub) + var i User + err := row.Scan(&i.ID, &i.IdpSub, &i.Username) + return i, err +} + +const getUser = `-- name: GetUser :one +SELECT id, idp_sub, username FROM users WHERE id = $1 LIMIT 1 +` + +func (q *Queries) GetUser(ctx context.Context, id int32) (User, error) { + row := q.db.QueryRow(ctx, getUser, id) + var i User + err := row.Scan(&i.ID, &i.IdpSub, &i.Username) + return i, err +} + +const getUserByIDP = `-- name: GetUserByIDP :one +SELECT id, idp_sub, username FROM users WHERE idp_sub = $1 LIMIT 1 +` + +func (q *Queries) GetUserByIDP(ctx context.Context, idpSub string) (User, error) { + row := q.db.QueryRow(ctx, getUserByIDP, idpSub) + var i User + err := row.Scan(&i.ID, &i.IdpSub, &i.Username) + return i, err +} diff --git a/go-project-template.sample.json b/go-project-template.sample.json index a47587e..bfbde13 100644 --- a/go-project-template.sample.json +++ b/go-project-template.sample.json @@ -1,3 +1,13 @@ { - "database": "postgresql://postgres:password@localhost:5432/postgres" + "loglevel": "debug", + "database": "postgresql://postgres:password@localhost:5432/postgres", + "web": { + "baseurl": "http://localhost:8080", + "sessionKey": "" + }, + "oauth2": { + "ClientID": "", + "ClientSecret": "", + "ProviderURL": "https://my.keycloak.instance/realms/example" + } } diff --git a/go.mod b/go.mod index 370605f..295220c 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,20 @@ module git.janky.solutions/finn/go-project-template go 1.21.8 require ( + github.com/coreos/go-oidc/v3 v3.11.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/sessions v1.3.0 github.com/jackc/pgx/v5 v5.5.5 github.com/labstack/echo/v4 v4.12.0 github.com/pressly/goose/v3 v3.20.0 github.com/sirupsen/logrus v1.9.3 + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 + golang.org/x/oauth2 v0.21.0 ) require ( + github.com/go-jose/go-jose/v4 v4.0.2 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect @@ -21,9 +28,9 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.22.0 // indirect - golang.org/x/net v0.24.0 // indirect + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/net v0.27.0 // indirect golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.19.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index dd7d6ab..4ce19d9 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,22 @@ +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg= +github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -49,19 +61,23 @@ github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= +golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/web/auth.go b/web/auth.go new file mode 100644 index 0000000..f03f53a --- /dev/null +++ b/web/auth.go @@ -0,0 +1,142 @@ +package web + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/google/uuid" + "github.com/gorilla/sessions" + pgx "github.com/jackc/pgx/v5" + echo "github.com/labstack/echo/v4" + "github.com/sirupsen/logrus" + "golang.org/x/oauth2" + + "git.janky.solutions/finn/go-project-template/config" + "git.janky.solutions/finn/go-project-template/db" +) + +const ( + sessionName = "session" + contextKeySession = "session" + contextKeyUserID = "user_id" + sessionValueAuthState = "auth_state" + sessionValueAuthUser = "auth_user" + queryParamState = "state" + queryParamCode = "code" +) + +var sessionStore *sessions.CookieStore + +func authBegin(c echo.Context) error { + session := c.Get(contextKeySession).(*sessions.Session) + + state, ok := session.Values[sessionValueAuthState] + if !ok { + state = uuid.New().String() + session.Values[sessionValueAuthState] = state + if err := session.Save(c.Request(), c.Response()); err != nil { + return fmt.Errorf("error saving session: %v", err) + } + } + + postAuthPath := c.QueryParam("dest") + if postAuthPath == "" { + logrus.WithField("params", c.Request().URL.RawQuery).Debug("no post-auth path") + postAuthPath = "/" + } + + return c.Redirect(http.StatusFound, config.C.OAuth2.GetConfig(postAuthPath).AuthCodeURL(state.(string))) +} + +func authFinish(c echo.Context) error { + session := c.Get(contextKeySession).(*sessions.Session) + + postAuthPath := c.QueryParam("dest") + if postAuthPath == "" { + postAuthPath = "/" + } + + if c.QueryParam(queryParamState) != session.Values[sessionValueAuthState] { + logrus.WithFields(logrus.Fields{ + "provided_state": c.QueryParam(queryParamState), + "expected_state": session.Values[sessionValueAuthState], + }).Debug("unexpected auth state value, restarting auth process") + + query := url.Values{} + query.Set("dest", postAuthPath) + return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/login?%s", query.Encode())) + } + + ctx := c.Request().Context() + + token, err := config.C.OAuth2.GetConfig(postAuthPath).Exchange(ctx, c.QueryParam(queryParamCode)) + if err != nil { + return err + } + + userInfo, err := config.C.OAuth2.UserInfo(ctx, oauth2.StaticTokenSource(token)) + if err != nil { + return err + } + + var claims map[string]interface{} + if err := userInfo.Claims(&claims); err != nil { + return err + } + + queries, dbc, err := db.Get(ctx) + if err != nil { + return err + } + defer dbc.Close(ctx) + + user, err := queries.GetUserByIDP(ctx, userInfo.Subject) + if err != nil { + if !errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("error finding user by subject (sub=%s): %+v", userInfo.Subject, err) + } + + user, err = queries.CreateUser(ctx, db.CreateUserParams{ + Username: claims["preferred_username"].(string), + IdpSub: userInfo.Subject, + }) + if err != nil { + return fmt.Errorf("error adding new user to db: %v", err) + } + } + + session.Values[sessionValueAuthUser] = user.ID + if err := session.Save(c.Request(), c.Response()); err != nil { + return fmt.Errorf("error saving session: %v", err) + } + + dest := c.QueryParam("dest") + if dest == "" { + dest = "/" + } + + return c.Redirect(http.StatusFound, dest) +} + +func authMiddleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if strings.HasPrefix(c.Request().URL.Path, "/auth") { + return next(c) + } + + session := c.Get(contextKeySession).(*sessions.Session) + user, ok := session.Values[sessionValueAuthUser] + if !ok { + query := url.Values{} + query.Set("dest", c.Request().URL.Path) + return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/begin?%s", query.Encode())) + } + + c.Set(contextKeyUserID, user) + + return next(c) + } +} diff --git a/httpserver/index.go b/web/index.go similarity index 89% rename from httpserver/index.go rename to web/index.go index 0c929a4..554f96a 100644 --- a/httpserver/index.go +++ b/web/index.go @@ -1,4 +1,4 @@ -package httpserver +package web import ( "net/http" diff --git a/httpserver/server.go b/web/server.go similarity index 72% rename from httpserver/server.go rename to web/server.go index 72e3be9..a89b457 100644 --- a/httpserver/server.go +++ b/web/server.go @@ -1,4 +1,4 @@ -package httpserver +package web import ( "context" @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/gorilla/sessions" echo "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" @@ -16,20 +17,25 @@ import ( var server *echo.Echo func ListenAndServe() { + sessionStore = sessions.NewCookieStore([]byte(config.C.Web.SessionKey)) + server = echo.New() server.HideBanner = true server.HidePort = true server.HTTPErrorHandler = handleError server.Renderer = &Template{} - server.Use(accessLogMiddleware) + server.Use(accessLogMiddleware, sessionMiddleware) server.RouteNotFound("/*", notFoundHandler) - server.GET("/", index) + server.GET("/auth/begin", authBegin) + server.GET("/auth/finish", authFinish) + + server.GET("/", index, authMiddleware) server.StaticFS("/static", Static) - logrus.WithField("address", config.C.HTTPBind).Info("starting http server") - err := server.Start(config.C.HTTPBind) + logrus.WithField("address", config.C.Web.Bind).Info("starting http server") + err := server.Start(config.C.Web.Bind) if err != http.ErrServerClosed { logrus.WithError(err).Fatal("error starting http server") } @@ -85,3 +91,15 @@ func handleError(err error, c echo.Context) { func notFoundHandler(c echo.Context) error { return c.Render(http.StatusNotFound, "404.html", nil) } + +func sessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + session, err := sessionStore.Get(c.Request(), sessionName) + if err != nil { + return err + } + + c.Set(contextKeySession, session) + return next(c) + } +} diff --git a/httpserver/static/main.css b/web/static/main.css similarity index 100% rename from httpserver/static/main.css rename to web/static/main.css diff --git a/httpserver/templates.go b/web/templates.go similarity index 98% rename from httpserver/templates.go rename to web/templates.go index 3f54903..32e4c17 100644 --- a/httpserver/templates.go +++ b/web/templates.go @@ -1,4 +1,4 @@ -package httpserver +package web import ( "embed" diff --git a/httpserver/templates/404.html b/web/templates/404.html similarity index 100% rename from httpserver/templates/404.html rename to web/templates/404.html diff --git a/httpserver/templates/500.html b/web/templates/500.html similarity index 100% rename from httpserver/templates/500.html rename to web/templates/500.html diff --git a/httpserver/templates/base.html b/web/templates/base.html similarity index 100% rename from httpserver/templates/base.html rename to web/templates/base.html diff --git a/httpserver/templates/sample-page.html b/web/templates/sample-page.html similarity index 100% rename from httpserver/templates/sample-page.html rename to web/templates/sample-page.html