143 lines
3.5 KiB
Go
143 lines
3.5 KiB
Go
|
package web
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/google/uuid"
|
||
|
"github.com/gorilla/sessions"
|
||
|
pgx "github.com/jackc/pgx/v5"
|
||
|
echo "github.com/labstack/echo/v4"
|
||
|
"github.com/sirupsen/logrus"
|
||
|
"golang.org/x/oauth2"
|
||
|
|
||
|
"git.janky.solutions/finn/go-project-template/config"
|
||
|
"git.janky.solutions/finn/go-project-template/db"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
sessionName = "session"
|
||
|
contextKeySession = "session"
|
||
|
contextKeyUserID = "user_id"
|
||
|
sessionValueAuthState = "auth_state"
|
||
|
sessionValueAuthUser = "auth_user"
|
||
|
queryParamState = "state"
|
||
|
queryParamCode = "code"
|
||
|
)
|
||
|
|
||
|
var sessionStore *sessions.CookieStore
|
||
|
|
||
|
func authBegin(c echo.Context) error {
|
||
|
session := c.Get(contextKeySession).(*sessions.Session)
|
||
|
|
||
|
state, ok := session.Values[sessionValueAuthState]
|
||
|
if !ok {
|
||
|
state = uuid.New().String()
|
||
|
session.Values[sessionValueAuthState] = state
|
||
|
if err := session.Save(c.Request(), c.Response()); err != nil {
|
||
|
return fmt.Errorf("error saving session: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
postAuthPath := c.QueryParam("dest")
|
||
|
if postAuthPath == "" {
|
||
|
logrus.WithField("params", c.Request().URL.RawQuery).Debug("no post-auth path")
|
||
|
postAuthPath = "/"
|
||
|
}
|
||
|
|
||
|
return c.Redirect(http.StatusFound, config.C.OAuth2.GetConfig(postAuthPath).AuthCodeURL(state.(string)))
|
||
|
}
|
||
|
|
||
|
func authFinish(c echo.Context) error {
|
||
|
session := c.Get(contextKeySession).(*sessions.Session)
|
||
|
|
||
|
postAuthPath := c.QueryParam("dest")
|
||
|
if postAuthPath == "" {
|
||
|
postAuthPath = "/"
|
||
|
}
|
||
|
|
||
|
if c.QueryParam(queryParamState) != session.Values[sessionValueAuthState] {
|
||
|
logrus.WithFields(logrus.Fields{
|
||
|
"provided_state": c.QueryParam(queryParamState),
|
||
|
"expected_state": session.Values[sessionValueAuthState],
|
||
|
}).Debug("unexpected auth state value, restarting auth process")
|
||
|
|
||
|
query := url.Values{}
|
||
|
query.Set("dest", postAuthPath)
|
||
|
return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/login?%s", query.Encode()))
|
||
|
}
|
||
|
|
||
|
ctx := c.Request().Context()
|
||
|
|
||
|
token, err := config.C.OAuth2.GetConfig(postAuthPath).Exchange(ctx, c.QueryParam(queryParamCode))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
userInfo, err := config.C.OAuth2.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
var claims map[string]interface{}
|
||
|
if err := userInfo.Claims(&claims); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
queries, dbc, err := db.Get(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer dbc.Close(ctx)
|
||
|
|
||
|
user, err := queries.GetUserByIDP(ctx, userInfo.Subject)
|
||
|
if err != nil {
|
||
|
if !errors.Is(err, pgx.ErrNoRows) {
|
||
|
return fmt.Errorf("error finding user by subject (sub=%s): %+v", userInfo.Subject, err)
|
||
|
}
|
||
|
|
||
|
user, err = queries.CreateUser(ctx, db.CreateUserParams{
|
||
|
Username: claims["preferred_username"].(string),
|
||
|
IdpSub: userInfo.Subject,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("error adding new user to db: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
session.Values[sessionValueAuthUser] = user.ID
|
||
|
if err := session.Save(c.Request(), c.Response()); err != nil {
|
||
|
return fmt.Errorf("error saving session: %v", err)
|
||
|
}
|
||
|
|
||
|
dest := c.QueryParam("dest")
|
||
|
if dest == "" {
|
||
|
dest = "/"
|
||
|
}
|
||
|
|
||
|
return c.Redirect(http.StatusFound, dest)
|
||
|
}
|
||
|
|
||
|
func authMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
||
|
return func(c echo.Context) error {
|
||
|
if strings.HasPrefix(c.Request().URL.Path, "/auth") {
|
||
|
return next(c)
|
||
|
}
|
||
|
|
||
|
session := c.Get(contextKeySession).(*sessions.Session)
|
||
|
user, ok := session.Values[sessionValueAuthUser]
|
||
|
if !ok {
|
||
|
query := url.Values{}
|
||
|
query.Set("dest", c.Request().URL.Path)
|
||
|
return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/begin?%s", query.Encode()))
|
||
|
}
|
||
|
|
||
|
c.Set(contextKeyUserID, user)
|
||
|
|
||
|
return next(c)
|
||
|
}
|
||
|
}
|