diff --git a/internal/entity/embedding.go b/internal/entity/embedding.go index c9e433492..4239b11b7 100644 --- a/internal/entity/embedding.go +++ b/internal/entity/embedding.go @@ -27,7 +27,7 @@ func EmbeddingsMidpoint(m Embeddings) (result Embedding, radius float64, count i } if d := clusters.EuclideanDistance(result, emb); d > radius { - radius = d + radius = d + 0.01 } } diff --git a/internal/entity/face.go b/internal/entity/face.go index 45048840d..02f32af24 100644 --- a/internal/entity/face.go +++ b/internal/entity/face.go @@ -26,6 +26,7 @@ type Face struct { CollisionRadius float64 `json:"CollisionRadius" yaml:"CollisionRadius,omitempty"` EmbeddingJSON json.RawMessage `gorm:"type:MEDIUMBLOB;" json:"-" yaml:"EmbeddingJSON,omitempty"` embedding Embedding `gorm:"-"` + MatchedAt *time.Time `json:"MatchedAt" yaml:"MatchedAt,omitempty"` CreatedAt time.Time `json:"CreatedAt" yaml:"CreatedAt,omitempty"` UpdatedAt time.Time `json:"UpdatedAt" yaml:"UpdatedAt,omitempty"` } @@ -82,6 +83,13 @@ func (m *Face) SetEmbeddings(embeddings Embeddings) (err error) { return nil } +// UpdateMatchTime updates the match timestamp. +func (m *Face) UpdateMatchTime() error { + matched := Timestamp() + m.MatchedAt = &matched + return UnscopedDb().Model(m).UpdateColumns(Values{"MatchedAt": m.MatchedAt}).Error +} + // Embedding returns parsed face embedding. func (m *Face) Embedding() Embedding { if len(m.EmbeddingJSON) == 0 { @@ -168,7 +176,10 @@ func (m *Face) ReportCollision(embeddings Embeddings) (reported bool, err error) if err == nil && revise { var revised Markers revised, err = m.ReviseMatches() - log.Infof("faces: revised %d matches after collision", len(revised)) + + if n := len(revised); n > 0 { + log.Infof("faces: revised %d matches after collision", n) + } } return true, err @@ -198,6 +209,30 @@ func (m *Face) ReviseMatches() (revised Markers, err error) { return revised, nil } +// MatchMarkers finds and references matching markers. +func (m *Face) MatchMarkers() error { + var markers Markers + + err := Db(). + Where("face_id = '' AND marker_invalid = 0 AND marker_type = ?", MarkerFace). + Find(&markers).Error + + if err != nil { + log.Debugf("faces: %s (match markers)", err) + return err + } + + for _, marker := range markers { + if ok, _ := m.Match(marker.Embeddings()); !ok { + // Ignore. + } else if _, err = marker.SetFace(m); err != nil { + return err + } + } + + return nil +} + // Save updates the existing or inserts a new face. func (m *Face) Save() error { faceMutex.Lock() diff --git a/internal/entity/face_fixtures.go b/internal/entity/face_fixtures.go index 9ffd81ea8..103141368 100644 --- a/internal/entity/face_fixtures.go +++ b/internal/entity/face_fixtures.go @@ -38,6 +38,7 @@ var FaceFixtures = FaceMap{ SampleRadius: 0, Samples: 1, Collisions: 0, + MatchedAt: &editTime, CreatedAt: Timestamp(), UpdatedAt: Timestamp(), }, diff --git a/internal/entity/marker.go b/internal/entity/marker.go index 9b5f8fb34..a816284a9 100644 --- a/internal/entity/marker.go +++ b/internal/entity/marker.go @@ -120,8 +120,15 @@ func (m *Marker) SetFace(f *Face) (updated bool, err error) { // Any reason we don't want to set a new face for this marker? if m.SubjectSrc != SrcManual || f.SubjectUID == m.SubjectUID { // Don't skip if subject wasn't set manually, or subjects match. - } else if f.SubjectUID != "" { - log.Debugf("faces: ambiguous subjects %s / %s for marker %d", txt.Quote(f.SubjectUID), txt.Quote(m.SubjectUID), m.ID) + } else if f.SubjectUID != "" && m.SubjectUID == "" { + log.Debugf("faces: rejected subject %s for marker %d with unknown subject, source %s", txt.Quote(f.SubjectUID), m.ID, m.SubjectSrc) + return false, nil + } else if reported, err := f.ReportCollision(m.Embeddings()); err != nil { + return false, err + } else if reported { + log.Infof("faces: marker %d (subject %s) collision with %s (subject %s), source %s", m.ID, m.SubjectUID, f.ID, f.SubjectUID, m.SubjectSrc) + return false, nil + } else { return false, nil } @@ -307,18 +314,18 @@ func (m *Marker) GetFace() (f *Face) { if m.FaceID == "" && m.SubjectSrc == SrcManual { if f = NewFace(m.SubjectUID, SrcManual, m.Embeddings()); f == nil { return nil - } else if f = FirstOrCreateFace(f); f == nil { - log.Debugf("marker: invalid face") + } else if err := f.Create(); err != nil { return nil + } else if err := f.MatchMarkers(); err != nil { + log.Errorf("faces: %s (match markers)", err) } + m.Face = f m.FaceID = f.ID - - return f + } else { + m.Face = FindFace(m.FaceID) } - m.Face = FindFace(m.FaceID) - return m.Face } diff --git a/internal/entity/time.go b/internal/entity/time.go index 83f46106e..3515c4c94 100644 --- a/internal/entity/time.go +++ b/internal/entity/time.go @@ -9,6 +9,12 @@ func Timestamp() time.Time { return time.Now().UTC().Round(time.Second) } +// TimestampPointer returns a current timestamp pointer. +func TimestampPointer() *time.Time { + t := Timestamp() + return &t +} + // Seconds converts an int to a duration in seconds. func Seconds(s int) time.Duration { return time.Duration(s) * time.Second diff --git a/internal/entity/time_test.go b/internal/entity/time_test.go index f1458b9a8..f4e4e0361 100644 --- a/internal/entity/time_test.go +++ b/internal/entity/time_test.go @@ -17,6 +17,22 @@ func TestTimestamp(t *testing.T) { } } +func TestTimestampPointer(t *testing.T) { + result := TimestampPointer() + + if result == nil { + t.Fatal("result must not be nil") + } + + if result.Location() != time.UTC { + t.Fatal("timestamp zone must be utc") + } + + if result.After(time.Now().Add(time.Second)) { + t.Fatal("timestamp should be in the past from now") + } +} + func TestSeconds(t *testing.T) { result := Seconds(23) diff --git a/internal/face/face.go b/internal/face/face.go index ca210e625..f8ed925ac 100644 --- a/internal/face/face.go +++ b/internal/face/face.go @@ -40,7 +40,7 @@ import ( var ClusterCore = 4 var ClusterRadius = 0.6 -var SampleThreshold = 25 +var SampleThreshold = 2 * ClusterCore var log = event.Log diff --git a/internal/photoprism/faces.go b/internal/photoprism/faces.go index 01aa0444a..9d584bf77 100644 --- a/internal/photoprism/faces.go +++ b/internal/photoprism/faces.go @@ -4,6 +4,8 @@ import ( "fmt" "runtime/debug" + "github.com/photoprism/photoprism/internal/entity" + "github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/mutex" "github.com/photoprism/photoprism/internal/query" @@ -23,6 +25,13 @@ func NewFaces(conf *config.Config) *Faces { return instance } +// StartDefault starts face clustering and matching with default options. +func (w *Faces) StartDefault() (err error) { + return w.Start(FacesOptions{ + Force: false, + }) +} + // Start face clustering and matching. func (w *Faces) Start(opt FacesOptions) (err error) { defer func() { @@ -42,78 +51,46 @@ func (w *Faces) Start(opt FacesOptions) (err error) { defer mutex.MainWorker.Stop() - // Skip clustering if index contains no new face markers and force option isn't set. - if n := query.CountNewFaceMarkers(); n < 1 && !opt.Force { - log.Debugf("faces: no new samples") - - var updated int64 - - // Adds and reference known marker subjects. - if affected, err := query.AddMarkerSubjects(); err != nil { - log.Errorf("faces: %s (match markers with subjects)", err) - } else { - updated += affected - } - - // Match markers with known faces. - if affected, err := query.MatchFaceMarkers(); err != nil { - return err - } else { - updated += affected - } - - // Log result. - if updated > 0 { - log.Infof("faces: %d markers updated", updated) - } else { - log.Debug("faces: no changes") - } - - // Remove invalid ids from marker table. - if err := query.CleanInvalidMarkerReferences(); err != nil { - log.Errorf("faces: %s (clean)", err) - } - - // Optimize existing face clusters. - if res, err := w.Optimize(); err != nil { - return err - } else if res.Merged > 0 { - log.Infof("faces: %d clusters merged", res.Merged) - } - - return nil + // Remove invalid reference IDs from markers table. + if removed, err := query.RemoveInvalidMarkerReferences(); err != nil { + log.Errorf("faces: %s (remove invalid references)", err) + } else if removed > 0 { + log.Infof("faces: removed %d invalid references", removed) } else { - log.Infof("faces: %d new samples", n) - } - - var clustersAdded, clustersRemoved int64 - - // Cluster existing face embeddings. - if clustersAdded, clustersRemoved, err = w.Cluster(opt); err != nil { - log.Errorf("faces: %s (cluster)", err) - } - - // Log face clustering results. - if (clustersAdded - clustersRemoved) != 0 { - log.Infof("faces: %d clusters added, %d removed", clustersAdded, clustersRemoved) - } else { - log.Debugf("faces: %d clusters added, %d removed", clustersAdded, clustersRemoved) - } - - // Remove invalid marker references. - if err = query.CleanInvalidMarkerReferences(); err != nil { - log.Errorf("faces: %s (clean)", err) + log.Debugf("faces: no invalid references") } // Optimize existing face clusters. if res, err := w.Optimize(); err != nil { return err } else if res.Merged > 0 { - log.Infof("faces: %d clusters merged", res.Merged) + log.Infof("faces: merged %d clusters", res.Merged) + } else { + log.Debugf("faces: no clusters could be merged") + } + + // Add known marker subjects. + if affected, err := query.AddMarkerSubjects(); err != nil { + log.Errorf("faces: %s (match markers with subjects)", err) + } else if affected > 0 { + log.Infof("faces: added %d known marker subjects", affected) + } else { + log.Debugf("faces: no subjects were missing") + } + + var added entity.Faces + + // Cluster existing face embeddings. + if added, err = w.Cluster(opt); err != nil { + log.Errorf("faces: %s (cluster)", err) + } else if n := len(added); n > 0 { + log.Infof("faces: added %d new faces", n) + } else { + log.Debugf("faces: found no new faces") } // Match markers with faces and subjects. - matches, err := w.Match() + matches, err := w.Match(opt) if err != nil { log.Errorf("faces: %s (match)", err) diff --git a/internal/photoprism/faces_analyze.go b/internal/photoprism/faces_analyze.go index 9eb39807c..6922adddc 100644 --- a/internal/photoprism/faces_analyze.go +++ b/internal/photoprism/faces_analyze.go @@ -8,7 +8,7 @@ import ( // Analyze face embeddings. func (w *Faces) Analyze() (err error) { - if embeddings, err := query.Embeddings(true); err != nil { + if embeddings, err := query.Embeddings(true, false); err != nil { return err } else if samples := len(embeddings); samples == 0 { log.Infof("faces: no samples found") diff --git a/internal/photoprism/faces_cluster.go b/internal/photoprism/faces_cluster.go index 2608a3b2c..fa87fec40 100644 --- a/internal/photoprism/faces_cluster.go +++ b/internal/photoprism/faces_cluster.go @@ -8,29 +8,47 @@ import ( ) // Cluster clusters indexed face embeddings. -func (w *Faces) Cluster(opt FacesOptions) (added int64, removed int64, err error) { - // Fetch and cluster all face embeddings. - embeddings, err := query.Embeddings(false) +func (w *Faces) Cluster(opt FacesOptions) (added entity.Faces, err error) { + if w.Disabled() { + return added, nil + } + + // Skip clustering if index contains no new face markers, and force option isn't set. + if opt.Force { + log.Infof("faces: forced clustering") + } else if n := query.CountNewFaceMarkers(); n < 1 { + log.Debugf("faces: skipping clustering") + return added, nil + } + + // Fetch unclustered face embeddings. + embeddings, err := query.Embeddings(false, true) + + log.Debugf("faces: %d unclustered samples found", len(embeddings)) // Anything that keeps us from doing this? if err != nil { - return added, removed, err + return added, err } else if samples := len(embeddings); samples < opt.SampleThreshold() { - log.Warnf("faces: at least %d samples needed for matching similar faces", face.SampleThreshold) - return added, removed, nil + log.Debugf("faces: at least %d samples needed for clustering", face.SampleThreshold) + return added, nil } else { var c clusters.HardClusterer // See https://dl.photoprism.org/research/ for research on face clustering algorithms. if c, err = clusters.DBSCAN(face.ClusterCore, face.ClusterRadius, w.conf.Workers(), clusters.EuclideanDistance); err != nil { - return added, removed, err + return added, err } else if err = c.Learn(embeddings); err != nil { - return added, removed, err + return added, err } sizes := c.Sizes() - log.Debugf("faces: %d samples in %d clusters", len(embeddings), len(sizes)) + if len(sizes) > 1 { + log.Infof("faces: found %d new clusters", len(sizes)) + } else { + log.Debugf("faces: found no new clusters") + } results := make([]entity.Embeddings, len(sizes)) @@ -48,23 +66,19 @@ func (w *Faces) Cluster(opt FacesOptions) (added int64, removed int64, err error results[n-1] = append(results[n-1], embeddings[i]) } - if removed, err = query.RemoveAnonymousFaceClusters(); err != nil { - log.Errorf("faces: %s", err) - } else if removed > 0 { - log.Debugf("faces: removed %d anonymous clusters", removed) - } - for _, cluster := range results { if f := entity.NewFace("", entity.SrcAuto, cluster); f == nil { log.Errorf("faces: face should not be nil - bug?") } else if err := f.Create(); err == nil { - added++ - log.Tracef("faces: added face %s", f.ID) + added = append(added, *f) + log.Debugf("faces: added cluster %s based on %d samples, radius %f", f.ID, f.Samples, f.SampleRadius) } else if err := f.Updates(entity.Values{"UpdatedAt": entity.Timestamp()}); err != nil { log.Errorf("faces: %s", err) + } else { + log.Debugf("faces: updated cluster %s", f.ID) } } } - return added, removed, nil + return added, nil } diff --git a/internal/photoprism/faces_match.go b/internal/photoprism/faces_match.go index 7b7ecab94..638e15c2a 100644 --- a/internal/photoprism/faces_match.go +++ b/internal/photoprism/faces_match.go @@ -17,11 +17,21 @@ type FacesMatchResult struct { } // Match matches markers with faces and subjects. -func (w *Faces) Match() (result FacesMatchResult, err error) { +func (w *Faces) Match(opt FacesOptions) (result FacesMatchResult, err error) { if w.Disabled() { return result, nil } + // Skip matching if index contains no new face markers, and force option isn't set. + if opt.Force { + log.Infof("faces: forced matching") + } else if n := query.CountUnmatchedFaceMarkers(); n > 0 { + log.Infof("faces: %d unmatched markers", n) + } else { + result.Recognized, err = query.MatchFaceMarkers() + return result, err + } + faces, err := query.Faces(false, "") if err != nil { @@ -96,17 +106,19 @@ func (w *Faces) Match() (result FacesMatchResult, err error) { time.Sleep(50 * time.Millisecond) } - // Update remaining markers based on current matches. + // Update face match timestamps. + for _, m := range faces { + if err := m.UpdateMatchTime(); err != nil { + log.Warnf("faces: %s (update match time)", err) + } + } + + // Update remaining markers based on previous matches. if m, err := query.MatchFaceMarkers(); err != nil { return result, err } else { result.Recognized += m } - // Reset invalid marker data. - if err := query.CleanInvalidMarkerReferences(); err != nil { - return result, err - } - return result, nil } diff --git a/internal/photoprism/faces_options.go b/internal/photoprism/faces_options.go index 771bdc7fa..7df7f8b65 100644 --- a/internal/photoprism/faces_options.go +++ b/internal/photoprism/faces_options.go @@ -16,3 +16,10 @@ func (o FacesOptions) SampleThreshold() int { // Return default. return face.SampleThreshold } + +// FacesOptionsDefault returns new faces options with default values. +func FacesOptionsDefault() FacesOptions { + result := FacesOptions{} + + return result +} diff --git a/internal/photoprism/import.go b/internal/photoprism/import.go index 1821a8b57..261a0ea91 100644 --- a/internal/photoprism/import.go +++ b/internal/photoprism/import.go @@ -241,7 +241,7 @@ func (imp *Import) Start(opt ImportOptions) fs.Done { // Match existing faces if facial recognition is enabled. if w := NewFaces(imp.conf); w.Disabled() { log.Debugf("import: skipping facial recognition") - } else if matches, err := w.Match(); err != nil { + } else if matches, err := w.Match(FacesOptionsDefault()); err != nil { log.Errorf("import: %s", err) } else if matches.Updated > 0 { log.Infof("import: %d markers updated, %d faces recognized, %d unknown", matches.Updated, matches.Recognized, matches.Unknown) diff --git a/internal/photoprism/index.go b/internal/photoprism/index.go index 0ccd77f07..c37d5e368 100644 --- a/internal/photoprism/index.go +++ b/internal/photoprism/index.go @@ -233,7 +233,7 @@ func (ind *Index) Start(opt IndexOptions) fs.Done { // Match existing faces if facial recognition is enabled. if w := NewFaces(ind.conf); w.Disabled() { log.Debugf("index: skipping facial recognition") - } else if matches, err := w.Match(); err != nil { + } else if matches, err := w.Match(FacesOptionsDefault()); err != nil { log.Errorf("index: %s", err) } else if matches.Updated > 0 { log.Infof("index: %d markers updated, %d faces recognized, %d unknown", matches.Updated, matches.Recognized, matches.Unknown) diff --git a/internal/query/faces.go b/internal/query/faces.go index 84024bced..6ddf318d5 100644 --- a/internal/query/faces.go +++ b/internal/query/faces.go @@ -19,7 +19,7 @@ func Faces(knownOnly bool, src string) (result entity.Faces, err error) { if knownOnly { stmt = stmt.Where("subject_uid <> ''").Order("subject_uid, samples DESC") } else { - stmt = stmt.Order("id") + stmt = stmt.Order("samples DESC") } err = stmt.Find(&result).Error @@ -35,16 +35,20 @@ func MatchFaceMarkers() (affected int64, err error) { return affected, err } - for _, match := range faces { + for _, f := range faces { if res := Db().Model(&entity.Marker{}). - Where("face_id = ?", match.ID). + Where("face_id = ?", f.ID). Where("subject_src = ?", entity.SrcAuto). - Where("subject_uid <> ?", match.SubjectUID). - Updates(entity.Values{"SubjectUID": match.SubjectUID}); res.Error != nil { + Where("subject_uid <> ?", f.SubjectUID). + Updates(entity.Values{"SubjectUID": f.SubjectUID}); res.Error != nil { return affected, err } else if res.RowsAffected > 0 { affected += res.RowsAffected } + + if err := f.UpdateMatchTime(); err != nil { + return affected, err + } } return affected, nil @@ -67,27 +71,50 @@ func RemoveAutoFaceClusters() (removed int64, err error) { return res.RowsAffected, res.Error } -// CountNewFaceMarkers returns the number of new face markers in the index. +// CountNewFaceMarkers counts the number of new face markers in the index. func CountNewFaceMarkers() (n int) { var f entity.Face - if err := Db().Where("face_src = ?", entity.SrcAuto).Order("created_at DESC").Take(&f).Error; err != nil { + if err := Db().Where("face_src = ?", entity.SrcAuto).Order("created_at DESC").Limit(1).Take(&f).Error; err != nil { log.Debugf("faces: no existing clusters") } - q := Db().Model(&entity.Markers{}).Where("marker_type = ? AND marker_invalid = 0 AND embeddings_json <> ''", entity.MarkerFace) + q := Db().Model(&entity.Markers{}). + Where("marker_type = ?", entity.MarkerFace). + Where("face_id = '' AND marker_invalid = 0 AND embeddings_json <> ''") if !f.CreatedAt.IsZero() { q = q.Where("created_at > ?", f.CreatedAt) } - if err := q.Order("created_at DESC").Count(&n).Error; err != nil { + if err := q.Count(&n).Error; err != nil { log.Errorf("faces: %s (count new markers)", err) } return n } +// CountUnmatchedFaceMarkers counts the number of unmatched face markers in the index. +func CountUnmatchedFaceMarkers() (n int) { + var f entity.Face + + if err := Db().Where("face_src <> ?", entity.SrcDefault).Order("matched_at ASC").Limit(1).Take(&f).Error; err != nil { + log.Debugf("faces: no unmatched clusters") + } + + q := Db().Model(&entity.Markers{}).Where("marker_type = ? AND marker_invalid = 0 AND embeddings_json <> ''", entity.MarkerFace) + + if f.MatchedAt != nil { + q = q.Where("updated_at > ?", f.MatchedAt) + } + + if err := q.Count(&n).Error; err != nil { + log.Errorf("faces: %s (count unmatched markers)", err) + } + + return n +} + // MergeFaces returns a new face that replaces multiple others. func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) { if len(merge) < 2 { @@ -113,21 +140,9 @@ func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) { return merged, err } - // Find matching markers. - var markers entity.Markers - - if err := Db().Where("face_id = '' AND marker_invalid = 0 AND marker_type = ?", entity.MarkerFace). - Find(&markers).Error; err != nil { - log.Debugf("faces: %s (find matching markers)", err) + // Find and reference additional matching markers. + if err := merged.MatchMarkers(); err != nil { return merged, err - } else { - for _, marker := range markers { - if ok, _ := merged.Match(marker.Embeddings()); !ok { - // Ignore. - } else if _, err := marker.SetFace(merged); err != nil { - return merged, err - } - } } return merged, err diff --git a/internal/query/faces_test.go b/internal/query/faces_test.go index 26ae818d3..ababa9a0e 100644 --- a/internal/query/faces_test.go +++ b/internal/query/faces_test.go @@ -68,3 +68,8 @@ func TestCountNewFaceMarkers(t *testing.T) { n := CountNewFaceMarkers() assert.GreaterOrEqual(t, n, 1) } + +func TestCountUnmatchedFaceMarkers(t *testing.T) { + n := CountUnmatchedFaceMarkers() + assert.GreaterOrEqual(t, n, 0) +} diff --git a/internal/query/markers.go b/internal/query/markers.go index e29046070..c9e3ea5a0 100644 --- a/internal/query/markers.go +++ b/internal/query/markers.go @@ -41,7 +41,7 @@ func Markers(limit, offset int, markerType string, embeddings, subjects bool) (r } // Embeddings returns existing face embeddings. -func Embeddings(single bool) (result entity.Embeddings, err error) { +func Embeddings(single, unclustered bool) (result entity.Embeddings, err error) { var col []string stmt := Db(). @@ -51,6 +51,10 @@ func Embeddings(single bool) (result entity.Embeddings, err error) { Where("embeddings_json <> ''"). Order("id") + if unclustered { + stmt = stmt.Where("face_id = ''") + } + if err := stmt.Pluck("embeddings_json", &col).Error; err != nil { return result, err } @@ -104,25 +108,40 @@ func AddMarkerSubjects() (affected int64, err error) { return affected, err } -// CleanInvalidMarkerReferences deletes invalid reference IDs from the markers table. -func CleanInvalidMarkerReferences() (err error) { - // Reset subject and face relationships for invalid markers. - err = Db(). +// RemoveInvalidMarkerReferences deletes invalid reference IDs from the markers table. +func RemoveInvalidMarkerReferences() (removed int64, err error) { + // Remove subject and face relationships for invalid markers. + if res := Db(). Model(&entity.Marker{}). - Where("marker_invalid = 1"). - UpdateColumns(entity.Values{"subject_uid": "", "subject_src": "", "face_id": ""}). - Error - - if err != nil { - return err + Where("marker_invalid = 1 AND (subject_uid <> '' OR face_id <> '')"). + UpdateColumns(entity.Values{"subject_uid": "", "face_id": ""}); res.Error != nil { + return removed, res.Error + } else { + removed += res.RowsAffected } - // Reset invalid face IDs. - return Db(). + // Remove invalid face IDs. + if res := Db(). Model(&entity.Marker{}). + Where("marker_type = ?", entity.MarkerFace). Where(fmt.Sprintf("face_id <> '' AND face_id NOT IN (SELECT id FROM %s)", entity.Face{}.TableName())). - UpdateColumns(entity.Values{"face_id": ""}). - Error + UpdateColumns(entity.Values{"face_id": ""}); res.Error != nil { + return removed, res.Error + } else { + removed += res.RowsAffected + } + + // Remove invalid subject UIDs. + if res := Db(). + Model(&entity.Marker{}). + Where(fmt.Sprintf("subject_uid <> '' AND subject_uid NOT IN (SELECT subject_uid FROM %s)", entity.Subject{}.TableName())). + UpdateColumns(entity.Values{"subject_uid": ""}); res.Error != nil { + return removed, res.Error + } else { + removed += res.RowsAffected + } + + return removed, nil } // ResetFaceMarkerMatches removes automatically added subject and face references from the markers table. diff --git a/internal/query/markers_test.go b/internal/query/markers_test.go index 75731f396..57a94e0cd 100644 --- a/internal/query/markers_test.go +++ b/internal/query/markers_test.go @@ -46,7 +46,7 @@ func TestMarkers(t *testing.T) { } func TestEmbeddings(t *testing.T) { - results, err := Embeddings(false) + results, err := Embeddings(false, false) if err != nil { t.Fatal(err) @@ -66,6 +66,9 @@ func TestAddMarkerSubjects(t *testing.T) { assert.GreaterOrEqual(t, affected, int64(1)) } -func TestCleanInvalidMarkerReferences(t *testing.T) { - assert.NoError(t, CleanInvalidMarkerReferences()) +func TestRemoveInvalidMarkerReferences(t *testing.T) { + affected, err := RemoveInvalidMarkerReferences() + + assert.NoError(t, err) + assert.GreaterOrEqual(t, affected, int64(1)) }