Auth: Prevent duplicate super admin accounts from being created #98

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2023-03-09 15:59:08 +01:00
parent d8ab9616a5
commit 50913e301c
3 changed files with 22 additions and 8 deletions

View file

@ -112,7 +112,9 @@ func FindUser(find User) *User {
// Build query.
stmt := UnscopedDb()
if find.ID != 0 {
if find.ID != 0 && find.UserName != "" {
stmt = stmt.Where("id = ? OR user_name = ?", find.ID, find.UserName)
} else if find.ID != 0 {
stmt = stmt.Where("id = ?", find.ID)
} else if rnd.IsUID(find.UserUID, UserUID) {
stmt = stmt.Where("user_uid = ?", find.UserUID)
@ -495,8 +497,6 @@ func (m *User) Provider() authn.ProviderType {
func (m *User) SetProvider(t authn.ProviderType) *User {
if m == nil {
return nil
} else if m.ID <= 0 {
return m
}
m.AuthProvider = t.String()
@ -781,7 +781,7 @@ func (m *User) Validate() (err error) {
if err = Db().
Where("user_name = ? AND id <> ?", m.UserName, m.ID).
First(&duplicate).Error; err == nil {
return fmt.Errorf("username %s already exists", clean.LogQuote(m.UserName))
return fmt.Errorf("user %s already exists", clean.LogQuote(m.UserName))
} else if err != gorm.ErrRecordNotFound {
return err
}

View file

@ -331,6 +331,19 @@ func TestFindUser(t *testing.T) {
assert.NotEmpty(t, m.CreatedAt)
assert.NotEmpty(t, m.UpdatedAt)
})
t.Run("Admin", func(t *testing.T) {
m := FindUser(User{ID: 2, UserName: "admin"})
if m == nil {
t.Fatal("result should not be nil")
}
assert.Equal(t, 1, m.ID)
assert.NotEmpty(t, m.UserUID)
assert.Equal(t, "admin", m.UserName)
assert.NotEmpty(t, m.CreatedAt)
assert.NotEmpty(t, m.UpdatedAt)
})
t.Run("UserUID", func(t *testing.T) {
m := FindUser(User{UserUID: "u000000000000002"})

View file

@ -41,13 +41,14 @@ func (t ProviderType) IsLocal() bool {
// String returns the provider identifier as a string.
func (t ProviderType) String() string {
if t == ProviderUnknown {
switch t {
case "":
return string(ProviderDefault)
} else if t == "password" {
case "password":
return string(ProviderLocal)
default:
return string(t)
}
return string(t)
}
// Pretty returns the provider identifier in an easy-to-read format.