Add context parameter to some database functions (#26055)
To avoid deadlock problem, almost database related functions should be have ctx as the first parameter. This PR do a refactor for some of these functions.
This commit is contained in:
parent
c42b71877e
commit
b167f35113
50 changed files with 209 additions and 237 deletions
|
@ -391,10 +391,10 @@ func (a *Action) GetIssueInfos() []string {
|
|||
}
|
||||
|
||||
// GetIssueTitle returns the title of first issue associated
|
||||
// with the action.
|
||||
// with the action. This function will be invoked in template so keep db.DefaultContext here
|
||||
func (a *Action) GetIssueTitle() string {
|
||||
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
|
||||
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
|
||||
issue, err := issues_model.GetIssueByIndex(db.DefaultContext, a.RepoID, index)
|
||||
if err != nil {
|
||||
log.Error("GetIssueByIndex: %v", err)
|
||||
return "500 when get issue"
|
||||
|
@ -404,9 +404,9 @@ func (a *Action) GetIssueTitle() string {
|
|||
|
||||
// GetIssueContent returns the content of first issue associated with
|
||||
// this action.
|
||||
func (a *Action) GetIssueContent() string {
|
||||
func (a *Action) GetIssueContent(ctx context.Context) string {
|
||||
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
|
||||
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
|
||||
issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index)
|
||||
if err != nil {
|
||||
log.Error("GetIssueByIndex: %v", err)
|
||||
return "500 when get issue"
|
||||
|
|
|
@ -47,21 +47,21 @@ type ActivityStats struct {
|
|||
func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) {
|
||||
stats := &ActivityStats{Code: &git.CodeActivityStats{}}
|
||||
if releases {
|
||||
if err := stats.FillReleases(repo.ID, timeFrom); err != nil {
|
||||
if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil {
|
||||
return nil, fmt.Errorf("FillReleases: %w", err)
|
||||
}
|
||||
}
|
||||
if prs {
|
||||
if err := stats.FillPullRequests(repo.ID, timeFrom); err != nil {
|
||||
if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil {
|
||||
return nil, fmt.Errorf("FillPullRequests: %w", err)
|
||||
}
|
||||
}
|
||||
if issues {
|
||||
if err := stats.FillIssues(repo.ID, timeFrom); err != nil {
|
||||
if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil {
|
||||
return nil, fmt.Errorf("FillIssues: %w", err)
|
||||
}
|
||||
}
|
||||
if err := stats.FillUnresolvedIssues(repo.ID, timeFrom, issues, prs); err != nil {
|
||||
if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil {
|
||||
return nil, fmt.Errorf("FillUnresolvedIssues: %w", err)
|
||||
}
|
||||
if code {
|
||||
|
@ -205,41 +205,41 @@ func (stats *ActivityStats) PublishedReleaseCount() int {
|
|||
}
|
||||
|
||||
// FillPullRequests returns pull request information for activity page
|
||||
func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) error {
|
||||
func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error {
|
||||
var err error
|
||||
var count int64
|
||||
|
||||
// Merged pull requests
|
||||
sess := pullRequestsForActivityStatement(repoID, fromTime, true)
|
||||
sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
|
||||
sess.OrderBy("pull_request.merged_unix DESC")
|
||||
stats.MergedPRs = make(issues_model.PullRequestList, 0)
|
||||
if err = sess.Find(&stats.MergedPRs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = stats.MergedPRs.LoadAttributes(); err != nil {
|
||||
if err = stats.MergedPRs.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Merged pull request authors
|
||||
sess = pullRequestsForActivityStatement(repoID, fromTime, true)
|
||||
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.MergedPRAuthorCount = count
|
||||
|
||||
// Opened pull requests
|
||||
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
|
||||
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
|
||||
sess.OrderBy("issue.created_unix ASC")
|
||||
stats.OpenedPRs = make(issues_model.PullRequestList, 0)
|
||||
if err = sess.Find(&stats.OpenedPRs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = stats.OpenedPRs.LoadAttributes(); err != nil {
|
||||
if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Opened pull request authors
|
||||
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
|
||||
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -248,8 +248,8 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e
|
|||
return nil
|
||||
}
|
||||
|
||||
func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", repoID).
|
||||
func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session {
|
||||
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID).
|
||||
Join("INNER", "issue", "pull_request.issue_id = issue.id")
|
||||
|
||||
if merged {
|
||||
|
@ -264,12 +264,12 @@ func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged b
|
|||
}
|
||||
|
||||
// FillIssues returns issue information for activity page
|
||||
func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
|
||||
func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error {
|
||||
var err error
|
||||
var count int64
|
||||
|
||||
// Closed issues
|
||||
sess := issuesForActivityStatement(repoID, fromTime, true, false)
|
||||
sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false)
|
||||
sess.OrderBy("issue.closed_unix DESC")
|
||||
stats.ClosedIssues = make(issues_model.IssueList, 0)
|
||||
if err = sess.Find(&stats.ClosedIssues); err != nil {
|
||||
|
@ -277,14 +277,14 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
|
|||
}
|
||||
|
||||
// Closed issue authors
|
||||
sess = issuesForActivityStatement(repoID, fromTime, true, false)
|
||||
sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.ClosedIssueAuthorCount = count
|
||||
|
||||
// New issues
|
||||
sess = issuesForActivityStatement(repoID, fromTime, false, false)
|
||||
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
|
||||
sess.OrderBy("issue.created_unix ASC")
|
||||
stats.OpenedIssues = make(issues_model.IssueList, 0)
|
||||
if err = sess.Find(&stats.OpenedIssues); err != nil {
|
||||
|
@ -292,7 +292,7 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
|
|||
}
|
||||
|
||||
// Opened issue authors
|
||||
sess = issuesForActivityStatement(repoID, fromTime, false, false)
|
||||
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
|
||||
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -302,12 +302,12 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
|
|||
}
|
||||
|
||||
// FillUnresolvedIssues returns unresolved issue and pull request information for activity page
|
||||
func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Time, issues, prs bool) error {
|
||||
func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error {
|
||||
// Check if we need to select anything
|
||||
if !issues && !prs {
|
||||
return nil
|
||||
}
|
||||
sess := issuesForActivityStatement(repoID, fromTime, false, true)
|
||||
sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true)
|
||||
if !issues || !prs {
|
||||
sess.And("issue.is_pull = ?", prs)
|
||||
}
|
||||
|
@ -316,8 +316,8 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim
|
|||
return sess.Find(&stats.UnresolvedIssues)
|
||||
}
|
||||
|
||||
func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
|
||||
sess := db.GetEngine(db.DefaultContext).Where("issue.repo_id = ?", repoID).
|
||||
func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
|
||||
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
|
||||
And("issue.is_closed = ?", closed)
|
||||
|
||||
if !unresolved {
|
||||
|
@ -336,12 +336,12 @@ func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unreso
|
|||
}
|
||||
|
||||
// FillReleases returns release information for activity page
|
||||
func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error {
|
||||
func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error {
|
||||
var err error
|
||||
var count int64
|
||||
|
||||
// Published releases list
|
||||
sess := releasesForActivityStatement(repoID, fromTime)
|
||||
sess := releasesForActivityStatement(ctx, repoID, fromTime)
|
||||
sess.OrderBy("release.created_unix DESC")
|
||||
stats.PublishedReleases = make([]*repo_model.Release, 0)
|
||||
if err = sess.Find(&stats.PublishedReleases); err != nil {
|
||||
|
@ -349,7 +349,7 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error
|
|||
}
|
||||
|
||||
// Published releases authors
|
||||
sess = releasesForActivityStatement(repoID, fromTime)
|
||||
sess = releasesForActivityStatement(ctx, repoID, fromTime)
|
||||
if _, err = sess.Select("count(distinct release.publisher_id) as `count`").Table("release").Get(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -358,8 +358,8 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session {
|
||||
return db.GetEngine(db.DefaultContext).Where("release.repo_id = ?", repoID).
|
||||
func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
|
||||
return db.GetEngine(ctx).Where("release.repo_id = ?", repoID).
|
||||
And("release.is_draft = ?", false).
|
||||
And("release.created_unix >= ?", fromTime.Unix())
|
||||
}
|
||||
|
|
|
@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// loadAttributes loads all attributes
|
||||
func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
|
||||
// LoadAttributes loads attributes of the comments, except for attachments and
|
||||
// comments
|
||||
func (comments CommentList) LoadAttributes(ctx context.Context) (err error) {
|
||||
if err = comments.LoadPosters(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
|
|||
|
||||
return comments.loadDependentIssues(ctx)
|
||||
}
|
||||
|
||||
// LoadAttributes loads attributes of the comments, except for attachments and
|
||||
// comments
|
||||
func (comments CommentList) LoadAttributes() error {
|
||||
return comments.loadAttributes(db.DefaultContext)
|
||||
}
|
||||
|
|
|
@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
if err = issue.Comments.loadAttributes(ctx); err != nil {
|
||||
if err = issue.Comments.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if issue.IsTimetrackerEnabled(ctx) {
|
||||
|
@ -502,7 +502,7 @@ func (issue *Issue) GetLastEventLabelFake() string {
|
|||
}
|
||||
|
||||
// GetIssueByIndex returns raw issue without loading attributes by index in a repository.
|
||||
func GetIssueByIndex(repoID, index int64) (*Issue, error) {
|
||||
func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
|
||||
if index < 1 {
|
||||
return nil, ErrIssueNotExist{}
|
||||
}
|
||||
|
@ -510,7 +510,7 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
|
|||
RepoID: repoID,
|
||||
Index: index,
|
||||
}
|
||||
has, err := db.GetEngine(db.DefaultContext).Get(issue)
|
||||
has, err := db.GetEngine(ctx).Get(issue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
|
@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
|
|||
}
|
||||
|
||||
// GetIssueWithAttrsByIndex returns issue by index in a repository.
|
||||
func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) {
|
||||
issue, err := GetIssueByIndex(repoID, index)
|
||||
func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
|
||||
issue, err := GetIssueByIndex(ctx, repoID, index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return issue, issue.LoadAttributes(db.DefaultContext)
|
||||
return issue, issue.LoadAttributes(ctx)
|
||||
}
|
||||
|
||||
// GetIssueByID returns an issue by given ID.
|
||||
|
@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = IssueList(issues).LoadAttributes()
|
||||
err = IssueList(issues).LoadAttributes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// loadAttributes loads all attributes, expect for attachments and comments
|
||||
func (issues IssueList) loadAttributes(ctx context.Context) error {
|
||||
func (issues IssueList) LoadAttributes(ctx context.Context) error {
|
||||
if _, err := issues.LoadRepositories(ctx); err != nil {
|
||||
return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err)
|
||||
}
|
||||
|
@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// LoadAttributes loads attributes of the issues, except for attachments and
|
||||
// comments
|
||||
func (issues IssueList) LoadAttributes() error {
|
||||
return issues.loadAttributes(db.DefaultContext)
|
||||
}
|
||||
|
||||
// LoadComments loads comments
|
||||
func (issues IssueList) LoadComments(ctx context.Context) error {
|
||||
return issues.loadComments(ctx, builder.NewCond())
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) {
|
|||
unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}),
|
||||
}
|
||||
|
||||
assert.NoError(t, issueList.LoadAttributes())
|
||||
assert.NoError(t, issueList.LoadAttributes(db.DefaultContext))
|
||||
for _, issue := range issueList {
|
||||
assert.EqualValues(t, issue.RepoID, issue.Repo.ID)
|
||||
for _, label := range issue.Labels {
|
||||
|
|
|
@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) {
|
|||
return nil, fmt.Errorf("unable to query Issues: %w", err)
|
||||
}
|
||||
|
||||
if err := issues.LoadAttributes(); err != nil {
|
||||
if err := issues.LoadAttributes(ctx); err != nil {
|
||||
return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor
|
|||
}
|
||||
|
||||
// GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged
|
||||
func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
|
||||
func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
|
||||
prs := make([]*PullRequest, 0, 2)
|
||||
sess := db.GetEngine(db.DefaultContext).
|
||||
sess := db.GetEngine(ctx).
|
||||
Join("INNER", "issue", "issue.id = pull_request.issue_id").
|
||||
Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub)
|
||||
return prs, sess.Find(&prs)
|
||||
}
|
||||
|
||||
// CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch
|
||||
func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool {
|
||||
func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool {
|
||||
if p.CanWrite(unit.TypeCode) {
|
||||
return true
|
||||
}
|
||||
|
@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *
|
|||
return false
|
||||
}
|
||||
|
||||
prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch)
|
||||
prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pr := range prs {
|
||||
if pr.AllowMaintainerEdit {
|
||||
err = pr.LoadBaseRepo(db.DefaultContext)
|
||||
err = pr.LoadBaseRepo(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user)
|
||||
prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch
|
|||
|
||||
// GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged
|
||||
// by given base information (repo and branch).
|
||||
func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
|
||||
func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
|
||||
prs := make([]*PullRequest, 0, 2)
|
||||
return prs, db.GetEngine(db.DefaultContext).
|
||||
return prs, db.GetEngine(ctx).
|
||||
Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
|
||||
repoID, branch, false, false).
|
||||
OrderBy("issue.updated_unix DESC").
|
||||
|
@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
|
|||
// PullRequestList defines a list of pull requests
|
||||
type PullRequestList []*PullRequest
|
||||
|
||||
func (prs PullRequestList) loadAttributes(ctx context.Context) error {
|
||||
func (prs PullRequestList) LoadAttributes(ctx context.Context) error {
|
||||
if len(prs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 {
|
|||
}
|
||||
return issueIDs
|
||||
}
|
||||
|
||||
// LoadAttributes load all the prs attributes
|
||||
func (prs PullRequestList) LoadAttributes() error {
|
||||
return prs.loadAttributes(db.DefaultContext)
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ func TestHasUnmergedPullRequestsByHeadInfo(t *testing.T) {
|
|||
|
||||
func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(1, "branch2")
|
||||
prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(db.DefaultContext, 1, "branch2")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, prs, 1)
|
||||
for _, pr := range prs {
|
||||
|
@ -159,7 +159,7 @@ func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) {
|
|||
|
||||
func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(1, "master")
|
||||
prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(db.DefaultContext, 1, "master")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, prs, 1)
|
||||
pr := prs[0]
|
||||
|
@ -242,13 +242,13 @@ func TestPullRequestList_LoadAttributes(t *testing.T) {
|
|||
unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}),
|
||||
unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}),
|
||||
}
|
||||
assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes())
|
||||
assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes(db.DefaultContext))
|
||||
for _, pr := range prs {
|
||||
assert.NotNil(t, pr.Issue)
|
||||
assert.Equal(t, pr.IssueID, pr.Issue.ID)
|
||||
}
|
||||
|
||||
assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes())
|
||||
assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes(db.DefaultContext))
|
||||
}
|
||||
|
||||
// TODO TestAddTestPullRequestTask
|
||||
|
|
|
@ -43,11 +43,7 @@ func (t *TrackedTime) AfterLoad() {
|
|||
}
|
||||
|
||||
// LoadAttributes load Issue, User
|
||||
func (t *TrackedTime) LoadAttributes() (err error) {
|
||||
return t.loadAttributes(db.DefaultContext)
|
||||
}
|
||||
|
||||
func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) {
|
||||
func (t *TrackedTime) LoadAttributes(ctx context.Context) (err error) {
|
||||
// Load the issue
|
||||
if t.Issue == nil {
|
||||
t.Issue, err = GetIssueByID(ctx, t.IssueID)
|
||||
|
@ -76,9 +72,9 @@ func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// LoadAttributes load Issue, User
|
||||
func (tl TrackedTimeList) LoadAttributes() error {
|
||||
func (tl TrackedTimeList) LoadAttributes(ctx context.Context) error {
|
||||
for _, t := range tl {
|
||||
if err := t.LoadAttributes(); err != nil {
|
||||
if err := t.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -143,8 +139,8 @@ func GetTrackedTimes(ctx context.Context, options *FindTrackedTimesOptions) (tra
|
|||
}
|
||||
|
||||
// CountTrackedTimes returns count of tracked times that fit to the given options.
|
||||
func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) {
|
||||
sess := db.GetEngine(db.DefaultContext).Where(opts.toCond())
|
||||
func CountTrackedTimes(ctx context.Context, opts *FindTrackedTimesOptions) (int64, error) {
|
||||
sess := db.GetEngine(ctx).Where(opts.toCond())
|
||||
if opts.RepositoryID > 0 || opts.MilestoneID > 0 {
|
||||
sess = sess.Join("INNER", "issue", "issue.id = tracked_time.issue_id")
|
||||
}
|
||||
|
@ -157,8 +153,8 @@ func GetTrackedSeconds(ctx context.Context, opts FindTrackedTimesOptions) (track
|
|||
}
|
||||
|
||||
// AddTime will add the given time (in seconds) to the issue
|
||||
func AddTime(user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
func AddTime(ctx context.Context, user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -276,7 +272,7 @@ func DeleteTime(t *TrackedTime) error {
|
|||
}
|
||||
defer committer.Close()
|
||||
|
||||
if err := t.loadAttributes(ctx); err != nil {
|
||||
if err := t.LoadAttributes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func TestAddTime(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// 3661 = 1h 1min 1s
|
||||
trackedTime, err := issues_model.AddTime(user3, issue1, 3661, time.Now())
|
||||
trackedTime, err := issues_model.AddTime(db.DefaultContext, user3, issue1, 3661, time.Now())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), trackedTime.UserID)
|
||||
assert.Equal(t, int64(1), trackedTime.IssueID)
|
||||
|
|
|
@ -128,8 +128,8 @@ func InsertIssueComments(comments []*issues_model.Comment) error {
|
|||
}
|
||||
|
||||
// InsertPullRequests inserted pull requests
|
||||
func InsertPullRequests(prs ...*issues_model.PullRequest) error {
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
func InsertPullRequests(ctx context.Context, prs ...*issues_model.PullRequest) error {
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ func TestMigrate_InsertPullRequests(t *testing.T) {
|
|||
Issue: i,
|
||||
}
|
||||
|
||||
err := InsertPullRequests(p)
|
||||
err := InsertPullRequests(db.DefaultContext, p)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_ = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{IssueID: i.ID})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue