diff --git a/internal/database/comment.go b/internal/database/comment.go index f0155c849..a28246fcc 100644 --- a/internal/database/comment.go +++ b/internal/database/comment.go @@ -7,11 +7,10 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/unknwon/com" - log "unknwon.dev/clog/v2" - "xorm.io/xorm" - api "github.com/gogs/go-gogs-client" + "github.com/unknwon/com" + "gorm.io/gorm" + log "unknwon.dev/clog/v2" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/markup" @@ -50,47 +49,50 @@ type Comment struct { ID int64 Type CommentType PosterID int64 - Poster *User `xorm:"-" json:"-" gorm:"-"` - IssueID int64 `xorm:"INDEX"` - Issue *Issue `xorm:"-" json:"-" gorm:"-"` + Poster *User `gorm:"-" json:"-"` + IssueID int64 `gorm:"index"` + Issue *Issue `gorm:"-" json:"-"` CommitID int64 Line int64 - Content string `xorm:"TEXT"` - RenderedContent string `xorm:"-" json:"-" gorm:"-"` + Content string `gorm:"type:text"` + RenderedContent string `gorm:"-" json:"-"` - Created time.Time `xorm:"-" json:"-" gorm:"-"` + Created time.Time `gorm:"-" json:"-"` CreatedUnix int64 - Updated time.Time `xorm:"-" json:"-" gorm:"-"` + Updated time.Time `gorm:"-" json:"-"` UpdatedUnix int64 // Reference issue in commit message - CommitSHA string `xorm:"VARCHAR(40)"` + CommitSHA string `gorm:"type:varchar(40)"` - Attachments []*Attachment `xorm:"-" json:"-" gorm:"-"` + Attachments []*Attachment `gorm:"-" json:"-"` // For view issue page. - ShowTag CommentTag `xorm:"-" json:"-" gorm:"-"` + ShowTag CommentTag `gorm:"-" json:"-"` } -func (c *Comment) BeforeInsert() { - c.CreatedUnix = time.Now().Unix() - c.UpdatedUnix = c.CreatedUnix -} - -func (c *Comment) BeforeUpdate() { - c.UpdatedUnix = time.Now().Unix() -} - -func (c *Comment) AfterSet(colName string, _ xorm.Cell) { - switch colName { - case "created_unix": - c.Created = time.Unix(c.CreatedUnix, 0).Local() - case "updated_unix": - c.Updated = time.Unix(c.UpdatedUnix, 0).Local() +func (c *Comment) BeforeCreate(tx *gorm.DB) error { + if c.CreatedUnix == 0 { + c.CreatedUnix = tx.NowFunc().Unix() } + if c.UpdatedUnix == 0 { + c.UpdatedUnix = c.CreatedUnix + } + return nil } -func (c *Comment) loadAttributes(e Engine) (err error) { +func (c *Comment) BeforeUpdate(tx *gorm.DB) error { + c.UpdatedUnix = tx.NowFunc().Unix() + return nil +} + +func (c *Comment) AfterFind(tx *gorm.DB) error { + c.Created = time.Unix(c.CreatedUnix, 0).Local() + c.Updated = time.Unix(c.UpdatedUnix, 0).Local() + return nil +} + +func (c *Comment) loadAttributes(tx *gorm.DB) (err error) { if c.Poster == nil { c.Poster, err = Handle.Users().GetByID(context.TODO(), c.PosterID) if err != nil { @@ -104,12 +106,12 @@ func (c *Comment) loadAttributes(e Engine) (err error) { } if c.Issue == nil { - c.Issue, err = getRawIssueByID(e, c.IssueID) + c.Issue, err = getRawIssueByID(tx, c.IssueID) if err != nil { return errors.Newf("getIssueByID [%d]: %v", c.IssueID, err) } if c.Issue.Repo == nil { - c.Issue.Repo, err = getRepositoryByID(e, c.Issue.RepoID) + c.Issue.Repo, err = getRepositoryByID(tx, c.Issue.RepoID) if err != nil { return errors.Newf("getRepositoryByID [%d]: %v", c.Issue.RepoID, err) } @@ -117,7 +119,7 @@ func (c *Comment) loadAttributes(e Engine) (err error) { } if c.Attachments == nil { - c.Attachments, err = getAttachmentsByCommentID(e, c.ID) + c.Attachments, err = getAttachmentsByCommentID(tx, c.ID) if err != nil { return errors.Newf("getAttachmentsByCommentID [%d]: %v", c.ID, err) } @@ -127,7 +129,7 @@ func (c *Comment) loadAttributes(e Engine) (err error) { } func (c *Comment) LoadAttributes() error { - return c.loadAttributes(x) + return c.loadAttributes(db) } func (c *Comment) HTMLURL() string { @@ -163,9 +165,9 @@ func (c *Comment) EventTag() string { // mailParticipants sends new comment emails to repository watchers // and mentioned people. -func (c *Comment) mailParticipants(e Engine, opType ActionType, issue *Issue) (err error) { +func (c *Comment) mailParticipants(tx *gorm.DB, opType ActionType, issue *Issue) (err error) { mentions := markup.FindAllMentions(c.Content) - if err = updateIssueMentions(e, c.IssueID, mentions); err != nil { + if err = updateIssueMentions(tx, c.IssueID, mentions); err != nil { return errors.Newf("UpdateIssueMentions [%d]: %v", c.IssueID, err) } @@ -184,7 +186,7 @@ func (c *Comment) mailParticipants(e Engine, opType ActionType, issue *Issue) (e return nil } -func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err error) { +func createComment(tx *gorm.DB, opts *CreateCommentOptions) (_ *Comment, err error) { comment := &Comment{ Type: opts.Type, PosterID: opts.Doer.ID, @@ -195,7 +197,7 @@ func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err Line: opts.LineNum, Content: opts.Content, } - if _, err = e.Insert(comment); err != nil { + if err = tx.Create(comment).Error; err != nil { return nil, err } @@ -216,14 +218,14 @@ func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err case CommentTypeComment: act.OpType = ActionCommentIssue - if _, err = e.Exec("UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil { + if err = tx.Exec("UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID).Error; err != nil { return nil, err } // Check attachments attachments := make([]*Attachment, 0, len(opts.Attachments)) for _, uuid := range opts.Attachments { - attach, err := getAttachmentByUUID(e, uuid) + attach, err := getAttachmentByUUID(tx, uuid) if err != nil { if IsErrAttachmentNotExist(err) { continue @@ -236,8 +238,10 @@ func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err for i := range attachments { attachments[i].IssueID = opts.Issue.ID attachments[i].CommentID = comment.ID - // No assign value could be 0, so ignore AllCols(). - if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil { + if err = tx.Model(attachments[i]).Where("id = ?", attachments[i].ID).Updates(map[string]any{ + "issue_id": attachments[i].IssueID, + "comment_id": attachments[i].CommentID, + }).Error; err != nil { return nil, errors.Newf("update attachment [%d]: %v", attachments[i].ID, err) } } @@ -249,9 +253,9 @@ func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err } if opts.Issue.IsPull { - _, err = e.Exec("UPDATE `repository` SET num_closed_pulls=num_closed_pulls-1 WHERE id=?", opts.Repo.ID) + err = tx.Exec("UPDATE `repository` SET num_closed_pulls=num_closed_pulls-1 WHERE id=?", opts.Repo.ID).Error } else { - _, err = e.Exec("UPDATE `repository` SET num_closed_issues=num_closed_issues-1 WHERE id=?", opts.Repo.ID) + err = tx.Exec("UPDATE `repository` SET num_closed_issues=num_closed_issues-1 WHERE id=?", opts.Repo.ID).Error } if err != nil { return nil, err @@ -264,38 +268,38 @@ func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err } if opts.Issue.IsPull { - _, err = e.Exec("UPDATE `repository` SET num_closed_pulls=num_closed_pulls+1 WHERE id=?", opts.Repo.ID) + err = tx.Exec("UPDATE `repository` SET num_closed_pulls=num_closed_pulls+1 WHERE id=?", opts.Repo.ID).Error } else { - _, err = e.Exec("UPDATE `repository` SET num_closed_issues=num_closed_issues+1 WHERE id=?", opts.Repo.ID) + err = tx.Exec("UPDATE `repository` SET num_closed_issues=num_closed_issues+1 WHERE id=?", opts.Repo.ID).Error } if err != nil { return nil, err } } - if _, err = e.Exec("UPDATE `issue` SET updated_unix = ? WHERE id = ?", time.Now().Unix(), opts.Issue.ID); err != nil { + if err = tx.Exec("UPDATE `issue` SET updated_unix = ? WHERE id = ?", tx.NowFunc().Unix(), opts.Issue.ID).Error; err != nil { return nil, errors.Newf("update issue 'updated_unix': %v", err) } // Notify watchers for whatever action comes in, ignore if no action type. if act.OpType > 0 { - if err = notifyWatchers(e, act); err != nil { + if err = notifyWatchers(tx, act); err != nil { log.Error("notifyWatchers: %v", err) } - if err = comment.mailParticipants(e, act.OpType, opts.Issue); err != nil { + if err = comment.mailParticipants(tx, act.OpType, opts.Issue); err != nil { log.Error("MailParticipants: %v", err) } } - return comment, comment.loadAttributes(e) + return comment, comment.loadAttributes(tx) } -func createStatusComment(e *xorm.Session, doer *User, repo *Repository, issue *Issue) (*Comment, error) { +func createStatusComment(tx *gorm.DB, doer *User, repo *Repository, issue *Issue) (*Comment, error) { cmtType := CommentTypeClose if !issue.IsClosed { cmtType = CommentTypeReopen } - return createComment(e, &CreateCommentOptions{ + return createComment(tx, &CreateCommentOptions{ Type: cmtType, Doer: doer, Repo: repo, @@ -318,18 +322,12 @@ type CreateCommentOptions struct { // CreateComment creates comment of issue or commit. func CreateComment(opts *CreateCommentOptions) (comment *Comment, err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return nil, err - } - - comment, err = createComment(sess, opts) - if err != nil { - return nil, err - } - - return comment, sess.Commit() + err = db.Transaction(func(tx *gorm.DB) error { + var err error + comment, err = createComment(tx, opts) + return err + }) + return comment, err } // CreateIssueComment creates a plain issue comment. @@ -367,14 +365,12 @@ func CreateRefComment(doer *User, repo *Repository, issue *Issue, content, commi } // Check if same reference from same commit has already existed. - has, err := x.Get(&Comment{ - Type: CommentTypeCommitRef, - IssueID: issue.ID, - CommitSHA: commitSHA, - }) + var count int64 + err := db.Model(new(Comment)).Where("type = ? AND issue_id = ? AND commit_sha = ?", + CommentTypeCommitRef, issue.ID, commitSHA).Count(&count).Error if err != nil { return errors.Newf("check reference comment: %v", err) - } else if has { + } else if count > 0 { return nil } @@ -411,19 +407,20 @@ func (ErrCommentNotExist) NotFound() bool { // GetCommentByID returns the comment by given ID. func GetCommentByID(id int64) (*Comment, error) { c := new(Comment) - has, err := x.Id(id).Get(c) + err := db.Where("id = ?", id).First(c).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrCommentNotExist{args: map[string]any{"commentID": id}} + } return nil, err - } else if !has { - return nil, ErrCommentNotExist{args: map[string]any{"commentID": id}} } return c, c.LoadAttributes() } // FIXME: use CommentList to improve performance. -func loadCommentsAttributes(e Engine, comments []*Comment) (err error) { +func loadCommentsAttributes(tx *gorm.DB, comments []*Comment) (err error) { for i := range comments { - if err = comments[i].loadAttributes(e); err != nil { + if err = comments[i].loadAttributes(tx); err != nil { return errors.Newf("loadAttributes [%d]: %v", comments[i].ID, err) } } @@ -431,53 +428,55 @@ func loadCommentsAttributes(e Engine, comments []*Comment) (err error) { return nil } -func getCommentsByIssueIDSince(e Engine, issueID, since int64) ([]*Comment, error) { +func getCommentsByIssueIDSince(tx *gorm.DB, issueID, since int64) ([]*Comment, error) { comments := make([]*Comment, 0, 10) - sess := e.Where("issue_id = ?", issueID).Asc("created_unix") + query := tx.Where("issue_id = ?", issueID).Order("created_unix ASC") if since > 0 { - sess.And("updated_unix >= ?", since) + query = query.Where("updated_unix >= ?", since) } - if err := sess.Find(&comments); err != nil { + if err := query.Find(&comments).Error; err != nil { return nil, err } - return comments, loadCommentsAttributes(e, comments) + return comments, loadCommentsAttributes(tx, comments) } -func getCommentsByRepoIDSince(e Engine, repoID, since int64) ([]*Comment, error) { +func getCommentsByRepoIDSince(tx *gorm.DB, repoID, since int64) ([]*Comment, error) { comments := make([]*Comment, 0, 10) - sess := e.Where("issue.repo_id = ?", repoID).Join("INNER", "issue", "issue.id = comment.issue_id").Asc("comment.created_unix") + query := tx.Joins("INNER JOIN issue ON issue.id = comment.issue_id"). + Where("issue.repo_id = ?", repoID). + Order("comment.created_unix ASC") if since > 0 { - sess.And("comment.updated_unix >= ?", since) + query = query.Where("comment.updated_unix >= ?", since) } - if err := sess.Find(&comments); err != nil { + if err := query.Find(&comments).Error; err != nil { return nil, err } - return comments, loadCommentsAttributes(e, comments) + return comments, loadCommentsAttributes(tx, comments) } -func getCommentsByIssueID(e Engine, issueID int64) ([]*Comment, error) { - return getCommentsByIssueIDSince(e, issueID, -1) +func getCommentsByIssueID(tx *gorm.DB, issueID int64) ([]*Comment, error) { + return getCommentsByIssueIDSince(tx, issueID, -1) } // GetCommentsByIssueID returns all comments of an issue. func GetCommentsByIssueID(issueID int64) ([]*Comment, error) { - return getCommentsByIssueID(x, issueID) + return getCommentsByIssueID(db, issueID) } // GetCommentsByIssueIDSince returns a list of comments of an issue since a given time point. func GetCommentsByIssueIDSince(issueID, since int64) ([]*Comment, error) { - return getCommentsByIssueIDSince(x, issueID, since) + return getCommentsByIssueIDSince(db, issueID, since) } // GetCommentsByRepoIDSince returns a list of comments for all issues in a repo since a given time point. func GetCommentsByRepoIDSince(repoID, since int64) ([]*Comment, error) { - return getCommentsByRepoIDSince(x, repoID, since) + return getCommentsByRepoIDSince(db, repoID, since) } // UpdateComment updates information of comment. func UpdateComment(doer *User, c *Comment, oldContent string) (err error) { - if _, err = x.Id(c.ID).AllCols().Update(c); err != nil { + if err = db.Model(c).Where("id = ?", c.ID).Updates(c).Error; err != nil { return err } @@ -511,24 +510,21 @@ func DeleteCommentByID(doer *User, id int64) error { return err } - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if _, err = sess.ID(comment.ID).Delete(new(Comment)); err != nil { - return err - } - - if comment.Type == CommentTypeComment { - if _, err = sess.Exec("UPDATE `issue` SET num_comments = num_comments - 1 WHERE id = ?", comment.IssueID); err != nil { + err = db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("id = ?", comment.ID).Delete(new(Comment)).Error; err != nil { return err } - } - if err = sess.Commit(); err != nil { - return errors.Newf("commit: %v", err) + if comment.Type == CommentTypeComment { + if err := tx.Exec("UPDATE `issue` SET num_comments = num_comments - 1 WHERE id = ?", comment.IssueID).Error; err != nil { + return err + } + } + + return nil + }) + if err != nil { + return errors.Newf("transaction: %v", err) } _, err = DeleteAttachmentsByComment(comment.ID, true) diff --git a/internal/database/issue_label.go b/internal/database/issue_label.go index 58d2c71bd..187bfafef 100644 --- a/internal/database/issue_label.go +++ b/internal/database/issue_label.go @@ -6,10 +6,9 @@ import ( "strconv" "strings" - "xorm.io/xorm" - "github.com/cockroachdb/errors" api "github.com/gogs/go-gogs-client" + "gorm.io/gorm" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/lazyregexp" @@ -53,13 +52,13 @@ func GetLabelTemplateFile(name string) ([][2]string, error) { // Label represents a label of repository for issues. type Label struct { ID int64 - RepoID int64 `xorm:"INDEX"` + RepoID int64 `gorm:"index"` Name string - Color string `xorm:"VARCHAR(7)"` + Color string `gorm:"type:varchar(7)"` NumIssues int NumClosedIssues int - NumOpenIssues int `xorm:"-" json:"-" gorm:"-"` - IsChecked bool `xorm:"-" json:"-" gorm:"-"` + NumOpenIssues int `gorm:"-" json:"-"` + IsChecked bool `gorm:"-" json:"-"` } func (l *Label) APIFormat() *api.Label { @@ -97,8 +96,7 @@ func (l *Label) ForegroundColor() template.CSS { // NewLabels creates new label(s) for a repository. func NewLabels(labels ...*Label) error { - _, err := x.Insert(labels) - return err + return db.Create(labels).Error } var _ errutil.NotFound = (*ErrLabelNotExist)(nil) @@ -123,20 +121,22 @@ func (ErrLabelNotExist) NotFound() bool { // getLabelOfRepoByName returns a label by Name in given repository. // If pass repoID as 0, then ORM will ignore limitation of repository // and can return arbitrary label with any valid ID. -func getLabelOfRepoByName(e Engine, repoID int64, labelName string) (*Label, error) { +func getLabelOfRepoByName(tx *gorm.DB, repoID int64, labelName string) (*Label, error) { if len(labelName) <= 0 { return nil, ErrLabelNotExist{args: map[string]any{"repoID": repoID}} } - l := &Label{ - Name: labelName, - RepoID: repoID, + l := &Label{} + query := tx.Where("name = ?", labelName) + if repoID > 0 { + query = query.Where("repo_id = ?", repoID) } - has, err := e.Get(l) + err := query.First(l).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrLabelNotExist{args: map[string]any{"repoID": repoID}} + } return nil, err - } else if !has { - return nil, ErrLabelNotExist{args: map[string]any{"repoID": repoID}} } return l, nil } @@ -144,54 +144,56 @@ func getLabelOfRepoByName(e Engine, repoID int64, labelName string) (*Label, err // getLabelInRepoByID returns a label by ID in given repository. // If pass repoID as 0, then ORM will ignore limitation of repository // and can return arbitrary label with any valid ID. -func getLabelOfRepoByID(e Engine, repoID, labelID int64) (*Label, error) { +func getLabelOfRepoByID(tx *gorm.DB, repoID, labelID int64) (*Label, error) { if labelID <= 0 { return nil, ErrLabelNotExist{args: map[string]any{"repoID": repoID, "labelID": labelID}} } - l := &Label{ - ID: labelID, - RepoID: repoID, + l := &Label{} + query := tx.Where("id = ?", labelID) + if repoID > 0 { + query = query.Where("repo_id = ?", repoID) } - has, err := e.Get(l) + err := query.First(l).Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrLabelNotExist{args: map[string]any{"repoID": repoID, "labelID": labelID}} + } return nil, err - } else if !has { - return nil, ErrLabelNotExist{args: map[string]any{"repoID": repoID, "labelID": labelID}} } return l, nil } // GetLabelByID returns a label by given ID. func GetLabelByID(id int64) (*Label, error) { - return getLabelOfRepoByID(x, 0, id) + return getLabelOfRepoByID(db, 0, id) } // GetLabelOfRepoByID returns a label by ID in given repository. func GetLabelOfRepoByID(repoID, labelID int64) (*Label, error) { - return getLabelOfRepoByID(x, repoID, labelID) + return getLabelOfRepoByID(db, repoID, labelID) } // GetLabelOfRepoByName returns a label by name in given repository. func GetLabelOfRepoByName(repoID int64, labelName string) (*Label, error) { - return getLabelOfRepoByName(x, repoID, labelName) + return getLabelOfRepoByName(db, repoID, labelName) } // GetLabelsInRepoByIDs returns a list of labels by IDs in given repository, // it silently ignores label IDs that are not belong to the repository. func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, x.Where("repo_id = ?", repoID).In("id", tool.Int64sToStrings(labelIDs)).Asc("name").Find(&labels) + return labels, db.Where("repo_id = ? AND id IN ?", repoID, labelIDs).Order("name ASC").Find(&labels).Error } // GetLabelsByRepoID returns all labels that belong to given repository by ID. func GetLabelsByRepoID(repoID int64) ([]*Label, error) { labels := make([]*Label, 0, 10) - return labels, x.Where("repo_id = ?", repoID).Asc("name").Find(&labels) + return labels, db.Where("repo_id = ?", repoID).Order("name ASC").Find(&labels).Error } -func getLabelsByIssueID(e Engine, issueID int64) ([]*Label, error) { - issueLabels, err := getIssueLabels(e, issueID) +func getLabelsByIssueID(tx *gorm.DB, issueID int64) ([]*Label, error) { + issueLabels, err := getIssueLabels(tx, issueID) if err != nil { return nil, errors.Newf("getIssueLabels: %v", err) } else if len(issueLabels) == 0 { @@ -204,22 +206,21 @@ func getLabelsByIssueID(e Engine, issueID int64) ([]*Label, error) { } labels := make([]*Label, 0, len(labelIDs)) - return labels, e.Where("id > 0").In("id", tool.Int64sToStrings(labelIDs)).Asc("name").Find(&labels) + return labels, tx.Where("id > 0 AND id IN ?", labelIDs).Order("name ASC").Find(&labels).Error } // GetLabelsByIssueID returns all labels that belong to given issue by ID. func GetLabelsByIssueID(issueID int64) ([]*Label, error) { - return getLabelsByIssueID(x, issueID) + return getLabelsByIssueID(db, issueID) } -func updateLabel(e Engine, l *Label) error { - _, err := e.ID(l.ID).AllCols().Update(l) - return err +func updateLabel(tx *gorm.DB, l *Label) error { + return tx.Model(l).Where("id = ?", l.ID).Updates(l).Error } // UpdateLabel updates label information. func UpdateLabel(l *Label) error { - return updateLabel(x, l) + return updateLabel(db, l) } // DeleteLabel delete a label of given repository. @@ -232,19 +233,15 @@ func DeleteLabel(repoID, labelID int64) error { return err } - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if _, err = sess.ID(labelID).Delete(new(Label)); err != nil { - return err - } else if _, err = sess.Where("label_id = ?", labelID).Delete(new(IssueLabel)); err != nil { - return err - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("id = ?", labelID).Delete(new(Label)).Error; err != nil { + return err + } + if err := tx.Where("label_id = ?", labelID).Delete(new(IssueLabel)).Error; err != nil { + return err + } + return nil + }) } // .___ .____ ___. .__ @@ -257,25 +254,26 @@ func DeleteLabel(repoID, labelID int64) error { // IssueLabel represents an issue-lable relation. type IssueLabel struct { ID int64 - IssueID int64 `xorm:"UNIQUE(s)"` - LabelID int64 `xorm:"UNIQUE(s)"` + IssueID int64 `gorm:"uniqueIndex:issue_label_unique"` + LabelID int64 `gorm:"uniqueIndex:issue_label_unique"` } -func hasIssueLabel(e Engine, issueID, labelID int64) bool { - has, _ := e.Where("issue_id = ? AND label_id = ?", issueID, labelID).Get(new(IssueLabel)) - return has +func hasIssueLabel(tx *gorm.DB, issueID, labelID int64) bool { + var count int64 + tx.Model(new(IssueLabel)).Where("issue_id = ? AND label_id = ?", issueID, labelID).Count(&count) + return count > 0 } // HasIssueLabel returns true if issue has been labeled. func HasIssueLabel(issueID, labelID int64) bool { - return hasIssueLabel(x, issueID, labelID) + return hasIssueLabel(db, issueID, labelID) } -func newIssueLabel(e *xorm.Session, issue *Issue, label *Label) (err error) { - if _, err = e.Insert(&IssueLabel{ +func newIssueLabel(tx *gorm.DB, issue *Issue, label *Label) (err error) { + if err = tx.Create(&IssueLabel{ IssueID: issue.ID, LabelID: label.ID, - }); err != nil { + }).Error; err != nil { return err } @@ -284,7 +282,7 @@ func newIssueLabel(e *xorm.Session, issue *Issue, label *Label) (err error) { label.NumClosedIssues++ } - if err = updateLabel(e, label); err != nil { + if err = updateLabel(tx, label); err != nil { return errors.Newf("updateLabel: %v", err) } @@ -298,26 +296,18 @@ func NewIssueLabel(issue *Issue, label *Label) (err error) { return nil } - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = newIssueLabel(sess, issue, label); err != nil { - return err - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + return newIssueLabel(tx, issue, label) + }) } -func newIssueLabels(e *xorm.Session, issue *Issue, labels []*Label) (err error) { +func newIssueLabels(tx *gorm.DB, issue *Issue, labels []*Label) (err error) { for i := range labels { - if hasIssueLabel(e, issue.ID, labels[i].ID) { + if hasIssueLabel(tx, issue.ID, labels[i].ID) { continue } - if err = newIssueLabel(e, issue, labels[i]); err != nil { + if err = newIssueLabel(tx, issue, labels[i]); err != nil { return errors.Newf("newIssueLabel: %v", err) } } @@ -327,34 +317,23 @@ func newIssueLabels(e *xorm.Session, issue *Issue, labels []*Label) (err error) // NewIssueLabels creates a list of issue-label relations. func NewIssueLabels(issue *Issue, labels []*Label) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = newIssueLabels(sess, issue, labels); err != nil { - return err - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + return newIssueLabels(tx, issue, labels) + }) } -func getIssueLabels(e Engine, issueID int64) ([]*IssueLabel, error) { +func getIssueLabels(tx *gorm.DB, issueID int64) ([]*IssueLabel, error) { issueLabels := make([]*IssueLabel, 0, 10) - return issueLabels, e.Where("issue_id=?", issueID).Asc("label_id").Find(&issueLabels) + return issueLabels, tx.Where("issue_id = ?", issueID).Order("label_id ASC").Find(&issueLabels).Error } // GetIssueLabels returns all issue-label relations of given issue by ID. func GetIssueLabels(issueID int64) ([]*IssueLabel, error) { - return getIssueLabels(x, issueID) + return getIssueLabels(db, issueID) } -func deleteIssueLabel(e *xorm.Session, issue *Issue, label *Label) (err error) { - if _, err = e.Delete(&IssueLabel{ - IssueID: issue.ID, - LabelID: label.ID, - }); err != nil { +func deleteIssueLabel(tx *gorm.DB, issue *Issue, label *Label) (err error) { + if err = tx.Where("issue_id = ? AND label_id = ?", issue.ID, label.ID).Delete(&IssueLabel{}).Error; err != nil { return err } @@ -362,7 +341,7 @@ func deleteIssueLabel(e *xorm.Session, issue *Issue, label *Label) (err error) { if issue.IsClosed { label.NumClosedIssues-- } - if err = updateLabel(e, label); err != nil { + if err = updateLabel(tx, label); err != nil { return errors.Newf("updateLabel: %v", err) } @@ -377,15 +356,7 @@ func deleteIssueLabel(e *xorm.Session, issue *Issue, label *Label) (err error) { // DeleteIssueLabel deletes issue-label relation. func DeleteIssueLabel(issue *Issue, label *Label) (err error) { - sess := x.NewSession() - defer sess.Close() - if err = sess.Begin(); err != nil { - return err - } - - if err = deleteIssueLabel(sess, issue, label); err != nil { - return err - } - - return sess.Commit() + return db.Transaction(func(tx *gorm.DB) error { + return deleteIssueLabel(tx, issue, label) + }) }