package web import ( "errors" "fmt" "net/http" "net/url" "strings" "github.com/google/uuid" "github.com/gorilla/sessions" pgx "github.com/jackc/pgx/v5" echo "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" "golang.org/x/oauth2" "git.janky.solutions/finn/go-project-template/config" "git.janky.solutions/finn/go-project-template/db" ) const ( sessionName = "session" contextKeySession = "session" contextKeyUserID = "user_id" sessionValueAuthState = "auth_state" sessionValueAuthUser = "auth_user" queryParamState = "state" queryParamCode = "code" ) var sessionStore *sessions.CookieStore func authBegin(c echo.Context) error { session := c.Get(contextKeySession).(*sessions.Session) state, ok := session.Values[sessionValueAuthState] if !ok { state = uuid.New().String() session.Values[sessionValueAuthState] = state if err := session.Save(c.Request(), c.Response()); err != nil { return fmt.Errorf("error saving session: %v", err) } } postAuthPath := c.QueryParam("dest") if postAuthPath == "" { logrus.WithField("params", c.Request().URL.RawQuery).Debug("no post-auth path") postAuthPath = "/" } return c.Redirect(http.StatusFound, config.C.OAuth2.GetConfig(postAuthPath).AuthCodeURL(state.(string))) } func authFinish(c echo.Context) error { session := c.Get(contextKeySession).(*sessions.Session) postAuthPath := c.QueryParam("dest") if postAuthPath == "" { postAuthPath = "/" } if c.QueryParam(queryParamState) != session.Values[sessionValueAuthState] { logrus.WithFields(logrus.Fields{ "provided_state": c.QueryParam(queryParamState), "expected_state": session.Values[sessionValueAuthState], }).Debug("unexpected auth state value, restarting auth process") query := url.Values{} query.Set("dest", postAuthPath) return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/login?%s", query.Encode())) } ctx := c.Request().Context() token, err := config.C.OAuth2.GetConfig(postAuthPath).Exchange(ctx, c.QueryParam(queryParamCode)) if err != nil { return err } userInfo, err := config.C.OAuth2.UserInfo(ctx, oauth2.StaticTokenSource(token)) if err != nil { return err } var claims map[string]interface{} if err := userInfo.Claims(&claims); err != nil { return err } queries, dbc, err := db.Get(ctx) if err != nil { return err } defer dbc.Close(ctx) user, err := queries.GetUserByIDP(ctx, userInfo.Subject) if err != nil { if !errors.Is(err, pgx.ErrNoRows) { return fmt.Errorf("error finding user by subject (sub=%s): %+v", userInfo.Subject, err) } user, err = queries.CreateUser(ctx, db.CreateUserParams{ Username: claims["preferred_username"].(string), IdpSub: userInfo.Subject, }) if err != nil { return fmt.Errorf("error adding new user to db: %v", err) } } session.Values[sessionValueAuthUser] = user.ID if err := session.Save(c.Request(), c.Response()); err != nil { return fmt.Errorf("error saving session: %v", err) } dest := c.QueryParam("dest") if dest == "" { dest = "/" } return c.Redirect(http.StatusFound, dest) } func authMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if strings.HasPrefix(c.Request().URL.Path, "/auth") { return next(c) } session := c.Get(contextKeySession).(*sessions.Session) user, ok := session.Values[sessionValueAuthUser] if !ok { query := url.Values{} query.Set("dest", c.Request().URL.Path) return c.Redirect(http.StatusFound, fmt.Sprintf("/auth/begin?%s", query.Encode())) } c.Set(contextKeyUserID, user) return next(c) } }