Auth: Prevent duplicate super admin accounts from being created #98
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
d8ab9616a5
commit
50913e301c
3 changed files with 22 additions and 8 deletions
|
@ -112,7 +112,9 @@ func FindUser(find User) *User {
|
|||
|
||||
// Build query.
|
||||
stmt := UnscopedDb()
|
||||
if find.ID != 0 {
|
||||
if find.ID != 0 && find.UserName != "" {
|
||||
stmt = stmt.Where("id = ? OR user_name = ?", find.ID, find.UserName)
|
||||
} else if find.ID != 0 {
|
||||
stmt = stmt.Where("id = ?", find.ID)
|
||||
} else if rnd.IsUID(find.UserUID, UserUID) {
|
||||
stmt = stmt.Where("user_uid = ?", find.UserUID)
|
||||
|
@ -495,8 +497,6 @@ func (m *User) Provider() authn.ProviderType {
|
|||
func (m *User) SetProvider(t authn.ProviderType) *User {
|
||||
if m == nil {
|
||||
return nil
|
||||
} else if m.ID <= 0 {
|
||||
return m
|
||||
}
|
||||
|
||||
m.AuthProvider = t.String()
|
||||
|
@ -781,7 +781,7 @@ func (m *User) Validate() (err error) {
|
|||
if err = Db().
|
||||
Where("user_name = ? AND id <> ?", m.UserName, m.ID).
|
||||
First(&duplicate).Error; err == nil {
|
||||
return fmt.Errorf("username %s already exists", clean.LogQuote(m.UserName))
|
||||
return fmt.Errorf("user %s already exists", clean.LogQuote(m.UserName))
|
||||
} else if err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -331,6 +331,19 @@ func TestFindUser(t *testing.T) {
|
|||
assert.NotEmpty(t, m.CreatedAt)
|
||||
assert.NotEmpty(t, m.UpdatedAt)
|
||||
})
|
||||
t.Run("Admin", func(t *testing.T) {
|
||||
m := FindUser(User{ID: 2, UserName: "admin"})
|
||||
|
||||
if m == nil {
|
||||
t.Fatal("result should not be nil")
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, m.ID)
|
||||
assert.NotEmpty(t, m.UserUID)
|
||||
assert.Equal(t, "admin", m.UserName)
|
||||
assert.NotEmpty(t, m.CreatedAt)
|
||||
assert.NotEmpty(t, m.UpdatedAt)
|
||||
})
|
||||
t.Run("UserUID", func(t *testing.T) {
|
||||
m := FindUser(User{UserUID: "u000000000000002"})
|
||||
|
||||
|
|
|
@ -41,13 +41,14 @@ func (t ProviderType) IsLocal() bool {
|
|||
|
||||
// String returns the provider identifier as a string.
|
||||
func (t ProviderType) String() string {
|
||||
if t == ProviderUnknown {
|
||||
switch t {
|
||||
case "":
|
||||
return string(ProviderDefault)
|
||||
} else if t == "password" {
|
||||
case "password":
|
||||
return string(ProviderLocal)
|
||||
default:
|
||||
return string(t)
|
||||
}
|
||||
|
||||
return string(t)
|
||||
}
|
||||
|
||||
// Pretty returns the provider identifier in an easy-to-read format.
|
||||
|
|
Loading…
Reference in a new issue