diff --git a/models/actions/run_job_list.go b/models/actions/run_job_list.go index 6ea6cb9d3..6c5d3b325 100644 --- a/models/actions/run_job_list.go +++ b/models/actions/run_job_list.go @@ -16,14 +16,9 @@ import ( type ActionJobList []*ActionRunJob func (jobs ActionJobList) GetRunIDs() []int64 { - ids := make(container.Set[int64], len(jobs)) - for _, j := range jobs { - if j.RunID == 0 { - continue - } - ids.Add(j.RunID) - } - return ids.Values() + return container.FilterSlice(jobs, func(j *ActionRunJob) (int64, bool) { + return j.RunID, j.RunID != 0 + }) } func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error { diff --git a/models/actions/run_list.go b/models/actions/run_list.go index 388bfc4f8..4046c7d36 100644 --- a/models/actions/run_list.go +++ b/models/actions/run_list.go @@ -19,19 +19,15 @@ type RunList []*ActionRun // GetUserIDs returns a slice of user's id func (runs RunList) GetUserIDs() []int64 { - ids := make(container.Set[int64], len(runs)) - for _, run := range runs { - ids.Add(run.TriggerUserID) - } - return ids.Values() + return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) { + return run.TriggerUserID, true + }) } func (runs RunList) GetRepoIDs() []int64 { - ids := make(container.Set[int64], len(runs)) - for _, run := range runs { - ids.Add(run.RepoID) - } - return ids.Values() + return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) { + return run.RepoID, true + }) } func (runs RunList) LoadTriggerUser(ctx context.Context) error { diff --git a/models/actions/runner_list.go b/models/actions/runner_list.go index 87f0886b4..5ce69e07a 100644 --- a/models/actions/runner_list.go +++ b/models/actions/runner_list.go @@ -16,14 +16,9 @@ type RunnerList []*ActionRunner // GetUserIDs returns a slice of user's id func (runners RunnerList) GetUserIDs() []int64 { - ids := make(container.Set[int64], len(runners)) - for _, runner := range runners { - if runner.OwnerID == 0 { - continue - } - ids.Add(runner.OwnerID) - } - return ids.Values() + return container.FilterSlice(runners, func(runner *ActionRunner) (int64, bool) { + return runner.OwnerID, runner.OwnerID != 0 + }) } func (runners RunnerList) LoadOwners(ctx context.Context) error { diff --git a/models/actions/schedule_list.go b/models/actions/schedule_list.go index ecd2c1121..5361b9480 100644 --- a/models/actions/schedule_list.go +++ b/models/actions/schedule_list.go @@ -18,19 +18,15 @@ type ScheduleList []*ActionSchedule // GetUserIDs returns a slice of user's id func (schedules ScheduleList) GetUserIDs() []int64 { - ids := make(container.Set[int64], len(schedules)) - for _, schedule := range schedules { - ids.Add(schedule.TriggerUserID) - } - return ids.Values() + return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) { + return schedule.TriggerUserID, true + }) } func (schedules ScheduleList) GetRepoIDs() []int64 { - ids := make(container.Set[int64], len(schedules)) - for _, schedule := range schedules { - ids.Add(schedule.RepoID) - } - return ids.Values() + return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) { + return schedule.RepoID, true + }) } func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error { diff --git a/models/actions/schedule_spec_list.go b/models/actions/schedule_spec_list.go index e9ae268a6..f7dac72f8 100644 --- a/models/actions/schedule_spec_list.go +++ b/models/actions/schedule_spec_list.go @@ -16,11 +16,9 @@ import ( type SpecList []*ActionScheduleSpec func (specs SpecList) GetScheduleIDs() []int64 { - ids := make(container.Set[int64], len(specs)) - for _, spec := range specs { - ids.Add(spec.ScheduleID) - } - return ids.Values() + return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) { + return spec.ScheduleID, true + }) } func (specs SpecList) LoadSchedules(ctx context.Context) error { @@ -46,11 +44,9 @@ func (specs SpecList) LoadSchedules(ctx context.Context) error { } func (specs SpecList) GetRepoIDs() []int64 { - ids := make(container.Set[int64], len(specs)) - for _, spec := range specs { - ids.Add(spec.RepoID) - } - return ids.Values() + return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) { + return spec.RepoID, true + }) } func (specs SpecList) LoadRepos(ctx context.Context) error { diff --git a/models/actions/task_list.go b/models/actions/task_list.go index b07d00b8d..5e17f9144 100644 --- a/models/actions/task_list.go +++ b/models/actions/task_list.go @@ -16,14 +16,9 @@ import ( type TaskList []*ActionTask func (tasks TaskList) GetJobIDs() []int64 { - ids := make(container.Set[int64], len(tasks)) - for _, t := range tasks { - if t.JobID == 0 { - continue - } - ids.Add(t.JobID) - } - return ids.Values() + return container.FilterSlice(tasks, func(t *ActionTask) (int64, bool) { + return t.JobID, t.JobID != 0 + }) } func (tasks TaskList) LoadJobs(ctx context.Context) error { diff --git a/models/activities/action_list.go b/models/activities/action_list.go index fdf0f35d4..6e23b173b 100644 --- a/models/activities/action_list.go +++ b/models/activities/action_list.go @@ -22,11 +22,9 @@ import ( type ActionList []*Action func (actions ActionList) getUserIDs() []int64 { - userIDs := make(container.Set[int64], len(actions)) - for _, action := range actions { - userIDs.Add(action.ActUserID) - } - return userIDs.Values() + return container.FilterSlice(actions, func(action *Action) (int64, bool) { + return action.ActUserID, true + }) } func (actions ActionList) LoadActUsers(ctx context.Context) (map[int64]*user_model.User, error) { @@ -50,11 +48,9 @@ func (actions ActionList) LoadActUsers(ctx context.Context) (map[int64]*user_mod } func (actions ActionList) getRepoIDs() []int64 { - repoIDs := make(container.Set[int64], len(actions)) - for _, action := range actions { - repoIDs.Add(action.RepoID) - } - return repoIDs.Values() + return container.FilterSlice(actions, func(action *Action) (int64, bool) { + return action.RepoID, true + }) } func (actions ActionList) LoadRepositories(ctx context.Context) error { @@ -80,18 +76,16 @@ func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]* userMap = make(map[int64]*user_model.User) } - userSet := make(container.Set[int64], len(actions)) - for _, action := range actions { + missingUserIDs := container.FilterSlice(actions, func(action *Action) (int64, bool) { if action.Repo == nil { - continue + return 0, false } - if _, ok := userMap[action.Repo.OwnerID]; !ok { - userSet.Add(action.Repo.OwnerID) - } - } + _, alreadyLoaded := userMap[action.Repo.OwnerID] + return action.Repo.OwnerID, !alreadyLoaded + }) if err := db.GetEngine(ctx). - In("id", userSet.Values()). + In("id", missingUserIDs). Find(&userMap); err != nil { return fmt.Errorf("find user: %w", err) } diff --git a/models/git/branch_list.go b/models/git/branch_list.go index 8319e5ecd..980bd7b4c 100644 --- a/models/git/branch_list.go +++ b/models/git/branch_list.go @@ -17,15 +17,12 @@ import ( type BranchList []*Branch func (branches BranchList) LoadDeletedBy(ctx context.Context) error { - ids := container.Set[int64]{} - for _, branch := range branches { - if !branch.IsDeleted { - continue - } - ids.Add(branch.DeletedByID) - } + ids := container.FilterSlice(branches, func(branch *Branch) (int64, bool) { + return branch.DeletedByID, branch.IsDeleted + }) + usersMap := make(map[int64]*user_model.User, len(ids)) - if err := db.GetEngine(ctx).In("id", ids.Values()).Find(&usersMap); err != nil { + if err := db.GetEngine(ctx).In("id", ids).Find(&usersMap); err != nil { return err } for _, branch := range branches { @@ -41,14 +38,13 @@ func (branches BranchList) LoadDeletedBy(ctx context.Context) error { } func (branches BranchList) LoadPusher(ctx context.Context) error { - ids := container.Set[int64]{} - for _, branch := range branches { - if branch.PusherID > 0 { // pusher_id maybe zero because some branches are sync by backend with no pusher - ids.Add(branch.PusherID) - } - } + ids := container.FilterSlice(branches, func(branch *Branch) (int64, bool) { + // pusher_id maybe zero because some branches are sync by backend with no pusher + return branch.PusherID, branch.PusherID > 0 + }) + usersMap := make(map[int64]*user_model.User, len(ids)) - if err := db.GetEngine(ctx).In("id", ids.Values()).Find(&usersMap); err != nil { + if err := db.GetEngine(ctx).In("id", ids).Find(&usersMap); err != nil { return err } for _, branch := range branches { diff --git a/models/issues/comment.go b/models/issues/comment.go index 1e962b52f..e4b5ed12c 100644 --- a/models/issues/comment.go +++ b/models/issues/comment.go @@ -1289,10 +1289,9 @@ func InsertIssueComments(ctx context.Context, comments []*Comment) error { return nil } - issueIDs := make(container.Set[int64]) - for _, comment := range comments { - issueIDs.Add(comment.IssueID) - } + issueIDs := container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.IssueID, true + }) ctx, committer, err := db.TxContext(ctx) if err != nil { @@ -1315,7 +1314,7 @@ func InsertIssueComments(ctx context.Context, comments []*Comment) error { } } - for issueID := range issueIDs { + for _, issueID := range issueIDs { if _, err := db.Exec(ctx, "UPDATE issue set num_comments = (SELECT count(*) FROM comment WHERE issue_id = ? AND `type`=?) WHERE id = ?", issueID, CommentTypeComment, issueID); err != nil { return err diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index 347dbe99e..370b5396e 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -17,13 +17,9 @@ import ( type CommentList []*Comment func (comments CommentList) getPosterIDs() []int64 { - posterIDs := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.PosterID > 0 { - posterIDs.Add(comment.PosterID) - } - } - return posterIDs.Values() + return container.FilterSlice(comments, func(c *Comment) (int64, bool) { + return c.PosterID, c.PosterID > 0 + }) } // LoadPosters loads posters @@ -44,13 +40,9 @@ func (comments CommentList) LoadPosters(ctx context.Context) error { } func (comments CommentList) getLabelIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.LabelID > 0 { - ids.Add(comment.LabelID) - } - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.LabelID, comment.LabelID > 0 + }) } func (comments CommentList) loadLabels(ctx context.Context) error { @@ -94,13 +86,9 @@ func (comments CommentList) loadLabels(ctx context.Context) error { } func (comments CommentList) getMilestoneIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.MilestoneID > 0 { - ids.Add(comment.MilestoneID) - } - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.MilestoneID, comment.MilestoneID > 0 + }) } func (comments CommentList) loadMilestones(ctx context.Context) error { @@ -137,13 +125,9 @@ func (comments CommentList) loadMilestones(ctx context.Context) error { } func (comments CommentList) getOldMilestoneIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.OldMilestoneID > 0 { - ids.Add(comment.OldMilestoneID) - } - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.OldMilestoneID, comment.OldMilestoneID > 0 + }) } func (comments CommentList) loadOldMilestones(ctx context.Context) error { @@ -180,13 +164,9 @@ func (comments CommentList) loadOldMilestones(ctx context.Context) error { } func (comments CommentList) getAssigneeIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.AssigneeID > 0 { - ids.Add(comment.AssigneeID) - } - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.AssigneeID, comment.AssigneeID > 0 + }) } func (comments CommentList) loadAssignees(ctx context.Context) error { @@ -237,14 +217,9 @@ func (comments CommentList) loadAssignees(ctx context.Context) error { // getIssueIDs returns all the issue ids on this comment list which issue hasn't been loaded func (comments CommentList) getIssueIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.Issue != nil { - continue - } - ids.Add(comment.IssueID) - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.IssueID, comment.Issue == nil + }) } // Issues returns all the issues of comments @@ -311,16 +286,12 @@ func (comments CommentList) LoadIssues(ctx context.Context) error { } func (comments CommentList) getDependentIssueIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { if comment.DependentIssue != nil { - continue + return 0, false } - if comment.DependentIssueID > 0 { - ids.Add(comment.DependentIssueID) - } - } - return ids.Values() + return comment.DependentIssueID, comment.DependentIssueID > 0 + }) } func (comments CommentList) loadDependentIssues(ctx context.Context) error { @@ -375,13 +346,9 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error { // getAttachmentCommentIDs only return the comment ids which possibly has attachments func (comments CommentList) getAttachmentCommentIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.Type.HasAttachmentSupport() { - ids.Add(comment.ID) - } - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.ID, comment.Type.HasAttachmentSupport() + }) } // LoadAttachmentsByIssue loads attachments by issue id @@ -449,13 +416,9 @@ func (comments CommentList) LoadAttachments(ctx context.Context) (err error) { } func (comments CommentList) getReviewIDs() []int64 { - ids := make(container.Set[int64], len(comments)) - for _, comment := range comments { - if comment.ReviewID > 0 { - ids.Add(comment.ReviewID) - } - } - return ids.Values() + return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { + return comment.ReviewID, comment.ReviewID > 0 + }) } func (comments CommentList) loadReviews(ctx context.Context) error { diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index da55ff1b0..2235f3d3a 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -74,11 +74,9 @@ func (issues IssueList) LoadRepositories(ctx context.Context) (repo_model.Reposi } func (issues IssueList) getPosterIDs() []int64 { - posterIDs := make(container.Set[int64], len(issues)) - for _, issue := range issues { - posterIDs.Add(issue.PosterID) - } - return posterIDs.Values() + return container.FilterSlice(issues, func(issue *Issue) (int64, bool) { + return issue.PosterID, true + }) } func (issues IssueList) loadPosters(ctx context.Context) error { @@ -193,11 +191,9 @@ func (issues IssueList) loadLabels(ctx context.Context) error { } func (issues IssueList) getMilestoneIDs() []int64 { - ids := make(container.Set[int64], len(issues)) - for _, issue := range issues { - ids.Add(issue.MilestoneID) - } - return ids.Values() + return container.FilterSlice(issues, func(issue *Issue) (int64, bool) { + return issue.MilestoneID, true + }) } func (issues IssueList) loadMilestones(ctx context.Context) error { diff --git a/models/issues/reaction.go b/models/issues/reaction.go index d5448636f..eb7faefc7 100644 --- a/models/issues/reaction.go +++ b/models/issues/reaction.go @@ -305,14 +305,12 @@ func (list ReactionList) GroupByType() map[string]ReactionList { } func (list ReactionList) getUserIDs() []int64 { - userIDs := make(container.Set[int64], len(list)) - for _, reaction := range list { + return container.FilterSlice(list, func(reaction *Reaction) (int64, bool) { if reaction.OriginalAuthor != "" { - continue + return 0, false } - userIDs.Add(reaction.UserID) - } - return userIDs.Values() + return reaction.UserID, true + }) } func valuesUser(m map[int64]*user_model.User) []*user_model.User { diff --git a/models/issues/review_list.go b/models/issues/review_list.go index ec6cb0798..7b8c3d319 100644 --- a/models/issues/review_list.go +++ b/models/issues/review_list.go @@ -38,12 +38,11 @@ func (reviews ReviewList) LoadReviewers(ctx context.Context) error { } func (reviews ReviewList) LoadIssues(ctx context.Context) error { - issueIDs := container.Set[int64]{} - for i := 0; i < len(reviews); i++ { - issueIDs.Add(reviews[i].IssueID) - } + issueIDs := container.FilterSlice(reviews, func(review *Review) (int64, bool) { + return review.IssueID, true + }) - issues, err := GetIssuesByIDs(ctx, issueIDs.Values()) + issues, err := GetIssuesByIDs(ctx, issueIDs) if err != nil { return err } diff --git a/models/repo/repo_list.go b/models/repo/repo_list.go index cb7cd47a8..987c7df9b 100644 --- a/models/repo/repo_list.go +++ b/models/repo/repo_list.go @@ -104,18 +104,19 @@ func (repos RepositoryList) LoadAttributes(ctx context.Context) error { return nil } - set := make(container.Set[int64]) + userIDs := container.FilterSlice(repos, func(repo *Repository) (int64, bool) { + return repo.OwnerID, true + }) repoIDs := make([]int64, len(repos)) for i := range repos { - set.Add(repos[i].OwnerID) repoIDs[i] = repos[i].ID } // Load owners. - users := make(map[int64]*user_model.User, len(set)) + users := make(map[int64]*user_model.User, len(userIDs)) if err := db.GetEngine(ctx). Where("id > 0"). - In("id", set.Values()). + In("id", userIDs). Find(&users); err != nil { return fmt.Errorf("find users: %w", err) } diff --git a/modules/container/filter.go b/modules/container/filter.go new file mode 100644 index 000000000..37ec7c3d5 --- /dev/null +++ b/modules/container/filter.go @@ -0,0 +1,21 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package container + +import "slices" + +// FilterSlice ranges over the slice and calls include() for each element. +// If the second returned value is true, the first returned value will be included in the resulting +// slice (after deduplication). +func FilterSlice[E any, T comparable](s []E, include func(E) (T, bool)) []T { + filtered := make([]T, 0, len(s)) // slice will be clipped before returning + seen := make(map[T]bool, len(s)) + for i := range s { + if v, ok := include(s[i]); ok && !seen[v] { + filtered = append(filtered, v) + seen[v] = true + } + } + return slices.Clip(filtered) +} diff --git a/modules/container/filter_test.go b/modules/container/filter_test.go new file mode 100644 index 000000000..ad304e5ab --- /dev/null +++ b/modules/container/filter_test.go @@ -0,0 +1,28 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package container + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilterMapUnique(t *testing.T) { + result := FilterSlice([]int{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + }, func(i int) (int, bool) { + switch i { + case 0: + return 0, true // included later + case 1: + return 0, true // duplicate of previous (should be ignored) + case 2: + return 2, false // not included + default: + return i, true + } + }) + assert.Equal(t, []int{0, 3, 4, 5, 6, 7, 8, 9}, result) +}