2021-08-14 18:13:03 +02:00
|
|
|
package query
|
|
|
|
|
|
|
|
import (
|
2021-08-22 21:06:44 +02:00
|
|
|
"fmt"
|
|
|
|
|
2021-09-01 12:48:17 +02:00
|
|
|
"github.com/photoprism/photoprism/internal/face"
|
|
|
|
|
2021-08-31 15:33:42 +02:00
|
|
|
"github.com/photoprism/photoprism/pkg/txt"
|
|
|
|
|
2021-08-14 18:13:03 +02:00
|
|
|
"github.com/photoprism/photoprism/internal/entity"
|
|
|
|
)
|
|
|
|
|
2021-08-29 13:26:05 +02:00
|
|
|
// Faces returns all (known / unmatched) faces from the index.
|
|
|
|
func Faces(knownOnly, unmatched bool) (result entity.Faces, err error) {
|
2021-09-03 16:14:09 +02:00
|
|
|
stmt := Db()
|
2021-08-22 21:06:44 +02:00
|
|
|
|
2021-08-29 13:26:05 +02:00
|
|
|
if unmatched {
|
|
|
|
stmt = stmt.Where("matched_at IS NULL")
|
2021-08-22 21:06:44 +02:00
|
|
|
}
|
2021-08-14 18:13:03 +02:00
|
|
|
|
|
|
|
if knownOnly {
|
2021-09-17 14:26:12 +02:00
|
|
|
stmt = stmt.Where("subj_uid <> ''")
|
2021-08-14 18:13:03 +02:00
|
|
|
}
|
|
|
|
|
2021-09-17 14:26:12 +02:00
|
|
|
err = stmt.Order("subj_uid, samples DESC").Find(&result).Error
|
2021-08-29 13:26:05 +02:00
|
|
|
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// ManuallyAddedFaces returns all manually added face clusters.
|
|
|
|
func ManuallyAddedFaces() (result entity.Faces, err error) {
|
|
|
|
err = Db().
|
|
|
|
Where("face_src = ?", entity.SrcManual).
|
2021-09-17 14:26:12 +02:00
|
|
|
Where("subj_uid <> ''").Order("subj_uid, samples DESC").
|
2021-08-29 13:26:05 +02:00
|
|
|
Find(&result).Error
|
2021-08-14 18:13:03 +02:00
|
|
|
|
|
|
|
return result, err
|
|
|
|
}
|
|
|
|
|
2021-08-19 23:12:51 +02:00
|
|
|
// MatchFaceMarkers matches markers with known faces.
|
|
|
|
func MatchFaceMarkers() (affected int64, err error) {
|
2021-08-29 13:26:05 +02:00
|
|
|
faces, err := Faces(true, false)
|
2021-08-14 18:13:03 +02:00
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return affected, err
|
|
|
|
}
|
|
|
|
|
2021-08-23 16:22:01 +02:00
|
|
|
for _, f := range faces {
|
2021-08-14 18:13:03 +02:00
|
|
|
if res := Db().Model(&entity.Marker{}).
|
2021-08-23 16:22:01 +02:00
|
|
|
Where("face_id = ?", f.ID).
|
2021-09-17 14:26:12 +02:00
|
|
|
Where("subj_src = ?", entity.SrcAuto).
|
|
|
|
Where("subj_uid <> ?", f.SubjUID).
|
|
|
|
Updates(entity.Values{"SubjUID": f.SubjUID, "MarkerReview": false}); res.Error != nil {
|
2021-08-14 18:13:03 +02:00
|
|
|
return affected, err
|
|
|
|
} else if res.RowsAffected > 0 {
|
|
|
|
affected += res.RowsAffected
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return affected, nil
|
|
|
|
}
|
|
|
|
|
2021-08-22 16:14:34 +02:00
|
|
|
// RemoveAnonymousFaceClusters removes anonymous faces from the index.
|
|
|
|
func RemoveAnonymousFaceClusters() (removed int64, err error) {
|
|
|
|
res := UnscopedDb().Delete(
|
2021-08-14 18:13:03 +02:00
|
|
|
entity.Face{},
|
2021-09-17 14:26:12 +02:00
|
|
|
"face_src = ? AND subj_uid = ''", entity.SrcAuto)
|
2021-08-22 16:14:34 +02:00
|
|
|
|
|
|
|
return res.RowsAffected, res.Error
|
2021-08-14 19:52:49 +02:00
|
|
|
}
|
|
|
|
|
2021-08-22 16:14:34 +02:00
|
|
|
// RemoveAutoFaceClusters removes automatically added face clusters from the index.
|
|
|
|
func RemoveAutoFaceClusters() (removed int64, err error) {
|
|
|
|
res := UnscopedDb().
|
2021-09-03 16:14:09 +02:00
|
|
|
Delete(entity.Face{}, "face_src = ?", entity.SrcAuto)
|
2021-08-22 16:14:34 +02:00
|
|
|
|
|
|
|
return res.RowsAffected, res.Error
|
2021-08-15 14:14:27 +02:00
|
|
|
}
|
|
|
|
|
2021-08-23 16:22:01 +02:00
|
|
|
// CountNewFaceMarkers counts the number of new face markers in the index.
|
2021-08-29 13:26:05 +02:00
|
|
|
func CountNewFaceMarkers(size, score int) (n int) {
|
2021-08-14 19:52:49 +02:00
|
|
|
var f entity.Face
|
|
|
|
|
2021-08-24 15:20:05 +02:00
|
|
|
if err := Db().Where("face_src = ?", entity.SrcAuto).
|
|
|
|
Order("created_at DESC").Limit(1).Take(&f).Error; err != nil {
|
2021-08-15 14:14:27 +02:00
|
|
|
log.Debugf("faces: no existing clusters")
|
2021-08-14 19:52:49 +02:00
|
|
|
}
|
|
|
|
|
2021-08-23 16:22:01 +02:00
|
|
|
q := Db().Model(&entity.Markers{}).
|
|
|
|
Where("marker_type = ?", entity.MarkerFace).
|
|
|
|
Where("face_id = '' AND marker_invalid = 0 AND embeddings_json <> ''")
|
2021-08-14 19:52:49 +02:00
|
|
|
|
2021-08-29 13:26:05 +02:00
|
|
|
if size > 0 {
|
|
|
|
q = q.Where("size >= ?", size)
|
|
|
|
}
|
|
|
|
|
2021-08-24 20:15:36 +02:00
|
|
|
if score > 0 {
|
|
|
|
q = q.Where("score >= ?", score)
|
|
|
|
}
|
|
|
|
|
2021-08-14 19:52:49 +02:00
|
|
|
if !f.CreatedAt.IsZero() {
|
|
|
|
q = q.Where("created_at > ?", f.CreatedAt)
|
|
|
|
}
|
|
|
|
|
2021-08-23 16:22:01 +02:00
|
|
|
if err := q.Count(&n).Error; err != nil {
|
2021-08-14 19:52:49 +02:00
|
|
|
log.Errorf("faces: %s (count new markers)", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return n
|
2021-08-14 18:13:03 +02:00
|
|
|
}
|
2021-08-22 21:06:44 +02:00
|
|
|
|
2021-08-31 15:33:42 +02:00
|
|
|
// PurgeOrphanFaces removes unused faces from the index.
|
|
|
|
func PurgeOrphanFaces(faceIds []string) (removed int64, err error) {
|
2021-08-29 13:26:05 +02:00
|
|
|
// Remove invalid face IDs.
|
|
|
|
if res := Db().
|
|
|
|
Where("id IN (?)", faceIds).
|
|
|
|
Where(fmt.Sprintf("id NOT IN (SELECT face_id FROM %s)", entity.Marker{}.TableName())).
|
|
|
|
Delete(&entity.Face{}); res.Error != nil {
|
2021-08-31 20:08:53 +02:00
|
|
|
return removed, fmt.Errorf("faces: %s while purging orphans", res.Error)
|
2021-08-29 13:26:05 +02:00
|
|
|
} else {
|
|
|
|
removed += res.RowsAffected
|
|
|
|
}
|
|
|
|
|
|
|
|
return removed, nil
|
|
|
|
}
|
|
|
|
|
2021-08-22 21:06:44 +02:00
|
|
|
// MergeFaces returns a new face that replaces multiple others.
|
|
|
|
func MergeFaces(merge entity.Faces) (merged *entity.Face, err error) {
|
|
|
|
if len(merge) < 2 {
|
|
|
|
// Nothing to merge.
|
2021-08-31 15:33:42 +02:00
|
|
|
return merged, fmt.Errorf("faces: two or more clusters required for merging")
|
2021-08-22 21:06:44 +02:00
|
|
|
}
|
|
|
|
|
2021-09-17 14:26:12 +02:00
|
|
|
subjUID := merge[0].SubjUID
|
2021-08-31 15:33:42 +02:00
|
|
|
|
|
|
|
for i := 1; i < len(merge); i++ {
|
2021-09-17 14:26:12 +02:00
|
|
|
if merge[i].SubjUID != subjUID {
|
2021-08-31 15:33:42 +02:00
|
|
|
return merged, fmt.Errorf("faces: can't merge clusters with conflicting subjects %s <> %s",
|
2021-09-17 14:26:12 +02:00
|
|
|
txt.Quote(subjUID), txt.Quote(merge[i].SubjUID))
|
2021-08-31 15:33:42 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Find or create merged face cluster.
|
2021-09-17 14:26:12 +02:00
|
|
|
if merged = entity.NewFace(merge[0].SubjUID, merge[0].FaceSrc, merge.Embeddings()); merged == nil {
|
|
|
|
return merged, fmt.Errorf("faces: new cluster is nil for subject %s", txt.Quote(subjUID))
|
2021-08-31 15:33:42 +02:00
|
|
|
} else if merged = entity.FirstOrCreateFace(merged); merged == nil {
|
2021-09-17 14:26:12 +02:00
|
|
|
return merged, fmt.Errorf("faces: failed creating new cluster for subject %s", txt.Quote(subjUID))
|
2021-08-29 13:26:05 +02:00
|
|
|
} else if err := merged.MatchMarkers(append(merge.IDs(), "")); err != nil {
|
2021-08-22 21:06:44 +02:00
|
|
|
return merged, err
|
|
|
|
}
|
|
|
|
|
2021-08-31 15:33:42 +02:00
|
|
|
// PurgeOrphanFaces removes unused faces from the index.
|
|
|
|
if removed, err := PurgeOrphanFaces(merge.IDs()); err != nil {
|
|
|
|
return merged, err
|
|
|
|
} else if removed > 0 {
|
2021-09-17 14:26:12 +02:00
|
|
|
log.Debugf("faces: removed %d orphans for subject %s", removed, txt.Quote(subjUID))
|
2021-08-29 13:26:05 +02:00
|
|
|
} else {
|
2021-09-17 14:26:12 +02:00
|
|
|
log.Warnf("faces: failed removing merged clusters for subject %s", txt.Quote(subjUID))
|
2021-08-22 21:06:44 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
return merged, err
|
|
|
|
}
|
2021-09-01 12:48:17 +02:00
|
|
|
|
|
|
|
// ResolveFaceCollisions resolves collisions of different subject's faces.
|
|
|
|
func ResolveFaceCollisions() (conflicts, resolved int, err error) {
|
|
|
|
faces, err := Faces(true, false)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return conflicts, resolved, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, f1 := range faces {
|
|
|
|
for _, f2 := range faces {
|
|
|
|
if matched, dist := f1.Match(entity.Embeddings{f2.Embedding()}); matched {
|
2021-09-17 14:26:12 +02:00
|
|
|
if f1.SubjUID == f2.SubjUID {
|
2021-09-01 12:48:17 +02:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
conflicts++
|
|
|
|
|
|
|
|
r := f1.SampleRadius + face.ClusterRadius
|
|
|
|
|
|
|
|
log.Infof("face %s: conflict at dist %f, Ø %f from %d samples, collision Ø %f", f1.ID, dist, r, f1.Samples, f1.CollisionRadius)
|
|
|
|
|
2021-09-17 14:26:12 +02:00
|
|
|
if f1.SubjUID != "" {
|
|
|
|
log.Debugf("face %s: subject %s (%s %s)", f1.ID, txt.Quote(f1.SubjUID), f1.SubjUID, entity.SrcString(f1.FaceSrc))
|
2021-09-01 12:48:17 +02:00
|
|
|
} else {
|
|
|
|
log.Debugf("face %s: no subject (%s)", f1.ID, entity.SrcString(f1.FaceSrc))
|
|
|
|
}
|
|
|
|
|
2021-09-17 14:26:12 +02:00
|
|
|
if f2.SubjUID != "" {
|
|
|
|
log.Debugf("face %s: subject %s (%s %s)", f2.ID, txt.Quote(f2.SubjUID), f2.SubjUID, entity.SrcString(f2.FaceSrc))
|
2021-09-01 12:48:17 +02:00
|
|
|
} else {
|
|
|
|
log.Debugf("face %s: no subject (%s)", f2.ID, entity.SrcString(f2.FaceSrc))
|
|
|
|
}
|
|
|
|
|
|
|
|
if ok, err := f1.ResolveCollision(entity.Embeddings{f2.Embedding()}); err != nil {
|
|
|
|
log.Errorf("face %s: %s", f1.ID, err)
|
|
|
|
} else if ok {
|
|
|
|
log.Infof("face %s: collision has been resolved", f1.ID)
|
|
|
|
resolved++
|
|
|
|
} else {
|
|
|
|
log.Debugf("face %s: collision could not be resolved", f1.ID)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return conflicts, resolved, nil
|
|
|
|
}
|