go-project-template/web/auth.go

143 lines
3.5 KiB
Go
Raw Permalink Normal View History

2024-07-25 05:13:23 +00:00
package web
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/google/uuid"
"github.com/gorilla/sessions"
pgx "github.com/jackc/pgx/v5"
echo "github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"git.janky.solutions/finn/go-project-template/config"
"git.janky.solutions/finn/go-project-template/db"
)
const (
sessionName = "session"
contextKeySession = "session"
contextKeyUserID = "user_id"
sessionValueAuthState = "auth_state"
sessionValueAuthUser = "auth_user"
queryParamState = "state"
queryParamCode = "code"
)
var sessionStore *sessions.CookieStore
func authBegin(c echo.Context) error {
session := c.Get(contextKeySession).(*sessions.Session)
state, ok := session.Values[sessionValueAuthState]
if !ok {
state = uuid.New().String()
session.Values[sessionValueAuthState] = state
if err := session.Save(c.Request(), c.Response()); err != nil {
return fmt.Errorf("error saving session: %v", err)
}
}
postAuthPath := c.QueryParam("dest")
if postAuthPath == "" {
logrus.WithField("params", c.Request().URL.RawQuery).Debug("no post-auth path")
postAuthPath = "/"
}
return c.Redirect(http.StatusFound, config.C.OAuth2.GetConfig(postAuthPath).AuthCodeURL(state.(string)))
}
func authFinish(c echo.Context) error {
session := c.Get(contextKeySession).(*sessions.Session)
postAuthPath := c.QueryParam("dest")
if postAuthPath == "" {
postAuthPath = "/"
}
if c.QueryParam(queryParamState) != session.Values[sessionValueAuthState] {
logrus.WithFields(logrus.Fields{
"provided_state": c.QueryParam(queryParamState),
"expected_state": session.Values[sessionValueAuthState],
}).Debug("unexpected auth state value, restarting auth process")
query := url.Values{}
query.Set("dest", postAuthPath)
return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/login?%s", query.Encode()))
}
ctx := c.Request().Context()
token, err := config.C.OAuth2.GetConfig(postAuthPath).Exchange(ctx, c.QueryParam(queryParamCode))
if err != nil {
return err
}
userInfo, err := config.C.OAuth2.UserInfo(ctx, oauth2.StaticTokenSource(token))
if err != nil {
return err
}
var claims map[string]interface{}
if err := userInfo.Claims(&claims); err != nil {
return err
}
queries, dbc, err := db.Get(ctx)
if err != nil {
return err
}
defer dbc.Close(ctx)
user, err := queries.GetUserByIDP(ctx, userInfo.Subject)
if err != nil {
if !errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("error finding user by subject (sub=%s): %+v", userInfo.Subject, err)
}
user, err = queries.CreateUser(ctx, db.CreateUserParams{
Username: claims["preferred_username"].(string),
IdpSub: userInfo.Subject,
})
if err != nil {
return fmt.Errorf("error adding new user to db: %v", err)
}
}
session.Values[sessionValueAuthUser] = user.ID
if err := session.Save(c.Request(), c.Response()); err != nil {
return fmt.Errorf("error saving session: %v", err)
}
dest := c.QueryParam("dest")
if dest == "" {
dest = "/"
}
return c.Redirect(http.StatusFound, dest)
}
func authMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if strings.HasPrefix(c.Request().URL.Path, "/auth") {
return next(c)
}
session := c.Get(contextKeySession).(*sessions.Session)
user, ok := session.Values[sessionValueAuthUser]
if !ok {
query := url.Values{}
query.Set("dest", c.Request().URL.Path)
return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/begin?%s", query.Encode()))
}
c.Set(contextKeyUserID, user)
return next(c)
}
}