From ec13ccb6d5c40e658d4df9ab02df328f411bd55d Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Mon, 8 Jan 2024 16:57:07 +0100 Subject: [PATCH] OAuth2: Enforce limit for number of access tokens / sessions #808 #3943 These changes ensure that OAuth2 clients cannot create an unlimited number of access tokens (sessions) with their client credentials. Signed-off-by: Michael Mayer --- Makefile | 2 + internal/api/{session_acl.go => api_acl.go} | 0 .../{session_acl_test.go => api_acl_test.go} | 0 internal/api/session_oauth.go | 20 +++++---- internal/entity/auth_client.go | 16 +++++-- internal/entity/auth_session.go | 42 +++++++++++++++---- internal/entity/auth_session_test.go | 30 +++++++++++++ internal/session/session_save.go | 8 +++- 8 files changed, 98 insertions(+), 20 deletions(-) rename internal/api/{session_acl.go => api_acl.go} (100%) rename internal/api/{session_acl_test.go => api_acl_test.go} (100%) diff --git a/Makefile b/Makefile index 1dc55a1fe..3a395177e 100644 --- a/Makefile +++ b/Makefile @@ -163,6 +163,8 @@ stop: ./photoprism stop terminal: $(DOCKER_COMPOSE) exec -u $(UID) photoprism bash +mariadb: + $(DOCKER_COMPOSE) exec mariadb mariadb -uroot -pphotoprism photoprism rootshell: root-terminal root-terminal: $(DOCKER_COMPOSE) exec -u root photoprism bash diff --git a/internal/api/session_acl.go b/internal/api/api_acl.go similarity index 100% rename from internal/api/session_acl.go rename to internal/api/api_acl.go diff --git a/internal/api/session_acl_test.go b/internal/api/api_acl_test.go similarity index 100% rename from internal/api/session_acl_test.go rename to internal/api/api_acl_test.go diff --git a/internal/api/session_oauth.go b/internal/api/session_oauth.go index 053ec12e9..3e441c932 100644 --- a/internal/api/session_oauth.go +++ b/internal/api/session_oauth.go @@ -3,8 +3,8 @@ package api import ( "net/http" + "github.com/dustin/go-humanize/english" "github.com/gin-gonic/gin" - "github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/internal/acl" "github.com/photoprism/photoprism/internal/entity" @@ -13,6 +13,7 @@ import ( "github.com/photoprism/photoprism/internal/get" "github.com/photoprism/photoprism/internal/i18n" "github.com/photoprism/photoprism/internal/server/limiter" + "github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/rnd" ) @@ -28,7 +29,7 @@ func CreateOAuthToken(router *gin.RouterGroup) { // Abort if running in public mode. if get.Config().Public() { - event.AuditErr([]string{clientIP, "create client session in public mode", "denied"}) + event.AuditErr([]string{clientIP, "create client session", "disabled in public mode"}) Abort(c, http.StatusForbidden, i18n.ErrForbidden) return } @@ -42,14 +43,14 @@ func CreateOAuthToken(router *gin.RouterGroup) { f.ClientID = clientId f.ClientSecret = clientSecret } else if err = c.Bind(&f); err != nil { - event.AuditWarn([]string{clientIP, "oauth", "%s"}, err) + event.AuditWarn([]string{clientIP, "create client session", "%s"}, err) AbortBadRequest(c) return } // Check the credentials for completeness and the correct format. if err = f.Validate(); err != nil { - event.AuditWarn([]string{clientIP, "oauth", "%s"}, err) + event.AuditWarn([]string{clientIP, "create client session", "%s"}, err) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) return } @@ -70,7 +71,7 @@ func CreateOAuthToken(router *gin.RouterGroup) { limiter.Login.Reserve(clientIP) return } else if !client.AuthEnabled { - event.AuditWarn([]string{clientIP, "client %s", "create session", "disabled"}, f.ClientID) + event.AuditWarn([]string{clientIP, "client %s", "create session", "authentication disabled"}, f.ClientID) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) return } else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 { @@ -87,8 +88,6 @@ func CreateOAuthToken(router *gin.RouterGroup) { // Create new client session. sess := client.NewSession(c) - // TODO: Enforce limit for maximum number of access tokens. - // Try to log in and save session if successful. if sess, err = get.Session().Save(sess); err != nil { event.AuditErr([]string{clientIP, "client %s", "create session", "%s"}, f.ClientID, err) @@ -102,6 +101,11 @@ func CreateOAuthToken(router *gin.RouterGroup) { event.AuditInfo([]string{clientIP, "client %s", "session %s", "created"}, f.ClientID, sess.RefID) } + // Deletes old client sessions above the configured limit. + if deleted := client.EnforceAuthTokenLimit(); deleted > 0 { + event.AuditInfo([]string{clientIP, "client %s", "%s deleted"}, f.ClientID, english.Plural(deleted, "old session", "old sessions")) + } + // Response includes access token, token type, and token lifetime. data := gin.H{ "access_token": sess.AuthToken(), @@ -125,7 +129,7 @@ func DeleteOAuthToken(router *gin.RouterGroup) { // Abort if running in public mode. if get.Config().Public() { - event.AuditErr([]string{clientIP, "delete client session in public mode", "denied"}) + event.AuditErr([]string{clientIP, "delete client session", "disabled in public mode"}) Abort(c, http.StatusForbidden, i18n.ErrForbidden) return } diff --git a/internal/entity/auth_client.go b/internal/entity/auth_client.go index 3410cfe94..716ff4e99 100644 --- a/internal/entity/auth_client.go +++ b/internal/entity/auth_client.go @@ -256,9 +256,6 @@ func (m *Client) UpdateLastActive() *Client { // NewSession creates a new client session. func (m *Client) NewSession(c *gin.Context) *Session { - // Update activity timestamp. - m.UpdateLastActive() - // Create, initialize, and return new session. sess := NewSession(m.AuthExpires, 0).SetContext(c) sess.AuthID = m.UID() @@ -270,6 +267,19 @@ func (m *Client) NewSession(c *gin.Context) *Session { return sess } +// EnforceAuthTokenLimit deletes client sessions above the configured limit and returns the number of deleted sessions. +func (m *Client) EnforceAuthTokenLimit() (deleted int) { + if m == nil { + return 0 + } else if !m.HasUID() { + return 0 + } else if m.AuthTokens < 0 { + return 0 + } + + return DeleteClientSessions(m.ClientUID, m.AuthTokens) +} + // Expires returns the auth expiration duration. func (m *Client) Expires() time.Duration { return time.Duration(m.AuthExpires) * time.Second diff --git a/internal/entity/auth_session.go b/internal/entity/auth_session.go index a48cc4838..52120f20c 100644 --- a/internal/entity/auth_session.go +++ b/internal/entity/auth_session.go @@ -93,18 +93,44 @@ func (m *Session) Expires(t time.Time) *Session { return m } -// DeleteExpiredSessions deletes expired sessions. +// DeleteExpiredSessions deletes all expired sessions. func DeleteExpiredSessions() (deleted int) { - expired := Sessions{} + found := Sessions{} - if err := Db().Where("sess_expires > 0 AND sess_expires < ?", UnixTime()).Find(&expired).Error; err != nil { - event.AuditErr([]string{"failed to fetch sessions sessions", "%s"}, err) + if err := Db().Where("sess_expires > 0 AND sess_expires < ?", UnixTime()).Find(&found).Error; err != nil { + event.AuditErr([]string{"failed to fetch expired sessions", "%s"}, err) return deleted } - for _, s := range expired { - if err := s.Delete(); err != nil { - event.AuditErr([]string{s.IP(), "session %s", "failed to delete", "%s"}, s.RefID, err) + for _, sess := range found { + if err := sess.Delete(); err != nil { + event.AuditErr([]string{sess.IP(), "session %s", "failed to delete", "%s"}, sess.RefID, err) + } else { + deleted++ + } + } + + return deleted +} + +// DeleteClientSessions deletes client sessions above the specified limit. +func DeleteClientSessions(clientUID string, limit int64) (deleted int) { + if !rnd.IsUID(clientUID, ClientUID) || limit < 0 { + return 0 + } + + found := Sessions{} + + if err := Db().Where("auth_id = ?", clientUID). + Order("created_at DESC").Limit(2147483648).Offset(limit). + Find(&found).Error; err != nil { + event.AuditErr([]string{"failed to fetch client sessions", "%s"}, err) + return deleted + } + + for _, sess := range found { + if err := sess.Delete(); err != nil { + event.AuditErr([]string{sess.IP(), "session %s", "failed to delete", "%s"}, sess.RefID, err) } else { deleted++ } @@ -616,7 +642,7 @@ func (m *Session) Invalid() bool { // Valid checks whether the session belongs to a registered user or a visitor with shares. func (m *Session) Valid() bool { - if m.AuthMethod == authn.MethodOAuth2.String() { + if m.IsClient() { return true } diff --git a/internal/entity/auth_session_test.go b/internal/entity/auth_session_test.go index d607efa13..e27d3246c 100644 --- a/internal/entity/auth_session_test.go +++ b/internal/entity/auth_session_test.go @@ -97,6 +97,36 @@ func TestDeleteExpiredSessions(t *testing.T) { assert.Equal(t, 1, DeleteExpiredSessions()) } +func TestDeleteClientSessions(t *testing.T) { + clientUID := "cs5gfen1bgx00000" + + // Make sure no sessions exist yet and test missing arguments. + assert.Equal(t, 0, DeleteClientSessions("", -1)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, -1)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, 0)) + assert.Equal(t, 0, DeleteClientSessions("", 0)) + + // Create 10 client sessions. + for i := 0; i < 10; i++ { + sess := NewSession(3600, 0) + sess.SetClientIP(UnknownIP) + sess.AuthID = clientUID + sess.AuthProvider = authn.ProviderClient.String() + sess.AuthMethod = authn.MethodOAuth2.String() + sess.AuthScope = "*" + + if err := sess.Save(); err != nil { + t.Fatal(err) + } + } + + // Check if the expected number of sessions is deleted until none are left. + assert.Equal(t, 0, DeleteClientSessions(clientUID, -1)) + assert.Equal(t, 9, DeleteClientSessions(clientUID, 1)) + assert.Equal(t, 1, DeleteClientSessions(clientUID, 0)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, 0)) +} + func TestSessionStatusUnauthorized(t *testing.T) { m := SessionStatusUnauthorized() assert.Equal(t, 401, m.Status) diff --git a/internal/session/session_save.go b/internal/session/session_save.go index 12226cf46..07eb8b49e 100644 --- a/internal/session/session_save.go +++ b/internal/session/session_save.go @@ -14,8 +14,14 @@ func (s *Session) Save(m *entity.Session) (*entity.Session, error) { return nil, fmt.Errorf("session is nil") } + // Update last active timestamp. + m.LastActive = entity.UnixTime() + // Save session. - return m.UpdateLastActive(), m.Save() + err := m.Save() + + // Return session. + return m, err } // Create initializes a new client session and returns it.