People: Improve face clustering and matching #22
This commit is contained in:
parent
5442c04c75
commit
fefe70f9a4
|
@ -27,7 +27,7 @@ func EmbeddingsMidpoint(m Embeddings) (result Embedding, radius float64, count i
|
|||
}
|
||||
|
||||
if d := clusters.EuclideanDistance(result, emb); d > radius {
|
||||
radius = d
|
||||
radius = d + 0.01
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ type Face struct {
|
|||
CollisionRadius float64 `json:"CollisionRadius" yaml:"CollisionRadius,omitempty"`
|
||||
EmbeddingJSON json.RawMessage `gorm:"type:MEDIUMBLOB;" json:"-" yaml:"EmbeddingJSON,omitempty"`
|
||||
embedding Embedding `gorm:"-"`
|
||||
MatchedAt *time.Time `json:"MatchedAt" yaml:"MatchedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"CreatedAt" yaml:"CreatedAt,omitempty"`
|
||||
UpdatedAt time.Time `json:"UpdatedAt" yaml:"UpdatedAt,omitempty"`
|
||||
}
|
||||
|
@ -82,6 +83,13 @@ func (m *Face) SetEmbeddings(embeddings Embeddings) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UpdateMatchTime updates the match timestamp.
|
||||
func (m *Face) UpdateMatchTime() error {
|
||||
matched := Timestamp()
|
||||
m.MatchedAt = &matched
|
||||
return UnscopedDb().Model(m).UpdateColumns(Values{"MatchedAt": m.MatchedAt}).Error
|
||||
}
|
||||
|
||||
// Embedding returns parsed face embedding.
|
||||
func (m *Face) Embedding() Embedding {
|
||||
if len(m.EmbeddingJSON) == 0 {
|
||||
|
@ -168,7 +176,10 @@ func (m *Face) ReportCollision(embeddings Embeddings) (reported bool, err error)
|
|||
if err == nil && revise {
|
||||
var revised Markers
|
||||
revised, err = m.ReviseMatches()
|
||||
log.Infof("faces: revised %d matches after collision", len(revised))
|
||||
|
||||
if n := len(revised); n > 0 {
|
||||
log.Infof("faces: revised %d matches after collision", n)
|
||||
}
|
||||
}
|
||||
|
||||
return true, err
|
||||
|
@ -198,6 +209,30 @@ func (m *Face) ReviseMatches() (revised Markers, err error) {
|
|||
return revised, nil
|
||||
}
|
||||
|
||||
// MatchMarkers finds and references matching markers.
|
||||
func (m *Face) MatchMarkers() error {
|
||||
var markers Markers
|
||||
|
||||
err := Db().
|
||||
Where("face_id = '' AND marker_invalid = 0 AND marker_type = ?", MarkerFace).
|
||||
Find(&markers).Error
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("faces: %s (match markers)", err)
|
||||
return err
|
||||
}
|
||||
|
||||
for _, marker := range markers {
|
||||
if ok, _ := m.Match(marker.Embeddings()); !ok {
|
||||
// Ignore.
|
||||
} else if _, err = marker.SetFace(m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save updates the existing or inserts a new face.
|
||||
func (m *Face) Save() error {
|
||||
faceMutex.Lock()
|
||||
|
|
|
@ -38,6 +38,7 @@ var FaceFixtures = FaceMap{
|
|||
SampleRadius: 0,
|
||||
Samples: 1,
|
||||
Collisions: 0,
|
||||
MatchedAt: &editTime,
|
||||
CreatedAt: Timestamp(),
|
||||
UpdatedAt: Timestamp(),
|
||||
},
|
||||
|
|
|
@ -120,8 +120,15 @@ func (m *Marker) SetFace(f *Face) (updated bool, err error) {
|
|||
// Any reason we don't want to set a new face for this marker?
|
||||
if m.SubjectSrc != SrcManual || f.SubjectUID == m.SubjectUID {
|
||||
// Don't skip if subject wasn't set manually, or subjects match.
|
||||
} else if f.SubjectUID != "" {
|
||||
log.Debugf("faces: ambiguous subjects %s / %s for marker %d", txt.Quote(f.SubjectUID), txt.Quote(m.SubjectUID), m.ID)
|
||||
} else if f.SubjectUID != "" && m.SubjectUID == "" {
|
||||
log.Debugf("faces: rejected subject %s for marker %d with unknown subject, source %s", txt.Quote(f.SubjectUID), m.ID, m.SubjectSrc)
|
||||
return false, nil
|
||||
} else if reported, err := f.ReportCollision(m.Embeddings()); err != nil {
|
||||
return false, err
|
||||
} else if reported {
|
||||
log.Infof("faces: marker %d (subject %s) collision with %s (subject %s), source %s", m.ID, m.SubjectUID, f.ID, f.SubjectUID, m.SubjectSrc)
|
||||
return false, nil
|
||||
} else {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
@ -307,18 +314,18 @@ func (m *Marker) GetFace() (f *Face) {
|
|||
if m.FaceID == "" && m.SubjectSrc == SrcManual {
|
||||
if f = NewFace(m.SubjectUID, SrcManual, m.Embeddings()); f == nil {
|
||||
return nil
|
||||
} else if f = FirstOrCreateFace(f); f == nil {
|
||||
log.Debugf("marker: invalid face")
|
||||
} else if err := f.Create(); err != nil {
|
||||
return nil
|
||||
} else if err := f.MatchMarkers(); err != nil {
|
||||
log.Errorf("faces: %s (match markers)", err)
|
||||
}
|
||||
|
||||
m.Face = f
|
||||
m.FaceID = f.ID
|
||||
|
||||
return f
|
||||
} else {
|
||||
m.Face = FindFace(m.FaceID)
|
||||
}
|
||||
|
||||
m.Face = FindFace(m.FaceID)
|
||||
|
||||
return m.Face
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,12 @@ func Timestamp() time.Time {
|
|||
return time.Now().UTC().Round(time.Second)
|
||||
}
|
||||
|
||||
// TimestampPointer returns a current timestamp pointer.
|
||||
func TimestampPointer() *time.Time {
|
||||
t := Timestamp()
|
||||
return &t
|
||||
}
|
||||
|
||||
// Seconds converts an int to a duration in seconds.
|
||||
func Seconds(s int) time.Duration {
|
||||
return time.Duration(s) * time.Second
|
||||
|
|
|
@ -17,6 +17,22 @@ func TestTimestamp(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTimestampPointer(t *testing.T) {
|
||||
result := TimestampPointer()
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("result must not be nil")
|
||||
}
|
||||
|
||||
if result.Location() != time.UTC {
|
||||
t.Fatal("timestamp zone must be utc")
|
||||
}
|
||||
|
||||
if result.After(time.Now().Add(time.Second)) {
|
||||
t.Fatal("timestamp should be in the past from now")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeconds(t *testing.T) {
|
||||
result := Seconds(23)
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ import (
|
|||
|
||||
var ClusterCore = 4
|
||||
var ClusterRadius = 0.6
|
||||
var SampleThreshold = 25
|
||||
var SampleThreshold = 2 * ClusterCore
|
||||
|
||||
var log = event.Log
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/mutex"
|
||||
"github.com/photoprism/photoprism/internal/query"
|
||||
|
@ -23,6 +25,13 @@ func NewFaces(conf *config.Config) *Faces {
|
|||
return instance
|
||||
}
|
||||
|
||||
// StartDefault starts face clustering and matching with default options.
|
||||
func (w *Faces) StartDefault() (err error) {
|
||||
return w.Start(FacesOptions{
|
||||
Force: false,
|
||||
})
|
||||
}
|
||||
|
||||
// Start face clustering and matching.
|
||||
func (w *Faces) Start(opt FacesOptions) (err error) {
|
||||
defer func() {
|
||||
|
@ -42,78 +51,46 @@ func (w *Faces) Start(opt FacesOptions) (err error) {
|
|||
|
||||
defer mutex.MainWorker.Stop()
|
||||
|
||||
// Skip clustering if index contains no new face markers and force option isn't set.
|
||||
if n := query.CountNewFaceMarkers(); n < 1 && !opt.Force {
|
||||
log.Debugf("faces: no new samples")
|
||||
|
||||
var updated int64
|
||||
|
||||
// Adds and reference known marker subjects.
|
||||
if affected, err := query.AddMarkerSubjects(); err != nil {
|
||||
log.Errorf("faces: %s (match markers with subjects)", err)
|
||||
} else {
|
||||
updated += affected
|
||||
}
|
||||
|
||||
// Match markers with known faces.
|
||||
if affected, err := query.MatchFaceMarkers(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
updated += affected
|
||||
}
|
||||
|
||||
// Log result.
|
||||
if updated > 0 {
|
||||
log.Infof("faces: %d markers updated", updated)
|
||||
} else {
|
||||
log.Debug("faces: no changes")
|
||||
}
|
||||
|
||||
// Remove invalid ids from marker table.
|
||||
if err := query.CleanInvalidMarkerReferences(); err != nil {
|
||||
log.Errorf("faces: %s (clean)", err)
|
||||
}
|
||||
|
||||
// Optimize existing face clusters.
|
||||
if res, err := w.Optimize(); err != nil {
|
||||
return err
|
||||
} else if res.Merged > 0 {
|
||||
log.Infof("faces: %d clusters merged", res.Merged)
|
||||
}
|
||||
|
||||
return nil
|
||||
// Remove invalid reference IDs from markers table.
|
||||
if removed, err := query.RemoveInvalidMarkerReferences(); err != nil {
|
||||
log.Errorf("faces: %s (remove invalid references)", err)
|
||||
} else if removed > 0 {
|
||||
log.Infof("faces: removed %d invalid references", removed)
|
||||
} else {
|
||||
log.Infof("faces: %d new samples", n)
|
||||
}
|
||||
|
||||
var clustersAdded, clustersRemoved int64
|
||||
|
||||
// Cluster existing face embeddings.
|
||||
if clustersAdded, clustersRemoved, err = w.Cluster(opt); err != nil {
|
||||
log.Errorf("faces: %s (cluster)", err)
|
||||
}
|
||||
|
||||
// Log face clustering results.
|
||||
if (clustersAdded - clustersRemoved) != 0 {
|
||||
log.Infof("faces: %d clusters added, %d removed", clustersAdded, clustersRemoved)
|
||||
} else {
|
||||
log.Debugf("faces: %d clusters added, %d removed", clustersAdded, clustersRemoved)
|
||||
}
|
||||
|
||||
// Remove invalid marker references.
|
||||
if err = query.CleanInvalidMarkerReferences(); err != nil {
|
||||
log.Errorf("faces: %s (clean)", err)
|
||||
log.Debugf("faces: no invalid references")
|
||||
}
|
||||
|
||||
// Optimize existing face clusters.
|
||||
if res, err := w.Optimize(); err != nil {
|
||||
return err
|
||||
} else if res.Merged > 0 {
|
||||
log.Infof("faces: %d clusters merged", res.Merged)
|
||||
log.Infof("faces: merged %d clusters", res.Merged)
|
||||
} else {
|
||||
log.Debugf("faces: no clusters could be merged")
|
||||
}
|
||||
|
||||
// Add known marker subjects.
|
||||
if affected, err := query.AddMarkerSubjects(); err != nil {
|
||||
log.Errorf("faces: %s (match markers with subjects)", err)
|
||||
} else if affected > 0 {
|
||||
log.Infof("faces: added %d known marker subjects", affected)
|
||||
} else {
|
||||
log.Debugf("faces: no subjects were missing")
|
||||
}
|
||||
|
||||
var added entity.Faces
|
||||
|
||||
// Cluster existing face embeddings.
|
||||
if added, err = w.Cluster(opt); err != nil {
|
||||
log.Errorf("faces: %s (cluster)", err)
|
||||
} else if n := len(added); n > 0 {
|
||||
log.Infof("faces: added %d new faces", n)
|
||||
} else {
|
||||
log.Debugf("faces: found no new faces")
|
||||
}
|
||||
|
||||
// Match markers with faces and subjects.
|
||||
matches, err := w.Match()
|
||||
matches, err := w.Match(opt)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("faces: %s (match)", err)
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
// Analyze face embeddings.
|
||||
func (w *Faces) Analyze() (err error) {
|
||||
if embeddings, err := query.Embeddings(true); err != nil {
|
||||
if embeddings, err := query.Embeddings(true, false); err != nil {
|
||||
return err
|
||||
} else if samples := len(embeddings); samples == 0 {
|
||||
log.Infof("faces: no samples found")
|
||||
|
|
|
@ -8,29 +8,47 @@ import (
|
|||
)
|
||||
|
||||
// Cluster clusters indexed face embeddings.
|
||||
func (w *Faces) Cluster(opt FacesOptions) (added int64, removed int64, err error) {
|
||||
// Fetch and cluster all face embeddings.
|
||||
embeddings, err := query.Embeddings(false)
|
||||
func (w *Faces) Cluster(opt FacesOptions) (added entity.Faces, err error) {
|
||||
if w.Disabled() {
|
||||
return added, nil
|
||||
}
|
||||
|
||||
// Skip clustering if index contains no new face markers, and force option isn't set.
|
||||
if opt.Force {
|
||||
log.Infof("faces: forced clustering")
|
||||
} else if n := query.CountNewFaceMarkers(); n < 1 {
|
||||
log.Debugf("faces: skipping clustering")
|
||||
return added, nil
|
||||
}
|
||||
|
||||
// Fetch unclustered face embeddings.
|
||||
embeddings, err := query.Embeddings(false, true)
|
||||
|
||||
log.Debugf("faces: %d unclustered samples found", len(embeddings))
|
||||
|
||||
// Anything that keeps us from doing this?
|
||||
if err != nil {
|
||||
return added, removed, err
|
||||
return added, err
|
||||
} else if samples := len(embeddings); samples < opt.SampleThreshold() {
|
||||
log.Warnf("faces: at least %d samples needed for matching similar faces", face.SampleThreshold)
|
||||
return added, removed, nil
|
||||
log.Debugf("faces: at least %d samples needed for clustering", face.SampleThreshold)
|
||||
return added, nil
|
||||
} else {
|
||||
var c clusters.HardClusterer
|
||||
|
||||
// See https://dl.photoprism.org/research/ for research on face clustering algorithms.
|
||||
if c, err = clusters.DBSCAN(face.ClusterCore, face.ClusterRadius, w.conf.Workers(), clusters.EuclideanDistance); err != nil {
|
||||
return added, removed, err
|
||||
return added, err
|
||||
} else if err = c.Learn(embeddings); err != nil {
|
||||
return added, removed, err
|
||||
return added, err
|
||||
}
|
||||
|
||||
sizes := c.Sizes()
|
||||
|
||||
log.Debugf("faces: %d samples in %d clusters", len(embeddings), len(sizes))
|
||||
if len(sizes) > 1 {
|
||||
log.Infof("faces: found %d new clusters", len(sizes))
|
||||
} else {
|
||||
log.Debugf("faces: found no new clusters")
|
||||
}
|
||||
|
||||
results := make([]entity.Embeddings, len(sizes))
|
||||
|
||||
|
@ -48,23 +66,19 @@ func (w *Faces) Cluster(opt FacesOptions) (added int64, removed int64, err error
|
|||
results[n-1] = append(results[n-1], embeddings[i])
|
||||
}
|
||||
|
||||
if removed, err = query.RemoveAnonymousFaceClusters(); err != nil {
|
||||
log.Errorf("faces: %s", err)
|
||||
} else if removed > 0 {
|
||||
log.Debugf("faces: removed %d anonymous clusters", removed)
|
||||
}
|
||||
|
||||
for _, cluster := range results {
|
||||
if f := entity.NewFace("", entity.SrcAuto, cluster); f == nil {
|
||||
log.Errorf("faces: face should not be nil - bug?")
|
||||
} else if err := f.Create(); err == nil {
|
||||
added++
|
||||
log.Tracef("faces: added face %s", f.ID)
|
||||
added = append(added, *f)
|
||||
log.Debugf("faces: added cluster %s based on %d samples, radius %f", f.ID, f.Samples, f.SampleRadius)
|
||||
} else if err := f.Updates(entity.Values{"UpdatedAt": entity.Timestamp()}); err != nil {
|
||||
log.Errorf("faces: %s", err)
|
||||
} else {
|
||||
log.Debugf("faces: updated cluster %s", f.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return added, removed, nil
|
||||
return added, nil
|
||||
}
|
||||
|
|
|
@ -17,11 +17,21 @@ type FacesMatchResult struct {
|
|||
}
|
||||
|
||||
// Match matches markers with faces and subjects.
|
||||
func (w *Faces) Match() (result FacesMatchResult, err error) {
|
||||
func (w *Faces) Match(opt FacesOptions) (result FacesMatchResult, err error) {
|
||||
if w.Disabled() {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Skip matching if index contains no new face markers, and force option isn't set.
|
||||
if opt.Force {
|
||||
log.Infof("faces: forced matching")
|
||||
} else if n := query.CountUnmatchedFaceMarkers(); n > 0 {
|
||||
log.Infof("faces: %d unmatched markers", n)
|
||||
} else {
|
||||
result.Recognized, err = query.MatchFaceMarkers()
|
||||
return result, err
|
||||
}
|
||||
|
||||
faces, err := query.Faces(false, "")
|
||||
|
||||
if err != nil {
|
||||
|
@ -96,17 +106,19 @@ func (w *Faces) Match() (result FacesMatchResult, err error) {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Update remaining markers based on current matches.
|
||||
// Update face match timestamps.
|
||||
for _, m := range faces {
|
||||
if err := m.UpdateMatchTime(); err != nil {
|
||||
log.Warnf("faces: %s (update match time)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update remaining markers based on previous matches.
|
||||
if m, err := query.MatchFaceMarkers(); err != nil {
|
||||
return result, err
|
||||
} else {
|
||||
result.Recognized += m
|
||||
}
|
||||
|
||||
// Reset invalid marker data.
|
||||
if err := query.CleanInvalidMarkerReferences(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -16,3 +16,10 @@ func (o FacesOptions) SampleThreshold() int {
|
|||
// Return default.
|
||||
return face.SampleThreshold
|
||||
}
|
||||
|
||||
// FacesOptionsDefault returns new faces options with default values.
|
||||
func FacesOptionsDefault() FacesOptions {
|
||||
result := FacesOptions{}
|
||||
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -241,7 +241,7 @@ func (imp *Import) Start(opt ImportOptions) fs.Done {
|
|||
// Match existing faces if facial recognition is enabled.
|
||||
if w := NewFaces(imp.conf); w.Disabled() {
|
||||
log.Debugf("import: skipping facial recognition")
|
||||
} else if matches, err := w.Match(); err != nil {
|
||||
} else if matches, err := w.Match(FacesOptionsDefault()); err != nil {
|
||||
log.Errorf("import: %s", err)
|
||||
} else if matches.Updated > 0 {
|
||||
log.Infof("import: %d markers updated, %d faces recognized, %d unknown", matches.Updated, matches.Recognized, matches.Unknown)
|
||||
|
|
|
@ -233,7 +233,7 @@ func (ind *Index) Start(opt IndexOptions) fs.Done {
|
|||
// Match existing faces if facial recognition is enabled.
|
||||
if w := NewFaces(ind.conf); w.Disabled() {
|
||||
log.Debugf("index: skipping facial recognition")
|
||||
} else if matches, err := w.Match(); err != nil {
|
||||
} else if matches, err := w.Match(FacesOptionsDefault()); err != nil {
|
||||
log.Errorf("index: %s", err)
|
||||
} else if matches.Updated > 0 {
|
||||
log.Infof("index: %d markers updated, %d faces recognized, %d unknown", matches.Updated, matches.Recognized, matches.Unknown)
|
||||
|
|
|
@ -19,7 +19,7 @@ func Faces(knownOnly bool, src string) (result entity.Faces, err error) {
|
|||
if knownOnly {
|
||||
stmt = stmt.Where("subject_uid <> ''").Order("subject_uid, samples DESC")
|
||||
} else {
|
||||
stmt = stmt.Order("id")
|
||||
stmt = stmt.Order("samples DESC")
|
||||
}
|
||||
|
||||
err = stmt.Find(&result).Error
|
||||
|
@ -35,16 +35,20 @@ func MatchFaceMarkers() (affected int64, err error) {
|
|||
return affected, err
|
||||
}
|
||||
|
||||
for _, match := range faces {
|
||||
for _, f := range faces {
|
||||
if res := Db().Model(&entity.Marker{}).
|
||||
Where("face_id = ?", match.ID).
|
||||
Where("face_id = ?", f.ID).
|
||||
Where("subject_src = ?", entity.SrcAuto).
|
||||
Where("subject_uid <> ?", match.SubjectUID).
|
||||
Updates(entity.Values{"SubjectUID": match.SubjectUID}); res.Error != nil {
|
||||
Where("subject_uid <> ?", f.SubjectUID).
|
||||
Updates(entity.Values{"SubjectUID": f.SubjectUID}); res.Error != nil {
|
||||
return affected, err
|
||||
} else if res.RowsAffected > 0 {
|
||||
affected += res.RowsAffected
|
||||
}
|
||||
|
||||
if err := f.UpdateMatchTime(); err != nil {
|
||||
return affected, err
|
||||
}
|
||||
}
|
||||
|
||||
return affected, nil
|
||||
|
@ -67,27 +71,50 @@ func RemoveAutoFaceClusters() (removed int64, err error) {
|
|||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
// CountNewFaceMarkers returns the number of new face markers in the index.
|
||||
// CountNewFaceMarkers counts the number of new face markers in the index.
|
||||
func CountNewFaceMarkers() (n int) {
|
||||
var f entity.Face
|
||||
|
||||
if err := Db().Where("face_src = ?", entity.SrcAuto).Order("created_at DESC").Take(&f).Error; err != nil {
|
||||
if err := Db().Where("face_src = ?", entity.SrcAuto).Order("created_at DESC").Limit(1).Take(&f).Error; err != nil {
|
||||
log.Debugf("faces: no existing clusters")
|
||||
}
|
||||
|
||||
q := Db().Model(&entity.Markers{}).Where("marker_type = ? AND marker_invalid = 0 AND embeddings_json <> ''", entity.MarkerFace)
|
||||
q := Db().Model(&entity.Markers{}).
|
||||
Where("marker_type = ?", entity.MarkerFace).
|
||||
Where("face_id = '' AND marker_invalid = 0 AND embeddings_json <> ''")
|
||||
|
||||
if !f.CreatedAt.IsZero() {
|
||||
q = q.Where("created_at > ?", f.CreatedAt)
|
||||
}
|
||||
|
||||
if err := q.Order("created_at DESC").Count(&n).Error; err != nil {
|
||||
if err := q.Count(&n).Error; err != nil {
|
||||
log.Errorf("faces: %s (count new markers)", err)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// CountUnmatchedFaceMarkers counts the number of unmatched face markers in the index.
|
||||
func CountUnmatchedFaceMarkers() (n int) {
|
||||
var f entity.Face
|
||||
|
||||
if err := Db().Where("face_src <> ?", entity.SrcDefault).Order("matched_at ASC").Limit(1).Take(&f).Error; err != nil {
|
||||
log.Debugf("faces: no unmatched clusters")
|
||||
}
|
||||
|
||||
q := Db().Model(&entity.Markers{}).Where("marker_type = ? AND marker_invalid = 0 AND embeddings_json <> ''", entity.MarkerFace)
|
||||
|
||||
if f.MatchedAt != nil {
|
||||
q = q.Where("updated_at > ?", f.MatchedAt)
|
||||
}
|
||||
|
||||
if err := q.Count(&n).Error; err != nil {
|
||||
log.Errorf("faces: %s (count unmatched markers)", err)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// MergeFaces returns a new face that replaces multiple others.
|
||||
func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) {
|
||||
if len(merge) < 2 {
|
||||
|
@ -113,21 +140,9 @@ func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) {
|
|||
return merged, err
|
||||
}
|
||||
|
||||
// Find matching markers.
|
||||
var markers entity.Markers
|
||||
|
||||
if err := Db().Where("face_id = '' AND marker_invalid = 0 AND marker_type = ?", entity.MarkerFace).
|
||||
Find(&markers).Error; err != nil {
|
||||
log.Debugf("faces: %s (find matching markers)", err)
|
||||
// Find and reference additional matching markers.
|
||||
if err := merged.MatchMarkers(); err != nil {
|
||||
return merged, err
|
||||
} else {
|
||||
for _, marker := range markers {
|
||||
if ok, _ := merged.Match(marker.Embeddings()); !ok {
|
||||
// Ignore.
|
||||
} else if _, err := marker.SetFace(merged); err != nil {
|
||||
return merged, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return merged, err
|
||||
|
|
|
@ -68,3 +68,8 @@ func TestCountNewFaceMarkers(t *testing.T) {
|
|||
n := CountNewFaceMarkers()
|
||||
assert.GreaterOrEqual(t, n, 1)
|
||||
}
|
||||
|
||||
func TestCountUnmatchedFaceMarkers(t *testing.T) {
|
||||
n := CountUnmatchedFaceMarkers()
|
||||
assert.GreaterOrEqual(t, n, 0)
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ func Markers(limit, offset int, markerType string, embeddings, subjects bool) (r
|
|||
}
|
||||
|
||||
// Embeddings returns existing face embeddings.
|
||||
func Embeddings(single bool) (result entity.Embeddings, err error) {
|
||||
func Embeddings(single, unclustered bool) (result entity.Embeddings, err error) {
|
||||
var col []string
|
||||
|
||||
stmt := Db().
|
||||
|
@ -51,6 +51,10 @@ func Embeddings(single bool) (result entity.Embeddings, err error) {
|
|||
Where("embeddings_json <> ''").
|
||||
Order("id")
|
||||
|
||||
if unclustered {
|
||||
stmt = stmt.Where("face_id = ''")
|
||||
}
|
||||
|
||||
if err := stmt.Pluck("embeddings_json", &col).Error; err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
@ -104,25 +108,40 @@ func AddMarkerSubjects() (affected int64, err error) {
|
|||
return affected, err
|
||||
}
|
||||
|
||||
// CleanInvalidMarkerReferences deletes invalid reference IDs from the markers table.
|
||||
func CleanInvalidMarkerReferences() (err error) {
|
||||
// Reset subject and face relationships for invalid markers.
|
||||
err = Db().
|
||||
// RemoveInvalidMarkerReferences deletes invalid reference IDs from the markers table.
|
||||
func RemoveInvalidMarkerReferences() (removed int64, err error) {
|
||||
// Remove subject and face relationships for invalid markers.
|
||||
if res := Db().
|
||||
Model(&entity.Marker{}).
|
||||
Where("marker_invalid = 1").
|
||||
UpdateColumns(entity.Values{"subject_uid": "", "subject_src": "", "face_id": ""}).
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
Where("marker_invalid = 1 AND (subject_uid <> '' OR face_id <> '')").
|
||||
UpdateColumns(entity.Values{"subject_uid": "", "face_id": ""}); res.Error != nil {
|
||||
return removed, res.Error
|
||||
} else {
|
||||
removed += res.RowsAffected
|
||||
}
|
||||
|
||||
// Reset invalid face IDs.
|
||||
return Db().
|
||||
// Remove invalid face IDs.
|
||||
if res := Db().
|
||||
Model(&entity.Marker{}).
|
||||
Where("marker_type = ?", entity.MarkerFace).
|
||||
Where(fmt.Sprintf("face_id <> '' AND face_id NOT IN (SELECT id FROM %s)", entity.Face{}.TableName())).
|
||||
UpdateColumns(entity.Values{"face_id": ""}).
|
||||
Error
|
||||
UpdateColumns(entity.Values{"face_id": ""}); res.Error != nil {
|
||||
return removed, res.Error
|
||||
} else {
|
||||
removed += res.RowsAffected
|
||||
}
|
||||
|
||||
// Remove invalid subject UIDs.
|
||||
if res := Db().
|
||||
Model(&entity.Marker{}).
|
||||
Where(fmt.Sprintf("subject_uid <> '' AND subject_uid NOT IN (SELECT subject_uid FROM %s)", entity.Subject{}.TableName())).
|
||||
UpdateColumns(entity.Values{"subject_uid": ""}); res.Error != nil {
|
||||
return removed, res.Error
|
||||
} else {
|
||||
removed += res.RowsAffected
|
||||
}
|
||||
|
||||
return removed, nil
|
||||
}
|
||||
|
||||
// ResetFaceMarkerMatches removes automatically added subject and face references from the markers table.
|
||||
|
|
|
@ -46,7 +46,7 @@ func TestMarkers(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEmbeddings(t *testing.T) {
|
||||
results, err := Embeddings(false)
|
||||
results, err := Embeddings(false, false)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -66,6 +66,9 @@ func TestAddMarkerSubjects(t *testing.T) {
|
|||
assert.GreaterOrEqual(t, affected, int64(1))
|
||||
}
|
||||
|
||||
func TestCleanInvalidMarkerReferences(t *testing.T) {
|
||||
assert.NoError(t, CleanInvalidMarkerReferences())
|
||||
func TestRemoveInvalidMarkerReferences(t *testing.T) {
|
||||
affected, err := RemoveInvalidMarkerReferences()
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, affected, int64(1))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user