mirror of
https://github.com/gogs/gogs.git
synced 2026-01-02 13:39:57 +01:00
db: use context and go-mockgen for login sources (#7041)
This commit is contained in:
@@ -9,11 +9,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gogs.io/gogs/internal/errutil"
|
||||
)
|
||||
|
||||
func Test_loginSourceFiles_GetByID(t *testing.T) {
|
||||
func TestLoginSourceFiles_GetByID(t *testing.T) {
|
||||
store := &loginSourceFiles{
|
||||
sources: []*LoginSource{
|
||||
{ID: 101},
|
||||
@@ -28,14 +29,12 @@ func Test_loginSourceFiles_GetByID(t *testing.T) {
|
||||
|
||||
t.Run("source exists", func(t *testing.T) {
|
||||
source, err := store.GetByID(101)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(101), source.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_loginSourceFiles_Len(t *testing.T) {
|
||||
func TestLoginSourceFiles_Len(t *testing.T) {
|
||||
store := &loginSourceFiles{
|
||||
sources: []*LoginSource{
|
||||
{ID: 101},
|
||||
@@ -45,7 +44,7 @@ func Test_loginSourceFiles_Len(t *testing.T) {
|
||||
assert.Equal(t, 1, store.Len())
|
||||
}
|
||||
|
||||
func Test_loginSourceFiles_List(t *testing.T) {
|
||||
func TestLoginSourceFiles_List(t *testing.T) {
|
||||
store := &loginSourceFiles{
|
||||
sources: []*LoginSource{
|
||||
{ID: 101, IsActived: true},
|
||||
@@ -65,7 +64,7 @@ func Test_loginSourceFiles_List(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_loginSourceFiles_Update(t *testing.T) {
|
||||
func TestLoginSourceFiles_Update(t *testing.T) {
|
||||
store := &loginSourceFiles{
|
||||
sources: []*LoginSource{
|
||||
{ID: 101, IsActived: true, IsDefault: true},
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -25,24 +26,24 @@ import (
|
||||
//
|
||||
// NOTE: All methods are sorted in alphabetical order.
|
||||
type LoginSourcesStore interface {
|
||||
// Create creates a new login source and persist to database.
|
||||
// It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
|
||||
Create(opts CreateLoginSourceOpts) (*LoginSource, error)
|
||||
// Create creates a new login source and persist to database. It returns
|
||||
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
|
||||
Create(ctx context.Context, opts CreateLoginSourceOpts) (*LoginSource, error)
|
||||
// Count returns the total number of login sources.
|
||||
Count() int64
|
||||
// DeleteByID deletes a login source by given ID.
|
||||
// It returns ErrLoginSourceInUse if at least one user is associated with the login source.
|
||||
DeleteByID(id int64) error
|
||||
// GetByID returns the login source with given ID.
|
||||
// It returns ErrLoginSourceNotExist when not found.
|
||||
GetByID(id int64) (*LoginSource, error)
|
||||
Count(ctx context.Context) int64
|
||||
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
|
||||
// if at least one user is associated with the login source.
|
||||
DeleteByID(ctx context.Context, id int64) error
|
||||
// GetByID returns the login source with given ID. It returns
|
||||
// ErrLoginSourceNotExist when not found.
|
||||
GetByID(ctx context.Context, id int64) (*LoginSource, error)
|
||||
// List returns a list of login sources filtered by options.
|
||||
List(opts ListLoginSourceOpts) ([]*LoginSource, error)
|
||||
List(ctx context.Context, opts ListLoginSourceOpts) ([]*LoginSource, error)
|
||||
// ResetNonDefault clears default flag for all the other login sources.
|
||||
ResetNonDefault(source *LoginSource) error
|
||||
// Save persists all values of given login source to database or local file.
|
||||
// The Updated field is set to current time automatically.
|
||||
Save(t *LoginSource) error
|
||||
ResetNonDefault(ctx context.Context, source *LoginSource) error
|
||||
// Save persists all values of given login source to database or local file. The
|
||||
// Updated field is set to current time automatically.
|
||||
Save(ctx context.Context, t *LoginSource) error
|
||||
}
|
||||
|
||||
var LoginSources LoginSourcesStore
|
||||
@@ -65,7 +66,7 @@ type LoginSource struct {
|
||||
File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM save hook.
|
||||
// BeforeSave implements the GORM save hook.
|
||||
func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
|
||||
if s.Provider == nil {
|
||||
return nil
|
||||
@@ -74,7 +75,7 @@ func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM create hook.
|
||||
// BeforeCreate implements the GORM create hook.
|
||||
func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.CreatedUnix == 0 {
|
||||
s.CreatedUnix = tx.NowFunc().Unix()
|
||||
@@ -83,13 +84,13 @@ func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM update hook.
|
||||
// BeforeUpdate implements the GORM update hook.
|
||||
func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
|
||||
s.UpdatedUnix = tx.NowFunc().Unix()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE: This is a GORM query hook.
|
||||
// AfterFind implements the GORM query hook.
|
||||
func (s *LoginSource) AfterFind(_ *gorm.DB) error {
|
||||
s.Created = time.Unix(s.CreatedUnix, 0).Local()
|
||||
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
|
||||
@@ -209,8 +210,8 @@ func (err ErrLoginSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
|
||||
err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOpts) (*LoginSource, error) {
|
||||
err := db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
|
||||
if err == nil {
|
||||
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
|
||||
} else if err != gorm.ErrRecordNotFound {
|
||||
@@ -227,12 +228,12 @@ func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return source, db.DB.Create(source).Error
|
||||
return source, db.WithContext(ctx).Create(source).Error
|
||||
}
|
||||
|
||||
func (db *loginSources) Count() int64 {
|
||||
func (db *loginSources) Count(ctx context.Context) int64 {
|
||||
var count int64
|
||||
db.Model(new(LoginSource)).Count(&count)
|
||||
db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
|
||||
return count + int64(db.files.Len())
|
||||
}
|
||||
|
||||
@@ -249,21 +250,21 @@ func (err ErrLoginSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users: %v", err.args)
|
||||
}
|
||||
|
||||
func (db *loginSources) DeleteByID(id int64) error {
|
||||
func (db *loginSources) DeleteByID(ctx context.Context, id int64) error {
|
||||
var count int64
|
||||
err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
err := db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count > 0 {
|
||||
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
|
||||
}
|
||||
|
||||
return db.Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
return db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
|
||||
}
|
||||
|
||||
func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
|
||||
func (db *loginSources) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
|
||||
source := new(LoginSource)
|
||||
err := db.Where("id = ?", id).First(source).Error
|
||||
err := db.WithContext(ctx).Where("id = ?", id).First(source).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return db.files.GetByID(id)
|
||||
@@ -278,9 +279,9 @@ type ListLoginSourceOpts struct {
|
||||
OnlyActivated bool
|
||||
}
|
||||
|
||||
func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
|
||||
func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOpts) ([]*LoginSource, error) {
|
||||
var sources []*LoginSource
|
||||
query := db.Order("id ASC")
|
||||
query := db.WithContext(ctx).Order("id ASC")
|
||||
if opts.OnlyActivated {
|
||||
query = query.Where("is_actived = ?", true)
|
||||
}
|
||||
@@ -292,8 +293,12 @@ func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
|
||||
return append(sources, db.files.List(opts)...), nil
|
||||
}
|
||||
|
||||
func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
|
||||
err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
|
||||
func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
|
||||
err := db.WithContext(ctx).
|
||||
Model(new(LoginSource)).
|
||||
Where("id != ?", dflt.ID).
|
||||
Updates(map[string]interface{}{"is_default": false}).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -311,9 +316,9 @@ func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *loginSources) Save(source *LoginSource) error {
|
||||
func (db *loginSources) Save(ctx context.Context, source *LoginSource) error {
|
||||
if source.File == nil {
|
||||
return db.DB.Save(source).Error
|
||||
return db.WithContext(ctx).Save(source).Error
|
||||
}
|
||||
|
||||
source.File.SetGeneral("name", source.Name)
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mockrequire "github.com/derision-test/go-mockgen/testutil/require"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gogs.io/gogs/internal/auth"
|
||||
@@ -31,9 +34,7 @@ func TestLoginSource_BeforeSave(t *testing.T) {
|
||||
t.Run("Config has not been set", func(t *testing.T) {
|
||||
s := &LoginSource{}
|
||||
err := s.BeforeSave(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, s.Config)
|
||||
})
|
||||
|
||||
@@ -44,9 +45,7 @@ func TestLoginSource_BeforeSave(t *testing.T) {
|
||||
}),
|
||||
}
|
||||
err := s.BeforeSave(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
|
||||
})
|
||||
}
|
||||
@@ -93,20 +92,18 @@ func Test_loginSources(t *testing.T) {
|
||||
name string
|
||||
test func(*testing.T, *loginSources)
|
||||
}{
|
||||
{"Create", test_loginSources_Create},
|
||||
{"Count", test_loginSources_Count},
|
||||
{"DeleteByID", test_loginSources_DeleteByID},
|
||||
{"GetByID", test_loginSources_GetByID},
|
||||
{"List", test_loginSources_List},
|
||||
{"ResetNonDefault", test_loginSources_ResetNonDefault},
|
||||
{"Save", test_loginSources_Save},
|
||||
{"Create", loginSourcesCreate},
|
||||
{"Count", loginSourcesCount},
|
||||
{"DeleteByID", loginSourcesDeleteByID},
|
||||
{"GetByID", loginSourcesGetByID},
|
||||
{"List", loginSourcesList},
|
||||
{"ResetNonDefault", loginSourcesResetNonDefault},
|
||||
{"Save", loginSourcesSave},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
err := clearTables(t, db.DB, tables...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
tc.test(t, db)
|
||||
})
|
||||
@@ -116,62 +113,12 @@ func Test_loginSources(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func test_loginSources_Create(t *testing.T, db *loginSources) {
|
||||
func loginSourcesCreate(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create first login source with name "GitHub"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get it back and check the Created field
|
||||
source, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
|
||||
|
||||
// Try create second login source with same name should fail
|
||||
_, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
|
||||
expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
|
||||
assert.Equal(t, expErr, err)
|
||||
}
|
||||
|
||||
func test_loginSources_Count(t *testing.T, db *loginSources) {
|
||||
// Create two login sources, one in database and one as source file.
|
||||
_, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockLen: func() int {
|
||||
return 2
|
||||
},
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(3), db.Count())
|
||||
}
|
||||
|
||||
func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
|
||||
t.Run("delete but in used", func(t *testing.T) {
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
@@ -179,279 +126,300 @@ func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get it back and check the Created field
|
||||
source, err = db.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
|
||||
assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
|
||||
|
||||
// Try create second login source with same name should fail
|
||||
_, err = db.Create(ctx, CreateLoginSourceOpts{Name: source.Name})
|
||||
expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
|
||||
assert.Equal(t, expErr, err)
|
||||
}
|
||||
|
||||
func loginSourcesCount(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two login sources, one in database and one as source file.
|
||||
_, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.LenFunc.SetDefaultReturn(2)
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
|
||||
assert.Equal(t, int64(3), db.Count(ctx))
|
||||
}
|
||||
|
||||
func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("delete but in used", func(t *testing.T) {
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a user that uses this login source
|
||||
_, err = (&users{DB: db.DB}).Create("alice", "", CreateUserOpts{
|
||||
LoginSource: source.ID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the login source will result in error
|
||||
err = db.DeleteByID(source.ID)
|
||||
err = db.DeleteByID(ctx, source.ID)
|
||||
expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
|
||||
assert.Equal(t, expErr, err)
|
||||
})
|
||||
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockGetByID: func(id int64) (*LoginSource, error) {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
},
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
|
||||
// Create a login source with name "GitHub2"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub2",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub2",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete a non-existent ID is noop
|
||||
err = db.DeleteByID(9999)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.DeleteByID(ctx, 9999)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should be able to get it back
|
||||
_, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now delete this login source with ID
|
||||
err = db.DeleteByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.DeleteByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should get token not found error
|
||||
_, err = db.GetByID(source.ID)
|
||||
_, err = db.GetByID(ctx, source.ID)
|
||||
expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
|
||||
assert.Equal(t, expErr, err)
|
||||
}
|
||||
|
||||
func test_loginSources_GetByID(t *testing.T, db *loginSources) {
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockGetByID: func(id int64) (*LoginSource, error) {
|
||||
if id != 101 {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
}
|
||||
return &LoginSource{ID: id}, nil
|
||||
},
|
||||
func loginSourcesGetByID(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
|
||||
if id != 101 {
|
||||
return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
|
||||
}
|
||||
return &LoginSource{ID: id}, nil
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
|
||||
expConfig := &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
}
|
||||
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: expConfig,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get the one in the database and test the read/write hooks
|
||||
source, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, expConfig, source.Provider.Config())
|
||||
|
||||
// Get the one in source file store
|
||||
_, err = db.GetByID(101)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func test_loginSources_List(t *testing.T, db *loginSources) {
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockList: func(opts ListLoginSourceOpts) []*LoginSource {
|
||||
if opts.OnlyActivated {
|
||||
return []*LoginSource{
|
||||
{ID: 1},
|
||||
}
|
||||
}
|
||||
return []*LoginSource{
|
||||
{ID: 1},
|
||||
{ID: 2},
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
// Create two login sources in database, one activated and the other one not
|
||||
_, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.PAM,
|
||||
Name: "PAM",
|
||||
Config: &pam.Config{
|
||||
ServiceName: "PAM",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// List all login sources
|
||||
sources, err := db.List(ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, 4, len(sources), "number of sources")
|
||||
|
||||
// Only list activated login sources
|
||||
sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, 2, len(sources), "number of sources")
|
||||
}
|
||||
|
||||
func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
|
||||
setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
|
||||
MockList: func(opts ListLoginSourceOpts) []*LoginSource {
|
||||
return []*LoginSource{
|
||||
{
|
||||
File: &mockLoginSourceFileStore{
|
||||
MockSetGeneral: func(name, value string) {
|
||||
assert.Equal(t, "is_default", name)
|
||||
assert.Equal(t, "false", value)
|
||||
},
|
||||
MockSave: func() error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
MockUpdate: func(source *LoginSource) {},
|
||||
})
|
||||
|
||||
// Create two login sources both have default on
|
||||
source1, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.PAM,
|
||||
Name: "PAM",
|
||||
Default: true,
|
||||
Config: &pam.Config{
|
||||
ServiceName: "PAM",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
source2, err := db.Create(CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: true,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set source 1 as default
|
||||
err = db.ResetNonDefault(source1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify the default state
|
||||
source1, err = db.GetByID(source1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, source1.IsDefault)
|
||||
|
||||
source2, err = db.GetByID(source2.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.False(t, source2.IsDefault)
|
||||
}
|
||||
|
||||
func test_loginSources_Save(t *testing.T, db *loginSources) {
|
||||
t.Run("save to database", func(t *testing.T) {
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(CreateLoginSourceOpts{
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: expConfig,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the one in the database and test the read/write hooks
|
||||
source, err = db.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expConfig, source.Provider.Config())
|
||||
|
||||
// Get the one in source file store
|
||||
_, err = db.GetByID(ctx, 101)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func loginSourcesList(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOpts) []*LoginSource {
|
||||
if opts.OnlyActivated {
|
||||
return []*LoginSource{
|
||||
{ID: 1},
|
||||
}
|
||||
}
|
||||
return []*LoginSource{
|
||||
{ID: 1},
|
||||
{ID: 2},
|
||||
}
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
|
||||
// Create two login sources in database, one activated and the other one not
|
||||
_, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.PAM,
|
||||
Name: "PAM",
|
||||
Config: &pam.Config{
|
||||
ServiceName: "PAM",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List all login sources
|
||||
sources, err := db.List(ctx, ListLoginSourceOpts{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4, len(sources), "number of sources")
|
||||
|
||||
// Only list activated login sources
|
||||
sources, err = db.List(ctx, ListLoginSourceOpts{OnlyActivated: true})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(sources), "number of sources")
|
||||
}
|
||||
|
||||
func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
mock := NewMockLoginSourceFilesStore()
|
||||
mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOpts) []*LoginSource {
|
||||
mockFile := NewMockLoginSourceFileStore()
|
||||
mockFile.SetGeneralFunc.SetDefaultHook(func(name, value string) {
|
||||
assert.Equal(t, "is_default", name)
|
||||
assert.Equal(t, "false", value)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return []*LoginSource{
|
||||
{
|
||||
File: mockFile,
|
||||
},
|
||||
}
|
||||
})
|
||||
setMockLoginSourceFilesStore(t, db, mock)
|
||||
|
||||
// Create two login sources both have default on
|
||||
source1, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.PAM,
|
||||
Name: "PAM",
|
||||
Default: true,
|
||||
Config: &pam.Config{
|
||||
ServiceName: "PAM",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
source2, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: true,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set source 1 as default
|
||||
err = db.ResetNonDefault(ctx, source1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the default state
|
||||
source1, err = db.GetByID(ctx, source1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, source1.IsDefault)
|
||||
|
||||
source2, err = db.GetByID(ctx, source2.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, source2.IsDefault)
|
||||
}
|
||||
|
||||
func loginSourcesSave(t *testing.T, db *loginSources) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save to database", func(t *testing.T) {
|
||||
// Create a login source with name "GitHub"
|
||||
source, err := db.Create(ctx,
|
||||
CreateLoginSourceOpts{
|
||||
Type: auth.GitHub,
|
||||
Name: "GitHub",
|
||||
Activated: true,
|
||||
Default: false,
|
||||
Config: &github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
source.IsActived = false
|
||||
source.Provider = github.NewProvider(&github.Config{
|
||||
APIEndpoint: "https://api2.github.com",
|
||||
})
|
||||
err = db.Save(source)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Save(ctx, source)
|
||||
require.NoError(t, err)
|
||||
|
||||
source, err = db.GetByID(source.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
source, err = db.GetByID(ctx, source.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, source.IsActived)
|
||||
assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
|
||||
})
|
||||
|
||||
t.Run("save to file", func(t *testing.T) {
|
||||
calledSave := false
|
||||
mockFile := NewMockLoginSourceFileStore()
|
||||
source := &LoginSource{
|
||||
Provider: github.NewProvider(&github.Config{
|
||||
APIEndpoint: "https://api.github.com",
|
||||
}),
|
||||
File: &mockLoginSourceFileStore{
|
||||
MockSetGeneral: func(name, value string) {},
|
||||
MockSetConfig: func(cfg interface{}) error { return nil },
|
||||
MockSave: func() error {
|
||||
calledSave = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
File: mockFile,
|
||||
}
|
||||
err := db.Save(source)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.True(t, calledSave)
|
||||
err := db.Save(ctx, source)
|
||||
require.NoError(t, err)
|
||||
mockrequire.Called(t, mockFile.SaveFunc)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i PermsStore -o mocks.go
|
||||
//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -o mocks.go
|
||||
|
||||
func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
|
||||
before := AccessTokens
|
||||
@@ -26,31 +26,6 @@ func SetMockLFSStore(t *testing.T, mock LFSStore) {
|
||||
})
|
||||
}
|
||||
|
||||
var _ loginSourceFilesStore = (*mockLoginSourceFilesStore)(nil)
|
||||
|
||||
type mockLoginSourceFilesStore struct {
|
||||
MockGetByID func(id int64) (*LoginSource, error)
|
||||
MockLen func() int
|
||||
MockList func(opts ListLoginSourceOpts) []*LoginSource
|
||||
MockUpdate func(source *LoginSource)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) GetByID(id int64) (*LoginSource, error) {
|
||||
return m.MockGetByID(id)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) Len() int {
|
||||
return m.MockLen()
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) List(opts ListLoginSourceOpts) []*LoginSource {
|
||||
return m.MockList(opts)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFilesStore) Update(source *LoginSource) {
|
||||
m.MockUpdate(source)
|
||||
}
|
||||
|
||||
func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSourceFilesStore) {
|
||||
before := db.files
|
||||
db.files = mock
|
||||
@@ -59,26 +34,6 @@ func setMockLoginSourceFilesStore(t *testing.T, db *loginSources, mock loginSour
|
||||
})
|
||||
}
|
||||
|
||||
var _ loginSourceFileStore = (*mockLoginSourceFileStore)(nil)
|
||||
|
||||
type mockLoginSourceFileStore struct {
|
||||
MockSetGeneral func(name, value string)
|
||||
MockSetConfig func(cfg interface{}) error
|
||||
MockSave func() error
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFileStore) SetGeneral(name, value string) {
|
||||
m.MockSetGeneral(name, value)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFileStore) SetConfig(cfg interface{}) error {
|
||||
return m.MockSetConfig(cfg)
|
||||
}
|
||||
|
||||
func (m *mockLoginSourceFileStore) Save() error {
|
||||
return m.MockSave()
|
||||
}
|
||||
|
||||
func SetMockPermsStore(t *testing.T, mock PermsStore) {
|
||||
before := Perms
|
||||
Perms = mock
|
||||
|
||||
1790
internal/db/mocks.go
1790
internal/db/mocks.go
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -208,7 +209,7 @@ type Statistic struct {
|
||||
}
|
||||
}
|
||||
|
||||
func GetStatistic() (stats Statistic) {
|
||||
func GetStatistic(ctx context.Context) (stats Statistic) {
|
||||
stats.Counter.User = CountUsers()
|
||||
stats.Counter.Org = CountOrganizations()
|
||||
stats.Counter.PublicKey, _ = x.Count(new(PublicKey))
|
||||
@@ -223,7 +224,7 @@ func GetStatistic() (stats Statistic) {
|
||||
stats.Counter.Follow, _ = x.Count(new(Follow))
|
||||
stats.Counter.Mirror, _ = x.Count(new(Mirror))
|
||||
stats.Counter.Release, _ = x.Count(new(Release))
|
||||
stats.Counter.LoginSource = LoginSources.Count()
|
||||
stats.Counter.LoginSource = LoginSources.Count(ctx)
|
||||
stats.Counter.Webhook, _ = x.Count(new(Webhook))
|
||||
stats.Counter.Milestone, _ = x.Count(new(Milestone))
|
||||
stats.Counter.Label, _ = x.Count(new(Label))
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -80,6 +81,8 @@ func (err ErrLoginSourceMismatch) Error() string {
|
||||
}
|
||||
|
||||
func (db *users) Authenticate(login, password string, loginSourceID int64) (*User, error) {
|
||||
ctx := context.TODO()
|
||||
|
||||
login = strings.ToLower(login)
|
||||
|
||||
var query *gorm.DB
|
||||
@@ -127,7 +130,7 @@ func (db *users) Authenticate(login, password string, loginSourceID int64) (*Use
|
||||
createNewUser = true
|
||||
}
|
||||
|
||||
source, err := LoginSources.GetByID(authSourceID)
|
||||
source, err := LoginSources.GetByID(ctx, authSourceID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get login source")
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func Dashboard(c *context.Context) {
|
||||
c.Data["BuildTime"] = conf.BuildTime
|
||||
c.Data["BuildCommit"] = conf.BuildCommit
|
||||
|
||||
c.Data["Stats"] = db.GetStatistic()
|
||||
c.Data["Stats"] = db.GetStatistic(c.Req.Context())
|
||||
// FIXME: update periodically
|
||||
updateSystemStatus()
|
||||
c.Data["SysStatus"] = sysStatus
|
||||
|
||||
@@ -35,13 +35,13 @@ func Authentications(c *context.Context) {
|
||||
c.PageIs("AdminAuthentications")
|
||||
|
||||
var err error
|
||||
c.Data["Sources"], err = db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
c.Data["Sources"], err = db.LoginSources.List(c.Req.Context(), db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data["Total"] = db.LoginSources.Count()
|
||||
c.Data["Total"] = db.LoginSources.Count(c.Req.Context())
|
||||
c.Success(AUTHS)
|
||||
}
|
||||
|
||||
@@ -159,13 +159,15 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
|
||||
return
|
||||
}
|
||||
|
||||
source, err := db.LoginSources.Create(db.CreateLoginSourceOpts{
|
||||
Type: auth.Type(f.Type),
|
||||
Name: f.Name,
|
||||
Activated: f.IsActive,
|
||||
Default: f.IsDefault,
|
||||
Config: config,
|
||||
})
|
||||
source, err := db.LoginSources.Create(c.Req.Context(),
|
||||
db.CreateLoginSourceOpts{
|
||||
Type: auth.Type(f.Type),
|
||||
Name: f.Name,
|
||||
Activated: f.IsActive,
|
||||
Default: f.IsDefault,
|
||||
Config: config,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
if db.IsErrLoginSourceAlreadyExist(err) {
|
||||
c.FormErr("Name")
|
||||
@@ -177,7 +179,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
|
||||
}
|
||||
|
||||
if source.IsDefault {
|
||||
err = db.LoginSources.ResetNonDefault(source)
|
||||
err = db.LoginSources.ResetNonDefault(c.Req.Context(), source)
|
||||
if err != nil {
|
||||
c.Error(err, "reset non-default login sources")
|
||||
return
|
||||
@@ -198,7 +200,7 @@ func EditAuthSource(c *context.Context) {
|
||||
c.Data["SecurityProtocols"] = securityProtocols
|
||||
c.Data["SMTPAuths"] = smtp.AuthTypes
|
||||
|
||||
source, err := db.LoginSources.GetByID(c.ParamsInt64(":authid"))
|
||||
source, err := db.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return
|
||||
@@ -216,7 +218,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
||||
|
||||
c.Data["SMTPAuths"] = smtp.AuthTypes
|
||||
|
||||
source, err := db.LoginSources.GetByID(c.ParamsInt64(":authid"))
|
||||
source, err := db.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return
|
||||
@@ -255,13 +257,13 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
||||
source.IsActived = f.IsActive
|
||||
source.IsDefault = f.IsDefault
|
||||
source.Provider = provider
|
||||
if err := db.LoginSources.Save(source); err != nil {
|
||||
if err := db.LoginSources.Save(c.Req.Context(), source); err != nil {
|
||||
c.Error(err, "update login source")
|
||||
return
|
||||
}
|
||||
|
||||
if source.IsDefault {
|
||||
err = db.LoginSources.ResetNonDefault(source)
|
||||
err = db.LoginSources.ResetNonDefault(c.Req.Context(), source)
|
||||
if err != nil {
|
||||
c.Error(err, "reset non-default login sources")
|
||||
return
|
||||
@@ -276,7 +278,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
|
||||
|
||||
func DeleteAuthSource(c *context.Context) {
|
||||
id := c.ParamsInt64(":authid")
|
||||
if err := db.LoginSources.DeleteByID(id); err != nil {
|
||||
if err := db.LoginSources.DeleteByID(c.Req.Context(), id); err != nil {
|
||||
if db.IsErrLoginSourceInUse(err) {
|
||||
c.Flash.Error(c.Tr("admin.auths.still_in_used"))
|
||||
} else {
|
||||
|
||||
@@ -46,7 +46,7 @@ func NewUser(c *context.Context) {
|
||||
|
||||
c.Data["login_type"] = "0-0"
|
||||
|
||||
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
sources, err := db.LoginSources.List(c.Req.Context(), db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
@@ -62,7 +62,7 @@ func NewUserPost(c *context.Context, f form.AdminCrateUser) {
|
||||
c.Data["PageIsAdmin"] = true
|
||||
c.Data["PageIsAdminUsers"] = true
|
||||
|
||||
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
sources, err := db.LoginSources.List(c.Req.Context(), db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return
|
||||
@@ -127,7 +127,7 @@ func prepareUserInfo(c *context.Context) *db.User {
|
||||
c.Data["User"] = u
|
||||
|
||||
if u.LoginSource > 0 {
|
||||
c.Data["LoginSource"], err = db.LoginSources.GetByID(u.LoginSource)
|
||||
c.Data["LoginSource"], err = db.LoginSources.GetByID(c.Req.Context(), u.LoginSource)
|
||||
if err != nil {
|
||||
c.Error(err, "get login source by ID")
|
||||
return nil
|
||||
@@ -136,7 +136,7 @@ func prepareUserInfo(c *context.Context) *db.User {
|
||||
c.Data["LoginSource"] = &db.LoginSource{}
|
||||
}
|
||||
|
||||
sources, err := db.LoginSources.List(db.ListLoginSourceOpts{})
|
||||
sources, err := db.LoginSources.List(c.Req.Context(), db.ListLoginSourceOpts{})
|
||||
if err != nil {
|
||||
c.Error(err, "list login sources")
|
||||
return nil
|
||||
|
||||
@@ -22,7 +22,7 @@ func parseLoginSource(c *context.APIContext, u *db.User, sourceID int64, loginNa
|
||||
return
|
||||
}
|
||||
|
||||
source, err := db.LoginSources.GetByID(sourceID)
|
||||
source, err := db.LoginSources.GetByID(c.Req.Context(), sourceID)
|
||||
if err != nil {
|
||||
if db.IsErrLoginSourceNotExist(err) {
|
||||
c.ErrorStatus(http.StatusUnprocessableEntity, err)
|
||||
|
||||
@@ -102,7 +102,7 @@ func Login(c *context.Context) {
|
||||
}
|
||||
|
||||
// Display normal login page
|
||||
loginSources, err := db.LoginSources.List(db.ListLoginSourceOpts{OnlyActivated: true})
|
||||
loginSources, err := db.LoginSources.List(c.Req.Context(), db.ListLoginSourceOpts{OnlyActivated: true})
|
||||
if err != nil {
|
||||
c.Error(err, "list activated login sources")
|
||||
return
|
||||
@@ -149,7 +149,7 @@ func afterLogin(c *context.Context, u *db.User, remember bool) {
|
||||
func LoginPost(c *context.Context, f form.SignIn) {
|
||||
c.Title("sign_in")
|
||||
|
||||
loginSources, err := db.LoginSources.List(db.ListLoginSourceOpts{OnlyActivated: true})
|
||||
loginSources, err := db.LoginSources.List(c.Req.Context(), db.ListLoginSourceOpts{OnlyActivated: true})
|
||||
if err != nil {
|
||||
c.Error(err, "list activated login sources")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user