People: Improve face clustering and matching #22

This commit is contained in:
Michael Mayer 2021-08-23 16:22:01 +02:00
parent 5442c04c75
commit fefe70f9a4
18 changed files with 259 additions and 142 deletions

View File

@ -27,7 +27,7 @@ func EmbeddingsMidpoint(m Embeddings) (result Embedding, radius float64, count i
}
if d := clusters.EuclideanDistance(result, emb); d > radius {
radius = d
radius = d + 0.01
}
}

View File

@ -26,6 +26,7 @@ type Face struct {
CollisionRadius float64 `json:"CollisionRadius" yaml:"CollisionRadius,omitempty"`
EmbeddingJSON json.RawMessage `gorm:"type:MEDIUMBLOB;" json:"-" yaml:"EmbeddingJSON,omitempty"`
embedding Embedding `gorm:"-"`
MatchedAt *time.Time `json:"MatchedAt" yaml:"MatchedAt,omitempty"`
CreatedAt time.Time `json:"CreatedAt" yaml:"CreatedAt,omitempty"`
UpdatedAt time.Time `json:"UpdatedAt" yaml:"UpdatedAt,omitempty"`
}
@ -82,6 +83,13 @@ func (m *Face) SetEmbeddings(embeddings Embeddings) (err error) {
return nil
}
// UpdateMatchTime updates the match timestamp.
func (m *Face) UpdateMatchTime() error {
matched := Timestamp()
m.MatchedAt = &matched
return UnscopedDb().Model(m).UpdateColumns(Values{"MatchedAt": m.MatchedAt}).Error
}
// Embedding returns parsed face embedding.
func (m *Face) Embedding() Embedding {
if len(m.EmbeddingJSON) == 0 {
@ -168,7 +176,10 @@ func (m *Face) ReportCollision(embeddings Embeddings) (reported bool, err error)
if err == nil && revise {
var revised Markers
revised, err = m.ReviseMatches()
log.Infof("faces: revised %d matches after collision", len(revised))
if n := len(revised); n > 0 {
log.Infof("faces: revised %d matches after collision", n)
}
}
return true, err
@ -198,6 +209,30 @@ func (m *Face) ReviseMatches() (revised Markers, err error) {
return revised, nil
}
// MatchMarkers finds and references matching markers.
func (m *Face) MatchMarkers() error {
var markers Markers
err := Db().
Where("face_id = '' AND marker_invalid = 0 AND marker_type = ?", MarkerFace).
Find(&markers).Error
if err != nil {
log.Debugf("faces: %s (match markers)", err)
return err
}
for _, marker := range markers {
if ok, _ := m.Match(marker.Embeddings()); !ok {
// Ignore.
} else if _, err = marker.SetFace(m); err != nil {
return err
}
}
return nil
}
// Save updates the existing or inserts a new face.
func (m *Face) Save() error {
faceMutex.Lock()

View File

@ -38,6 +38,7 @@ var FaceFixtures = FaceMap{
SampleRadius: 0,
Samples: 1,
Collisions: 0,
MatchedAt: &editTime,
CreatedAt: Timestamp(),
UpdatedAt: Timestamp(),
},

View File

@ -120,8 +120,15 @@ func (m *Marker) SetFace(f *Face) (updated bool, err error) {
// Any reason we don't want to set a new face for this marker?
if m.SubjectSrc != SrcManual || f.SubjectUID == m.SubjectUID {
// Don't skip if subject wasn't set manually, or subjects match.
} else if f.SubjectUID != "" {
log.Debugf("faces: ambiguous subjects %s / %s for marker %d", txt.Quote(f.SubjectUID), txt.Quote(m.SubjectUID), m.ID)
} else if f.SubjectUID != "" && m.SubjectUID == "" {
log.Debugf("faces: rejected subject %s for marker %d with unknown subject, source %s", txt.Quote(f.SubjectUID), m.ID, m.SubjectSrc)
return false, nil
} else if reported, err := f.ReportCollision(m.Embeddings()); err != nil {
return false, err
} else if reported {
log.Infof("faces: marker %d (subject %s) collision with %s (subject %s), source %s", m.ID, m.SubjectUID, f.ID, f.SubjectUID, m.SubjectSrc)
return false, nil
} else {
return false, nil
}
@ -307,18 +314,18 @@ func (m *Marker) GetFace() (f *Face) {
if m.FaceID == "" && m.SubjectSrc == SrcManual {
if f = NewFace(m.SubjectUID, SrcManual, m.Embeddings()); f == nil {
return nil
} else if f = FirstOrCreateFace(f); f == nil {
log.Debugf("marker: invalid face")
} else if err := f.Create(); err != nil {
return nil
} else if err := f.MatchMarkers(); err != nil {
log.Errorf("faces: %s (match markers)", err)
}
m.Face = f
m.FaceID = f.ID
return f
} else {
m.Face = FindFace(m.FaceID)
}
m.Face = FindFace(m.FaceID)
return m.Face
}

View File

@ -9,6 +9,12 @@ func Timestamp() time.Time {
return time.Now().UTC().Round(time.Second)
}
// TimestampPointer returns a current timestamp pointer.
func TimestampPointer() *time.Time {
t := Timestamp()
return &t
}
// Seconds converts an int to a duration in seconds.
func Seconds(s int) time.Duration {
return time.Duration(s) * time.Second

View File

@ -17,6 +17,22 @@ func TestTimestamp(t *testing.T) {
}
}
func TestTimestampPointer(t *testing.T) {
result := TimestampPointer()
if result == nil {
t.Fatal("result must not be nil")
}
if result.Location() != time.UTC {
t.Fatal("timestamp zone must be utc")
}
if result.After(time.Now().Add(time.Second)) {
t.Fatal("timestamp should be in the past from now")
}
}
func TestSeconds(t *testing.T) {
result := Seconds(23)

View File

@ -40,7 +40,7 @@ import (
var ClusterCore = 4
var ClusterRadius = 0.6
var SampleThreshold = 25
var SampleThreshold = 2 * ClusterCore
var log = event.Log

View File

@ -4,6 +4,8 @@ import (
"fmt"
"runtime/debug"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/mutex"
"github.com/photoprism/photoprism/internal/query"
@ -23,6 +25,13 @@ func NewFaces(conf *config.Config) *Faces {
return instance
}
// StartDefault starts face clustering and matching with default options.
func (w *Faces) StartDefault() (err error) {
return w.Start(FacesOptions{
Force: false,
})
}
// Start face clustering and matching.
func (w *Faces) Start(opt FacesOptions) (err error) {
defer func() {
@ -42,78 +51,46 @@ func (w *Faces) Start(opt FacesOptions) (err error) {
defer mutex.MainWorker.Stop()
// Skip clustering if index contains no new face markers and force option isn't set.
if n := query.CountNewFaceMarkers(); n < 1 && !opt.Force {
log.Debugf("faces: no new samples")
var updated int64
// Adds and reference known marker subjects.
if affected, err := query.AddMarkerSubjects(); err != nil {
log.Errorf("faces: %s (match markers with subjects)", err)
} else {
updated += affected
}
// Match markers with known faces.
if affected, err := query.MatchFaceMarkers(); err != nil {
return err
} else {
updated += affected
}
// Log result.
if updated > 0 {
log.Infof("faces: %d markers updated", updated)
} else {
log.Debug("faces: no changes")
}
// Remove invalid ids from marker table.
if err := query.CleanInvalidMarkerReferences(); err != nil {
log.Errorf("faces: %s (clean)", err)
}
// Optimize existing face clusters.
if res, err := w.Optimize(); err != nil {
return err
} else if res.Merged > 0 {
log.Infof("faces: %d clusters merged", res.Merged)
}
return nil
// Remove invalid reference IDs from markers table.
if removed, err := query.RemoveInvalidMarkerReferences(); err != nil {
log.Errorf("faces: %s (remove invalid references)", err)
} else if removed > 0 {
log.Infof("faces: removed %d invalid references", removed)
} else {
log.Infof("faces: %d new samples", n)
}
var clustersAdded, clustersRemoved int64
// Cluster existing face embeddings.
if clustersAdded, clustersRemoved, err = w.Cluster(opt); err != nil {
log.Errorf("faces: %s (cluster)", err)
}
// Log face clustering results.
if (clustersAdded - clustersRemoved) != 0 {
log.Infof("faces: %d clusters added, %d removed", clustersAdded, clustersRemoved)
} else {
log.Debugf("faces: %d clusters added, %d removed", clustersAdded, clustersRemoved)
}
// Remove invalid marker references.
if err = query.CleanInvalidMarkerReferences(); err != nil {
log.Errorf("faces: %s (clean)", err)
log.Debugf("faces: no invalid references")
}
// Optimize existing face clusters.
if res, err := w.Optimize(); err != nil {
return err
} else if res.Merged > 0 {
log.Infof("faces: %d clusters merged", res.Merged)
log.Infof("faces: merged %d clusters", res.Merged)
} else {
log.Debugf("faces: no clusters could be merged")
}
// Add known marker subjects.
if affected, err := query.AddMarkerSubjects(); err != nil {
log.Errorf("faces: %s (match markers with subjects)", err)
} else if affected > 0 {
log.Infof("faces: added %d known marker subjects", affected)
} else {
log.Debugf("faces: no subjects were missing")
}
var added entity.Faces
// Cluster existing face embeddings.
if added, err = w.Cluster(opt); err != nil {
log.Errorf("faces: %s (cluster)", err)
} else if n := len(added); n > 0 {
log.Infof("faces: added %d new faces", n)
} else {
log.Debugf("faces: found no new faces")
}
// Match markers with faces and subjects.
matches, err := w.Match()
matches, err := w.Match(opt)
if err != nil {
log.Errorf("faces: %s (match)", err)

View File

@ -8,7 +8,7 @@ import (
// Analyze face embeddings.
func (w *Faces) Analyze() (err error) {
if embeddings, err := query.Embeddings(true); err != nil {
if embeddings, err := query.Embeddings(true, false); err != nil {
return err
} else if samples := len(embeddings); samples == 0 {
log.Infof("faces: no samples found")

View File

@ -8,29 +8,47 @@ import (
)
// Cluster clusters indexed face embeddings.
func (w *Faces) Cluster(opt FacesOptions) (added int64, removed int64, err error) {
// Fetch and cluster all face embeddings.
embeddings, err := query.Embeddings(false)
func (w *Faces) Cluster(opt FacesOptions) (added entity.Faces, err error) {
if w.Disabled() {
return added, nil
}
// Skip clustering if index contains no new face markers, and force option isn't set.
if opt.Force {
log.Infof("faces: forced clustering")
} else if n := query.CountNewFaceMarkers(); n < 1 {
log.Debugf("faces: skipping clustering")
return added, nil
}
// Fetch unclustered face embeddings.
embeddings, err := query.Embeddings(false, true)
log.Debugf("faces: %d unclustered samples found", len(embeddings))
// Anything that keeps us from doing this?
if err != nil {
return added, removed, err
return added, err
} else if samples := len(embeddings); samples < opt.SampleThreshold() {
log.Warnf("faces: at least %d samples needed for matching similar faces", face.SampleThreshold)
return added, removed, nil
log.Debugf("faces: at least %d samples needed for clustering", face.SampleThreshold)
return added, nil
} else {
var c clusters.HardClusterer
// See https://dl.photoprism.org/research/ for research on face clustering algorithms.
if c, err = clusters.DBSCAN(face.ClusterCore, face.ClusterRadius, w.conf.Workers(), clusters.EuclideanDistance); err != nil {
return added, removed, err
return added, err
} else if err = c.Learn(embeddings); err != nil {
return added, removed, err
return added, err
}
sizes := c.Sizes()
log.Debugf("faces: %d samples in %d clusters", len(embeddings), len(sizes))
if len(sizes) > 1 {
log.Infof("faces: found %d new clusters", len(sizes))
} else {
log.Debugf("faces: found no new clusters")
}
results := make([]entity.Embeddings, len(sizes))
@ -48,23 +66,19 @@ func (w *Faces) Cluster(opt FacesOptions) (added int64, removed int64, err error
results[n-1] = append(results[n-1], embeddings[i])
}
if removed, err = query.RemoveAnonymousFaceClusters(); err != nil {
log.Errorf("faces: %s", err)
} else if removed > 0 {
log.Debugf("faces: removed %d anonymous clusters", removed)
}
for _, cluster := range results {
if f := entity.NewFace("", entity.SrcAuto, cluster); f == nil {
log.Errorf("faces: face should not be nil - bug?")
} else if err := f.Create(); err == nil {
added++
log.Tracef("faces: added face %s", f.ID)
added = append(added, *f)
log.Debugf("faces: added cluster %s based on %d samples, radius %f", f.ID, f.Samples, f.SampleRadius)
} else if err := f.Updates(entity.Values{"UpdatedAt": entity.Timestamp()}); err != nil {
log.Errorf("faces: %s", err)
} else {
log.Debugf("faces: updated cluster %s", f.ID)
}
}
}
return added, removed, nil
return added, nil
}

View File

@ -17,11 +17,21 @@ type FacesMatchResult struct {
}
// Match matches markers with faces and subjects.
func (w *Faces) Match() (result FacesMatchResult, err error) {
func (w *Faces) Match(opt FacesOptions) (result FacesMatchResult, err error) {
if w.Disabled() {
return result, nil
}
// Skip matching if index contains no new face markers, and force option isn't set.
if opt.Force {
log.Infof("faces: forced matching")
} else if n := query.CountUnmatchedFaceMarkers(); n > 0 {
log.Infof("faces: %d unmatched markers", n)
} else {
result.Recognized, err = query.MatchFaceMarkers()
return result, err
}
faces, err := query.Faces(false, "")
if err != nil {
@ -96,17 +106,19 @@ func (w *Faces) Match() (result FacesMatchResult, err error) {
time.Sleep(50 * time.Millisecond)
}
// Update remaining markers based on current matches.
// Update face match timestamps.
for _, m := range faces {
if err := m.UpdateMatchTime(); err != nil {
log.Warnf("faces: %s (update match time)", err)
}
}
// Update remaining markers based on previous matches.
if m, err := query.MatchFaceMarkers(); err != nil {
return result, err
} else {
result.Recognized += m
}
// Reset invalid marker data.
if err := query.CleanInvalidMarkerReferences(); err != nil {
return result, err
}
return result, nil
}

View File

@ -16,3 +16,10 @@ func (o FacesOptions) SampleThreshold() int {
// Return default.
return face.SampleThreshold
}
// FacesOptionsDefault returns new faces options with default values.
func FacesOptionsDefault() FacesOptions {
result := FacesOptions{}
return result
}

View File

@ -241,7 +241,7 @@ func (imp *Import) Start(opt ImportOptions) fs.Done {
// Match existing faces if facial recognition is enabled.
if w := NewFaces(imp.conf); w.Disabled() {
log.Debugf("import: skipping facial recognition")
} else if matches, err := w.Match(); err != nil {
} else if matches, err := w.Match(FacesOptionsDefault()); err != nil {
log.Errorf("import: %s", err)
} else if matches.Updated > 0 {
log.Infof("import: %d markers updated, %d faces recognized, %d unknown", matches.Updated, matches.Recognized, matches.Unknown)

View File

@ -233,7 +233,7 @@ func (ind *Index) Start(opt IndexOptions) fs.Done {
// Match existing faces if facial recognition is enabled.
if w := NewFaces(ind.conf); w.Disabled() {
log.Debugf("index: skipping facial recognition")
} else if matches, err := w.Match(); err != nil {
} else if matches, err := w.Match(FacesOptionsDefault()); err != nil {
log.Errorf("index: %s", err)
} else if matches.Updated > 0 {
log.Infof("index: %d markers updated, %d faces recognized, %d unknown", matches.Updated, matches.Recognized, matches.Unknown)

View File

@ -19,7 +19,7 @@ func Faces(knownOnly bool, src string) (result entity.Faces, err error) {
if knownOnly {
stmt = stmt.Where("subject_uid <> ''").Order("subject_uid, samples DESC")
} else {
stmt = stmt.Order("id")
stmt = stmt.Order("samples DESC")
}
err = stmt.Find(&result).Error
@ -35,16 +35,20 @@ func MatchFaceMarkers() (affected int64, err error) {
return affected, err
}
for _, match := range faces {
for _, f := range faces {
if res := Db().Model(&entity.Marker{}).
Where("face_id = ?", match.ID).
Where("face_id = ?", f.ID).
Where("subject_src = ?", entity.SrcAuto).
Where("subject_uid <> ?", match.SubjectUID).
Updates(entity.Values{"SubjectUID": match.SubjectUID}); res.Error != nil {
Where("subject_uid <> ?", f.SubjectUID).
Updates(entity.Values{"SubjectUID": f.SubjectUID}); res.Error != nil {
return affected, err
} else if res.RowsAffected > 0 {
affected += res.RowsAffected
}
if err := f.UpdateMatchTime(); err != nil {
return affected, err
}
}
return affected, nil
@ -67,27 +71,50 @@ func RemoveAutoFaceClusters() (removed int64, err error) {
return res.RowsAffected, res.Error
}
// CountNewFaceMarkers returns the number of new face markers in the index.
// CountNewFaceMarkers counts the number of new face markers in the index.
func CountNewFaceMarkers() (n int) {
var f entity.Face
if err := Db().Where("face_src = ?", entity.SrcAuto).Order("created_at DESC").Take(&f).Error; err != nil {
if err := Db().Where("face_src = ?", entity.SrcAuto).Order("created_at DESC").Limit(1).Take(&f).Error; err != nil {
log.Debugf("faces: no existing clusters")
}
q := Db().Model(&entity.Markers{}).Where("marker_type = ? AND marker_invalid = 0 AND embeddings_json <> ''", entity.MarkerFace)
q := Db().Model(&entity.Markers{}).
Where("marker_type = ?", entity.MarkerFace).
Where("face_id = '' AND marker_invalid = 0 AND embeddings_json <> ''")
if !f.CreatedAt.IsZero() {
q = q.Where("created_at > ?", f.CreatedAt)
}
if err := q.Order("created_at DESC").Count(&n).Error; err != nil {
if err := q.Count(&n).Error; err != nil {
log.Errorf("faces: %s (count new markers)", err)
}
return n
}
// CountUnmatchedFaceMarkers counts the number of unmatched face markers in the index.
func CountUnmatchedFaceMarkers() (n int) {
var f entity.Face
if err := Db().Where("face_src <> ?", entity.SrcDefault).Order("matched_at ASC").Limit(1).Take(&f).Error; err != nil {
log.Debugf("faces: no unmatched clusters")
}
q := Db().Model(&entity.Markers{}).Where("marker_type = ? AND marker_invalid = 0 AND embeddings_json <> ''", entity.MarkerFace)
if f.MatchedAt != nil {
q = q.Where("updated_at > ?", f.MatchedAt)
}
if err := q.Count(&n).Error; err != nil {
log.Errorf("faces: %s (count unmatched markers)", err)
}
return n
}
// MergeFaces returns a new face that replaces multiple others.
func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) {
if len(merge) < 2 {
@ -113,21 +140,9 @@ func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) {
return merged, err
}
// Find matching markers.
var markers entity.Markers
if err := Db().Where("face_id = '' AND marker_invalid = 0 AND marker_type = ?", entity.MarkerFace).
Find(&markers).Error; err != nil {
log.Debugf("faces: %s (find matching markers)", err)
// Find and reference additional matching markers.
if err := merged.MatchMarkers(); err != nil {
return merged, err
} else {
for _, marker := range markers {
if ok, _ := merged.Match(marker.Embeddings()); !ok {
// Ignore.
} else if _, err := marker.SetFace(merged); err != nil {
return merged, err
}
}
}
return merged, err

View File

@ -68,3 +68,8 @@ func TestCountNewFaceMarkers(t *testing.T) {
n := CountNewFaceMarkers()
assert.GreaterOrEqual(t, n, 1)
}
func TestCountUnmatchedFaceMarkers(t *testing.T) {
n := CountUnmatchedFaceMarkers()
assert.GreaterOrEqual(t, n, 0)
}

View File

@ -41,7 +41,7 @@ func Markers(limit, offset int, markerType string, embeddings, subjects bool) (r
}
// Embeddings returns existing face embeddings.
func Embeddings(single bool) (result entity.Embeddings, err error) {
func Embeddings(single, unclustered bool) (result entity.Embeddings, err error) {
var col []string
stmt := Db().
@ -51,6 +51,10 @@ func Embeddings(single bool) (result entity.Embeddings, err error) {
Where("embeddings_json <> ''").
Order("id")
if unclustered {
stmt = stmt.Where("face_id = ''")
}
if err := stmt.Pluck("embeddings_json", &col).Error; err != nil {
return result, err
}
@ -104,25 +108,40 @@ func AddMarkerSubjects() (affected int64, err error) {
return affected, err
}
// CleanInvalidMarkerReferences deletes invalid reference IDs from the markers table.
func CleanInvalidMarkerReferences() (err error) {
// Reset subject and face relationships for invalid markers.
err = Db().
// RemoveInvalidMarkerReferences deletes invalid reference IDs from the markers table.
func RemoveInvalidMarkerReferences() (removed int64, err error) {
// Remove subject and face relationships for invalid markers.
if res := Db().
Model(&entity.Marker{}).
Where("marker_invalid = 1").
UpdateColumns(entity.Values{"subject_uid": "", "subject_src": "", "face_id": ""}).
Error
if err != nil {
return err
Where("marker_invalid = 1 AND (subject_uid <> '' OR face_id <> '')").
UpdateColumns(entity.Values{"subject_uid": "", "face_id": ""}); res.Error != nil {
return removed, res.Error
} else {
removed += res.RowsAffected
}
// Reset invalid face IDs.
return Db().
// Remove invalid face IDs.
if res := Db().
Model(&entity.Marker{}).
Where("marker_type = ?", entity.MarkerFace).
Where(fmt.Sprintf("face_id <> '' AND face_id NOT IN (SELECT id FROM %s)", entity.Face{}.TableName())).
UpdateColumns(entity.Values{"face_id": ""}).
Error
UpdateColumns(entity.Values{"face_id": ""}); res.Error != nil {
return removed, res.Error
} else {
removed += res.RowsAffected
}
// Remove invalid subject UIDs.
if res := Db().
Model(&entity.Marker{}).
Where(fmt.Sprintf("subject_uid <> '' AND subject_uid NOT IN (SELECT subject_uid FROM %s)", entity.Subject{}.TableName())).
UpdateColumns(entity.Values{"subject_uid": ""}); res.Error != nil {
return removed, res.Error
} else {
removed += res.RowsAffected
}
return removed, nil
}
// ResetFaceMarkerMatches removes automatically added subject and face references from the markers table.

View File

@ -46,7 +46,7 @@ func TestMarkers(t *testing.T) {
}
func TestEmbeddings(t *testing.T) {
results, err := Embeddings(false)
results, err := Embeddings(false, false)
if err != nil {
t.Fatal(err)
@ -66,6 +66,9 @@ func TestAddMarkerSubjects(t *testing.T) {
assert.GreaterOrEqual(t, affected, int64(1))
}
func TestCleanInvalidMarkerReferences(t *testing.T) {
assert.NoError(t, CleanInvalidMarkerReferences())
func TestRemoveInvalidMarkerReferences(t *testing.T) {
affected, err := RemoveInvalidMarkerReferences()
assert.NoError(t, err)
assert.GreaterOrEqual(t, affected, int64(1))
}