add auth
and some other improvements
This commit is contained in:
parent
ef43c3af29
commit
4ebe6c7752
21 changed files with 390 additions and 110 deletions
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
"git.janky.solutions/finn/go-project-template/config"
|
||||
"git.janky.solutions/finn/go-project-template/db"
|
||||
"git.janky.solutions/finn/go-project-template/httpserver"
|
||||
"git.janky.solutions/finn/go-project-template/web"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -17,10 +17,9 @@ func main() {
|
|||
}
|
||||
|
||||
func run() {
|
||||
if err := config.Load(); err != nil {
|
||||
if err := config.Load(context.Background()); err != nil {
|
||||
logrus.WithError(err).Fatal("error loading config")
|
||||
}
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
@ -29,11 +28,11 @@ func run() {
|
|||
logrus.WithError(err).Fatal("error migrating database")
|
||||
}
|
||||
|
||||
go httpserver.ListenAndServe()
|
||||
go web.ListenAndServe()
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if err := httpserver.Shutdown(ctx); err != nil {
|
||||
if err := web.Shutdown(ctx); err != nil {
|
||||
logrus.WithError(err).Error("error shutting down web server")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,22 +1,42 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/rand"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Database string
|
||||
HTTPBind string
|
||||
Web WebConfig
|
||||
OAuth2 OAuth2
|
||||
LogLevel string
|
||||
}
|
||||
|
||||
type WebConfig struct {
|
||||
Bind string
|
||||
SessionKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
var C = Config{
|
||||
HTTPBind: ":8080",
|
||||
LogLevel: "INFO",
|
||||
Web: WebConfig{
|
||||
Bind: ":8080",
|
||||
},
|
||||
}
|
||||
|
||||
func Load() error {
|
||||
func Load(ctx context.Context) error {
|
||||
logrus.Info("loading config file")
|
||||
|
||||
f, err := os.Open("go-project-template.json")
|
||||
|
@ -30,5 +50,29 @@ func Load() error {
|
|||
return err
|
||||
}
|
||||
|
||||
level, err := logrus.ParseLevel(C.LogLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing requested log level %s: %v", C.LogLevel, err)
|
||||
}
|
||||
logrus.SetLevel(level)
|
||||
|
||||
if C.Web.SessionKey == "" {
|
||||
logrus.Info("No session key specified. Here's a good one if you need it: ", generateSessionKey())
|
||||
return errors.New("session key may not be empty")
|
||||
}
|
||||
|
||||
if err := C.OAuth2.Load(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
|
||||
|
||||
func generateSessionKey() string {
|
||||
b := make([]rune, 64)
|
||||
for i := range b {
|
||||
b[i] = letterRunes[rand.Intn(len(letterRunes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
|
60
config/oauth.go
Normal file
60
config/oauth.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OAuth2 struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ProviderURL string
|
||||
Scopes []string
|
||||
|
||||
provider *oidc.Provider
|
||||
}
|
||||
|
||||
func (o *OAuth2) Load(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, o.ProviderURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hasOpenIDScope := false
|
||||
for _, scope := range o.Scopes {
|
||||
if scope == oidc.ScopeOpenID {
|
||||
hasOpenIDScope = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenIDScope {
|
||||
o.Scopes = append(o.Scopes, oidc.ScopeOpenID)
|
||||
}
|
||||
|
||||
o.provider = provider
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o OAuth2) GetConfig(postAuthPath string) *oauth2.Config {
|
||||
params := url.Values{}
|
||||
params.Add("dest", postAuthPath)
|
||||
return &oauth2.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%s/auth/finish", C.Web.BaseURL),
|
||||
Endpoint: o.provider.Endpoint(),
|
||||
Scopes: append([]string{oidc.ScopeOpenID}, o.Scopes...),
|
||||
}
|
||||
}
|
||||
|
||||
func (o OAuth2) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
return o.provider.UserInfo(ctx, tokenSource)
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.20.0
|
||||
// source: example.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const exampleCreate = `-- name: ExampleCreate :exec
|
||||
INSERT INTO example (name) VALUES ($1)
|
||||
`
|
||||
|
||||
func (q *Queries) ExampleCreate(ctx context.Context, name string) error {
|
||||
_, err := q.db.Exec(ctx, exampleCreate, name)
|
||||
return err
|
||||
}
|
||||
|
||||
const exampleDelete = `-- name: ExampleDelete :exec
|
||||
DELETE FROM example WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) ExampleDelete(ctx context.Context, id int64) error {
|
||||
_, err := q.db.Exec(ctx, exampleDelete, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const exampleGetAll = `-- name: ExampleGetAll :many
|
||||
SELECT id, name FROM example
|
||||
`
|
||||
|
||||
func (q *Queries) ExampleGetAll(ctx context.Context) ([]Example, error) {
|
||||
rows, err := q.db.Query(ctx, exampleGetAll)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Example
|
||||
for rows.Next() {
|
||||
var i Example
|
||||
if err := rows.Scan(&i.ID, &i.Name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const exampleGetByID = `-- name: ExampleGetByID :one
|
||||
SELECT id, name FROM example WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) ExampleGetByID(ctx context.Context, id int64) (Example, error) {
|
||||
row := q.db.QueryRow(ctx, exampleGetByID, id)
|
||||
var i Example
|
||||
err := row.Scan(&i.ID, &i.Name)
|
||||
return i, err
|
||||
}
|
|
@ -1,12 +1,13 @@
|
|||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
CREATE TABLE example (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL
|
||||
)
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
idp_sub VARCHAR(36) UNIQUE NOT NULL, -- is this long enough for other IDPs? Keycloak gives a UUID
|
||||
username VARCHAR(50) NOT NULL
|
||||
);
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP TABLE example;
|
||||
DROP TABLE users;
|
||||
-- +goose StatementEnd
|
||||
|
|
|
@ -6,7 +6,8 @@ package db
|
|||
|
||||
import ()
|
||||
|
||||
type Example struct {
|
||||
ID int64
|
||||
Name string
|
||||
type User struct {
|
||||
ID int32
|
||||
IdpSub string
|
||||
Username string
|
||||
}
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
-- name: ExampleCreate :exec
|
||||
INSERT INTO example (name) VALUES ($1);
|
||||
|
||||
-- name: ExampleGetByID :one
|
||||
SELECT * FROM example WHERE id = $1;
|
||||
|
||||
-- name: ExampleGetAll :many
|
||||
SELECT * FROM example;
|
||||
|
||||
-- name: ExampleDelete :exec
|
||||
DELETE FROM example WHERE id = $1;
|
8
db/queries/users.sql
Normal file
8
db/queries/users.sql
Normal file
|
@ -0,0 +1,8 @@
|
|||
-- name: GetUser :one
|
||||
SELECT * FROM users WHERE id = $1 LIMIT 1;
|
||||
|
||||
-- name: GetUserByIDP :one
|
||||
SELECT * FROM users WHERE idp_sub = $1 LIMIT 1;
|
||||
|
||||
-- name: CreateUser :one
|
||||
INSERT INTO users (username, idp_sub) VALUES ($1, $2) RETURNING *;
|
48
db/users.sql.go
Normal file
48
db/users.sql.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.20.0
|
||||
// source: users.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (username, idp_sub) VALUES ($1, $2) RETURNING id, idp_sub, username
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
Username string
|
||||
IdpSub string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, createUser, arg.Username, arg.IdpSub)
|
||||
var i User
|
||||
err := row.Scan(&i.ID, &i.IdpSub, &i.Username)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUser = `-- name: GetUser :one
|
||||
SELECT id, idp_sub, username FROM users WHERE id = $1 LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUser(ctx context.Context, id int32) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUser, id)
|
||||
var i User
|
||||
err := row.Scan(&i.ID, &i.IdpSub, &i.Username)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByIDP = `-- name: GetUserByIDP :one
|
||||
SELECT id, idp_sub, username FROM users WHERE idp_sub = $1 LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByIDP(ctx context.Context, idpSub string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByIDP, idpSub)
|
||||
var i User
|
||||
err := row.Scan(&i.ID, &i.IdpSub, &i.Username)
|
||||
return i, err
|
||||
}
|
|
@ -1,3 +1,13 @@
|
|||
{
|
||||
"database": "postgresql://postgres:password@localhost:5432/postgres"
|
||||
"loglevel": "debug",
|
||||
"database": "postgresql://postgres:password@localhost:5432/postgres",
|
||||
"web": {
|
||||
"baseurl": "http://localhost:8080",
|
||||
"sessionKey": ""
|
||||
},
|
||||
"oauth2": {
|
||||
"ClientID": "",
|
||||
"ClientSecret": "",
|
||||
"ProviderURL": "https://my.keycloak.instance/realms/example"
|
||||
}
|
||||
}
|
||||
|
|
15
go.mod
15
go.mod
|
@ -3,13 +3,20 @@ module git.janky.solutions/finn/go-project-template
|
|||
go 1.21.8
|
||||
|
||||
require (
|
||||
github.com/coreos/go-oidc/v3 v3.11.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/labstack/echo/v4 v4.12.0
|
||||
github.com/pressly/goose/v3 v3.20.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8
|
||||
golang.org/x/oauth2 v0.21.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
|
@ -21,9 +28,9 @@ require (
|
|||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.22.0 // indirect
|
||||
golang.org/x/net v0.24.0 // indirect
|
||||
golang.org/x/crypto v0.25.0 // indirect
|
||||
golang.org/x/net v0.27.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.19.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/sys v0.22.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
)
|
||||
|
|
32
go.sum
32
go.sum
|
@ -1,10 +1,22 @@
|
|||
github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI=
|
||||
github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
|
||||
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
|
@ -49,19 +61,23 @@ github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQ
|
|||
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
|
||||
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
|
||||
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
|
||||
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
|
142
web/auth.go
Normal file
142
web/auth.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package web
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
echo "github.com/labstack/echo/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"git.janky.solutions/finn/go-project-template/config"
|
||||
"git.janky.solutions/finn/go-project-template/db"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionName = "session"
|
||||
contextKeySession = "session"
|
||||
contextKeyUserID = "user_id"
|
||||
sessionValueAuthState = "auth_state"
|
||||
sessionValueAuthUser = "auth_user"
|
||||
queryParamState = "state"
|
||||
queryParamCode = "code"
|
||||
)
|
||||
|
||||
var sessionStore *sessions.CookieStore
|
||||
|
||||
func authBegin(c echo.Context) error {
|
||||
session := c.Get(contextKeySession).(*sessions.Session)
|
||||
|
||||
state, ok := session.Values[sessionValueAuthState]
|
||||
if !ok {
|
||||
state = uuid.New().String()
|
||||
session.Values[sessionValueAuthState] = state
|
||||
if err := session.Save(c.Request(), c.Response()); err != nil {
|
||||
return fmt.Errorf("error saving session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
postAuthPath := c.QueryParam("dest")
|
||||
if postAuthPath == "" {
|
||||
logrus.WithField("params", c.Request().URL.RawQuery).Debug("no post-auth path")
|
||||
postAuthPath = "/"
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, config.C.OAuth2.GetConfig(postAuthPath).AuthCodeURL(state.(string)))
|
||||
}
|
||||
|
||||
func authFinish(c echo.Context) error {
|
||||
session := c.Get(contextKeySession).(*sessions.Session)
|
||||
|
||||
postAuthPath := c.QueryParam("dest")
|
||||
if postAuthPath == "" {
|
||||
postAuthPath = "/"
|
||||
}
|
||||
|
||||
if c.QueryParam(queryParamState) != session.Values[sessionValueAuthState] {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"provided_state": c.QueryParam(queryParamState),
|
||||
"expected_state": session.Values[sessionValueAuthState],
|
||||
}).Debug("unexpected auth state value, restarting auth process")
|
||||
|
||||
query := url.Values{}
|
||||
query.Set("dest", postAuthPath)
|
||||
return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/login?%s", query.Encode()))
|
||||
}
|
||||
|
||||
ctx := c.Request().Context()
|
||||
|
||||
token, err := config.C.OAuth2.GetConfig(postAuthPath).Exchange(ctx, c.QueryParam(queryParamCode))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userInfo, err := config.C.OAuth2.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries, dbc, err := db.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dbc.Close(ctx)
|
||||
|
||||
user, err := queries.GetUserByIDP(ctx, userInfo.Subject)
|
||||
if err != nil {
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf("error finding user by subject (sub=%s): %+v", userInfo.Subject, err)
|
||||
}
|
||||
|
||||
user, err = queries.CreateUser(ctx, db.CreateUserParams{
|
||||
Username: claims["preferred_username"].(string),
|
||||
IdpSub: userInfo.Subject,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error adding new user to db: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
session.Values[sessionValueAuthUser] = user.ID
|
||||
if err := session.Save(c.Request(), c.Response()); err != nil {
|
||||
return fmt.Errorf("error saving session: %v", err)
|
||||
}
|
||||
|
||||
dest := c.QueryParam("dest")
|
||||
if dest == "" {
|
||||
dest = "/"
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusFound, dest)
|
||||
}
|
||||
|
||||
func authMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if strings.HasPrefix(c.Request().URL.Path, "/auth") {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
session := c.Get(contextKeySession).(*sessions.Session)
|
||||
user, ok := session.Values[sessionValueAuthUser]
|
||||
if !ok {
|
||||
query := url.Values{}
|
||||
query.Set("dest", c.Request().URL.Path)
|
||||
return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/begin?%s", query.Encode()))
|
||||
}
|
||||
|
||||
c.Set(contextKeyUserID, user)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package httpserver
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
|
@ -1,4 +1,4 @@
|
|||
package httpserver
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
echo "github.com/labstack/echo/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
|
@ -16,20 +17,25 @@ import (
|
|||
var server *echo.Echo
|
||||
|
||||
func ListenAndServe() {
|
||||
sessionStore = sessions.NewCookieStore([]byte(config.C.Web.SessionKey))
|
||||
|
||||
server = echo.New()
|
||||
server.HideBanner = true
|
||||
server.HidePort = true
|
||||
server.HTTPErrorHandler = handleError
|
||||
server.Renderer = &Template{}
|
||||
server.Use(accessLogMiddleware)
|
||||
server.Use(accessLogMiddleware, sessionMiddleware)
|
||||
server.RouteNotFound("/*", notFoundHandler)
|
||||
|
||||
server.GET("/", index)
|
||||
server.GET("/auth/begin", authBegin)
|
||||
server.GET("/auth/finish", authFinish)
|
||||
|
||||
server.GET("/", index, authMiddleware)
|
||||
|
||||
server.StaticFS("/static", Static)
|
||||
|
||||
logrus.WithField("address", config.C.HTTPBind).Info("starting http server")
|
||||
err := server.Start(config.C.HTTPBind)
|
||||
logrus.WithField("address", config.C.Web.Bind).Info("starting http server")
|
||||
err := server.Start(config.C.Web.Bind)
|
||||
if err != http.ErrServerClosed {
|
||||
logrus.WithError(err).Fatal("error starting http server")
|
||||
}
|
||||
|
@ -85,3 +91,15 @@ func handleError(err error, c echo.Context) {
|
|||
func notFoundHandler(c echo.Context) error {
|
||||
return c.Render(http.StatusNotFound, "404.html", nil)
|
||||
}
|
||||
|
||||
func sessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
session, err := sessionStore.Get(c.Request(), sessionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Set(contextKeySession, session)
|
||||
return next(c)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package httpserver
|
||||
package web
|
||||
|
||||
import (
|
||||
"embed"
|
Loading…
Reference in a new issue