Penultimate round of db.DefaultContext refactor (#27414)

Part of #27065

---------

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
This commit is contained in:
JakobDev 2023-10-11 06:24:07 +02:00 committed by GitHub
parent 50166d1f7c
commit ebe803e514
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
136 changed files with 428 additions and 421 deletions

View file

@ -5,6 +5,7 @@
package auth
import (
"context"
"fmt"
"reflect"
@ -199,8 +200,8 @@ func (source *Source) SkipVerify() bool {
// CreateSource inserts a AuthSource in the DB if not already
// existing with the given name.
func CreateSource(source *Source) error {
has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
func CreateSource(ctx context.Context, source *Source) error {
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
if err != nil {
return err
} else if has {
@ -211,7 +212,7 @@ func CreateSource(source *Source) error {
source.IsSyncEnabled = false
}
_, err = db.GetEngine(db.DefaultContext).Insert(source)
_, err = db.GetEngine(ctx).Insert(source)
if err != nil {
return err
}
@ -232,7 +233,7 @@ func CreateSource(source *Source) error {
err = registerableSource.RegisterSource()
if err != nil {
// remove the AuthSource in case of errors while registering configuration
if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
if _, err := db.GetEngine(ctx).Delete(source); err != nil {
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@ -240,33 +241,33 @@ func CreateSource(source *Source) error {
}
// Sources returns a slice of all login sources found in DB.
func Sources() ([]*Source, error) {
func Sources(ctx context.Context) ([]*Source, error) {
auths := make([]*Source, 0, 6)
return auths, db.GetEngine(db.DefaultContext).Find(&auths)
return auths, db.GetEngine(ctx).Find(&auths)
}
// SourcesByType returns all sources of the specified type
func SourcesByType(loginType Type) ([]*Source, error) {
func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}
// AllActiveSources returns all active sources
func AllActiveSources() ([]*Source, error) {
func AllActiveSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 5)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}
// ActiveSources returns all active sources of the specified type
func ActiveSources(tp Type) ([]*Source, error) {
func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@ -274,11 +275,11 @@ func ActiveSources(tp Type) ([]*Source, error) {
// IsSSPIEnabled returns true if there is at least one activated login
// source of type LoginSSPI
func IsSSPIEnabled() bool {
func IsSSPIEnabled(ctx context.Context) bool {
if !db.HasEngine {
return false
}
sources, err := ActiveSources(SSPI)
sources, err := ActiveSources(ctx, SSPI)
if err != nil {
log.Error("ActiveSources: %v", err)
return false
@ -287,7 +288,7 @@ func IsSSPIEnabled() bool {
}
// GetSourceByID returns login source by given ID.
func GetSourceByID(id int64) (*Source, error) {
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
source := new(Source)
if id == 0 {
source.Cfg = registeredConfigs[NoType]()
@ -297,7 +298,7 @@ func GetSourceByID(id int64) (*Source, error) {
return source, nil
}
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
has, err := db.GetEngine(ctx).ID(id).Get(source)
if err != nil {
return nil, err
} else if !has {
@ -307,24 +308,24 @@ func GetSourceByID(id int64) (*Source, error) {
}
// UpdateSource updates a Source record in DB.
func UpdateSource(source *Source) error {
func UpdateSource(ctx context.Context, source *Source) error {
var originalSource *Source
if source.IsOAuth2() {
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
var err error
if originalSource, err = GetSourceByID(source.ID); err != nil {
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
return err
}
}
has, err := db.GetEngine(db.DefaultContext).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
if err != nil {
return err
} else if has {
return ErrSourceAlreadyExist{source.Name}
}
_, err = db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
if err != nil {
return err
}
@ -345,7 +346,7 @@ func UpdateSource(source *Source) error {
err = registerableSource.RegisterSource()
if err != nil {
// restore original values since we cannot update the provider it self
if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@ -353,8 +354,8 @@ func UpdateSource(source *Source) error {
}
// CountSources returns number of login sources.
func CountSources() int64 {
count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
func CountSources(ctx context.Context) int64 {
count, _ := db.GetEngine(ctx).Count(new(Source))
return count
}