These changes ensure that OAuth2 clients cannot create an unlimited number of access tokens (sessions) with their client credentials. Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
42fbf70dcf
commit
ec13ccb6d5
8 changed files with 98 additions and 20 deletions
2
Makefile
2
Makefile
|
@ -163,6 +163,8 @@ stop:
|
||||||
./photoprism stop
|
./photoprism stop
|
||||||
terminal:
|
terminal:
|
||||||
$(DOCKER_COMPOSE) exec -u $(UID) photoprism bash
|
$(DOCKER_COMPOSE) exec -u $(UID) photoprism bash
|
||||||
|
mariadb:
|
||||||
|
$(DOCKER_COMPOSE) exec mariadb mariadb -uroot -pphotoprism photoprism
|
||||||
rootshell: root-terminal
|
rootshell: root-terminal
|
||||||
root-terminal:
|
root-terminal:
|
||||||
$(DOCKER_COMPOSE) exec -u root photoprism bash
|
$(DOCKER_COMPOSE) exec -u root photoprism bash
|
||||||
|
|
|
@ -3,8 +3,8 @@ package api
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/dustin/go-humanize/english"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/photoprism/photoprism/pkg/authn"
|
|
||||||
|
|
||||||
"github.com/photoprism/photoprism/internal/acl"
|
"github.com/photoprism/photoprism/internal/acl"
|
||||||
"github.com/photoprism/photoprism/internal/entity"
|
"github.com/photoprism/photoprism/internal/entity"
|
||||||
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/photoprism/photoprism/internal/get"
|
"github.com/photoprism/photoprism/internal/get"
|
||||||
"github.com/photoprism/photoprism/internal/i18n"
|
"github.com/photoprism/photoprism/internal/i18n"
|
||||||
"github.com/photoprism/photoprism/internal/server/limiter"
|
"github.com/photoprism/photoprism/internal/server/limiter"
|
||||||
|
"github.com/photoprism/photoprism/pkg/authn"
|
||||||
"github.com/photoprism/photoprism/pkg/clean"
|
"github.com/photoprism/photoprism/pkg/clean"
|
||||||
"github.com/photoprism/photoprism/pkg/rnd"
|
"github.com/photoprism/photoprism/pkg/rnd"
|
||||||
)
|
)
|
||||||
|
@ -28,7 +29,7 @@ func CreateOAuthToken(router *gin.RouterGroup) {
|
||||||
|
|
||||||
// Abort if running in public mode.
|
// Abort if running in public mode.
|
||||||
if get.Config().Public() {
|
if get.Config().Public() {
|
||||||
event.AuditErr([]string{clientIP, "create client session in public mode", "denied"})
|
event.AuditErr([]string{clientIP, "create client session", "disabled in public mode"})
|
||||||
Abort(c, http.StatusForbidden, i18n.ErrForbidden)
|
Abort(c, http.StatusForbidden, i18n.ErrForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -42,14 +43,14 @@ func CreateOAuthToken(router *gin.RouterGroup) {
|
||||||
f.ClientID = clientId
|
f.ClientID = clientId
|
||||||
f.ClientSecret = clientSecret
|
f.ClientSecret = clientSecret
|
||||||
} else if err = c.Bind(&f); err != nil {
|
} else if err = c.Bind(&f); err != nil {
|
||||||
event.AuditWarn([]string{clientIP, "oauth", "%s"}, err)
|
event.AuditWarn([]string{clientIP, "create client session", "%s"}, err)
|
||||||
AbortBadRequest(c)
|
AbortBadRequest(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the credentials for completeness and the correct format.
|
// Check the credentials for completeness and the correct format.
|
||||||
if err = f.Validate(); err != nil {
|
if err = f.Validate(); err != nil {
|
||||||
event.AuditWarn([]string{clientIP, "oauth", "%s"}, err)
|
event.AuditWarn([]string{clientIP, "create client session", "%s"}, err)
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -70,7 +71,7 @@ func CreateOAuthToken(router *gin.RouterGroup) {
|
||||||
limiter.Login.Reserve(clientIP)
|
limiter.Login.Reserve(clientIP)
|
||||||
return
|
return
|
||||||
} else if !client.AuthEnabled {
|
} else if !client.AuthEnabled {
|
||||||
event.AuditWarn([]string{clientIP, "client %s", "create session", "disabled"}, f.ClientID)
|
event.AuditWarn([]string{clientIP, "client %s", "create session", "authentication disabled"}, f.ClientID)
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
|
||||||
return
|
return
|
||||||
} else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 {
|
} else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 {
|
||||||
|
@ -87,8 +88,6 @@ func CreateOAuthToken(router *gin.RouterGroup) {
|
||||||
// Create new client session.
|
// Create new client session.
|
||||||
sess := client.NewSession(c)
|
sess := client.NewSession(c)
|
||||||
|
|
||||||
// TODO: Enforce limit for maximum number of access tokens.
|
|
||||||
|
|
||||||
// Try to log in and save session if successful.
|
// Try to log in and save session if successful.
|
||||||
if sess, err = get.Session().Save(sess); err != nil {
|
if sess, err = get.Session().Save(sess); err != nil {
|
||||||
event.AuditErr([]string{clientIP, "client %s", "create session", "%s"}, f.ClientID, err)
|
event.AuditErr([]string{clientIP, "client %s", "create session", "%s"}, f.ClientID, err)
|
||||||
|
@ -102,6 +101,11 @@ func CreateOAuthToken(router *gin.RouterGroup) {
|
||||||
event.AuditInfo([]string{clientIP, "client %s", "session %s", "created"}, f.ClientID, sess.RefID)
|
event.AuditInfo([]string{clientIP, "client %s", "session %s", "created"}, f.ClientID, sess.RefID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deletes old client sessions above the configured limit.
|
||||||
|
if deleted := client.EnforceAuthTokenLimit(); deleted > 0 {
|
||||||
|
event.AuditInfo([]string{clientIP, "client %s", "%s deleted"}, f.ClientID, english.Plural(deleted, "old session", "old sessions"))
|
||||||
|
}
|
||||||
|
|
||||||
// Response includes access token, token type, and token lifetime.
|
// Response includes access token, token type, and token lifetime.
|
||||||
data := gin.H{
|
data := gin.H{
|
||||||
"access_token": sess.AuthToken(),
|
"access_token": sess.AuthToken(),
|
||||||
|
@ -125,7 +129,7 @@ func DeleteOAuthToken(router *gin.RouterGroup) {
|
||||||
|
|
||||||
// Abort if running in public mode.
|
// Abort if running in public mode.
|
||||||
if get.Config().Public() {
|
if get.Config().Public() {
|
||||||
event.AuditErr([]string{clientIP, "delete client session in public mode", "denied"})
|
event.AuditErr([]string{clientIP, "delete client session", "disabled in public mode"})
|
||||||
Abort(c, http.StatusForbidden, i18n.ErrForbidden)
|
Abort(c, http.StatusForbidden, i18n.ErrForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,9 +256,6 @@ func (m *Client) UpdateLastActive() *Client {
|
||||||
|
|
||||||
// NewSession creates a new client session.
|
// NewSession creates a new client session.
|
||||||
func (m *Client) NewSession(c *gin.Context) *Session {
|
func (m *Client) NewSession(c *gin.Context) *Session {
|
||||||
// Update activity timestamp.
|
|
||||||
m.UpdateLastActive()
|
|
||||||
|
|
||||||
// Create, initialize, and return new session.
|
// Create, initialize, and return new session.
|
||||||
sess := NewSession(m.AuthExpires, 0).SetContext(c)
|
sess := NewSession(m.AuthExpires, 0).SetContext(c)
|
||||||
sess.AuthID = m.UID()
|
sess.AuthID = m.UID()
|
||||||
|
@ -270,6 +267,19 @@ func (m *Client) NewSession(c *gin.Context) *Session {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnforceAuthTokenLimit deletes client sessions above the configured limit and returns the number of deleted sessions.
|
||||||
|
func (m *Client) EnforceAuthTokenLimit() (deleted int) {
|
||||||
|
if m == nil {
|
||||||
|
return 0
|
||||||
|
} else if !m.HasUID() {
|
||||||
|
return 0
|
||||||
|
} else if m.AuthTokens < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return DeleteClientSessions(m.ClientUID, m.AuthTokens)
|
||||||
|
}
|
||||||
|
|
||||||
// Expires returns the auth expiration duration.
|
// Expires returns the auth expiration duration.
|
||||||
func (m *Client) Expires() time.Duration {
|
func (m *Client) Expires() time.Duration {
|
||||||
return time.Duration(m.AuthExpires) * time.Second
|
return time.Duration(m.AuthExpires) * time.Second
|
||||||
|
|
|
@ -93,18 +93,44 @@ func (m *Session) Expires(t time.Time) *Session {
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteExpiredSessions deletes expired sessions.
|
// DeleteExpiredSessions deletes all expired sessions.
|
||||||
func DeleteExpiredSessions() (deleted int) {
|
func DeleteExpiredSessions() (deleted int) {
|
||||||
expired := Sessions{}
|
found := Sessions{}
|
||||||
|
|
||||||
if err := Db().Where("sess_expires > 0 AND sess_expires < ?", UnixTime()).Find(&expired).Error; err != nil {
|
if err := Db().Where("sess_expires > 0 AND sess_expires < ?", UnixTime()).Find(&found).Error; err != nil {
|
||||||
event.AuditErr([]string{"failed to fetch sessions sessions", "%s"}, err)
|
event.AuditErr([]string{"failed to fetch expired sessions", "%s"}, err)
|
||||||
return deleted
|
return deleted
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range expired {
|
for _, sess := range found {
|
||||||
if err := s.Delete(); err != nil {
|
if err := sess.Delete(); err != nil {
|
||||||
event.AuditErr([]string{s.IP(), "session %s", "failed to delete", "%s"}, s.RefID, err)
|
event.AuditErr([]string{sess.IP(), "session %s", "failed to delete", "%s"}, sess.RefID, err)
|
||||||
|
} else {
|
||||||
|
deleted++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteClientSessions deletes client sessions above the specified limit.
|
||||||
|
func DeleteClientSessions(clientUID string, limit int64) (deleted int) {
|
||||||
|
if !rnd.IsUID(clientUID, ClientUID) || limit < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
found := Sessions{}
|
||||||
|
|
||||||
|
if err := Db().Where("auth_id = ?", clientUID).
|
||||||
|
Order("created_at DESC").Limit(2147483648).Offset(limit).
|
||||||
|
Find(&found).Error; err != nil {
|
||||||
|
event.AuditErr([]string{"failed to fetch client sessions", "%s"}, err)
|
||||||
|
return deleted
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sess := range found {
|
||||||
|
if err := sess.Delete(); err != nil {
|
||||||
|
event.AuditErr([]string{sess.IP(), "session %s", "failed to delete", "%s"}, sess.RefID, err)
|
||||||
} else {
|
} else {
|
||||||
deleted++
|
deleted++
|
||||||
}
|
}
|
||||||
|
@ -616,7 +642,7 @@ func (m *Session) Invalid() bool {
|
||||||
|
|
||||||
// Valid checks whether the session belongs to a registered user or a visitor with shares.
|
// Valid checks whether the session belongs to a registered user or a visitor with shares.
|
||||||
func (m *Session) Valid() bool {
|
func (m *Session) Valid() bool {
|
||||||
if m.AuthMethod == authn.MethodOAuth2.String() {
|
if m.IsClient() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,6 +97,36 @@ func TestDeleteExpiredSessions(t *testing.T) {
|
||||||
assert.Equal(t, 1, DeleteExpiredSessions())
|
assert.Equal(t, 1, DeleteExpiredSessions())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteClientSessions(t *testing.T) {
|
||||||
|
clientUID := "cs5gfen1bgx00000"
|
||||||
|
|
||||||
|
// Make sure no sessions exist yet and test missing arguments.
|
||||||
|
assert.Equal(t, 0, DeleteClientSessions("", -1))
|
||||||
|
assert.Equal(t, 0, DeleteClientSessions(clientUID, -1))
|
||||||
|
assert.Equal(t, 0, DeleteClientSessions(clientUID, 0))
|
||||||
|
assert.Equal(t, 0, DeleteClientSessions("", 0))
|
||||||
|
|
||||||
|
// Create 10 client sessions.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
sess := NewSession(3600, 0)
|
||||||
|
sess.SetClientIP(UnknownIP)
|
||||||
|
sess.AuthID = clientUID
|
||||||
|
sess.AuthProvider = authn.ProviderClient.String()
|
||||||
|
sess.AuthMethod = authn.MethodOAuth2.String()
|
||||||
|
sess.AuthScope = "*"
|
||||||
|
|
||||||
|
if err := sess.Save(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the expected number of sessions is deleted until none are left.
|
||||||
|
assert.Equal(t, 0, DeleteClientSessions(clientUID, -1))
|
||||||
|
assert.Equal(t, 9, DeleteClientSessions(clientUID, 1))
|
||||||
|
assert.Equal(t, 1, DeleteClientSessions(clientUID, 0))
|
||||||
|
assert.Equal(t, 0, DeleteClientSessions(clientUID, 0))
|
||||||
|
}
|
||||||
|
|
||||||
func TestSessionStatusUnauthorized(t *testing.T) {
|
func TestSessionStatusUnauthorized(t *testing.T) {
|
||||||
m := SessionStatusUnauthorized()
|
m := SessionStatusUnauthorized()
|
||||||
assert.Equal(t, 401, m.Status)
|
assert.Equal(t, 401, m.Status)
|
||||||
|
|
|
@ -14,8 +14,14 @@ func (s *Session) Save(m *entity.Session) (*entity.Session, error) {
|
||||||
return nil, fmt.Errorf("session is nil")
|
return nil, fmt.Errorf("session is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update last active timestamp.
|
||||||
|
m.LastActive = entity.UnixTime()
|
||||||
|
|
||||||
// Save session.
|
// Save session.
|
||||||
return m.UpdateLastActive(), m.Save()
|
err := m.Save()
|
||||||
|
|
||||||
|
// Return session.
|
||||||
|
return m, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create initializes a new client session and returns it.
|
// Create initializes a new client session and returns it.
|
||||||
|
|
Loading…
Reference in a new issue