From e03dbe5d16ac47c27552369bca2b12bac1ea279d Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Tue, 9 Jan 2024 13:46:55 +0100 Subject: [PATCH] OAuth2: Refactor limit for number of access tokens / sessions #808 #3943 Signed-off-by: Michael Mayer --- internal/entity/auth_client.go | 2 +- internal/entity/auth_session.go | 11 ++++++++--- internal/entity/auth_session_test.go | 17 +++++++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/internal/entity/auth_client.go b/internal/entity/auth_client.go index 716ff4e99..90129ff88 100644 --- a/internal/entity/auth_client.go +++ b/internal/entity/auth_client.go @@ -277,7 +277,7 @@ func (m *Client) EnforceAuthTokenLimit() (deleted int) { return 0 } - return DeleteClientSessions(m.ClientUID, m.AuthTokens) + return DeleteClientSessions(m.ClientUID, authn.MethodOAuth2, m.AuthTokens) } // Expires returns the auth expiration duration. diff --git a/internal/entity/auth_session.go b/internal/entity/auth_session.go index 7900a1562..9e74ae4ba 100644 --- a/internal/entity/auth_session.go +++ b/internal/entity/auth_session.go @@ -114,15 +114,20 @@ func DeleteExpiredSessions() (deleted int) { } // DeleteClientSessions deletes client sessions above the specified limit. -func DeleteClientSessions(clientUID string, limit int64) (deleted int) { +func DeleteClientSessions(clientUID string, authMethod authn.MethodType, limit int64) (deleted int) { if !rnd.IsUID(clientUID, ClientUID) || limit < 0 { return 0 } found := Sessions{} - if err := Db().Where("auth_id = ?", clientUID). - Order("created_at DESC").Limit(2147483648).Offset(limit). + q := Db().Where("auth_id = ? AND auth_provider = ?", clientUID, authn.ProviderClient.String()) + + if !authMethod.IsDefault() { + q = q.Where("auth_method = ?", authMethod.String()) + } + + if err := q.Order("created_at DESC").Limit(2147483648).Offset(limit). Find(&found).Error; err != nil { event.AuditErr([]string{"failed to fetch client sessions", "%s"}, err) return deleted diff --git a/internal/entity/auth_session_test.go b/internal/entity/auth_session_test.go index 71dd19dca..db6660e26 100644 --- a/internal/entity/auth_session_test.go +++ b/internal/entity/auth_session_test.go @@ -101,10 +101,10 @@ func TestDeleteClientSessions(t *testing.T) { clientUID := "cs5gfen1bgx00000" // Make sure no sessions exist yet and test missing arguments. - assert.Equal(t, 0, DeleteClientSessions("", -1)) - assert.Equal(t, 0, DeleteClientSessions(clientUID, -1)) - assert.Equal(t, 0, DeleteClientSessions(clientUID, 0)) - assert.Equal(t, 0, DeleteClientSessions("", 0)) + assert.Equal(t, 0, DeleteClientSessions("", "", -1)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, -1)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0)) + assert.Equal(t, 0, DeleteClientSessions("", authn.MethodDefault, 0)) // Create 10 client sessions. for i := 0; i < 10; i++ { @@ -121,10 +121,11 @@ func TestDeleteClientSessions(t *testing.T) { } // Check if the expected number of sessions is deleted until none are left. - assert.Equal(t, 0, DeleteClientSessions(clientUID, -1)) - assert.Equal(t, 9, DeleteClientSessions(clientUID, 1)) - assert.Equal(t, 1, DeleteClientSessions(clientUID, 0)) - assert.Equal(t, 0, DeleteClientSessions(clientUID, 0)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, -1)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOIDC, 1)) + assert.Equal(t, 9, DeleteClientSessions(clientUID, authn.MethodOAuth2, 1)) + assert.Equal(t, 1, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0)) + assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0)) } func TestSessionStatusUnauthorized(t *testing.T) {