224 lines
6 KiB
Go
224 lines
6 KiB
Go
package entity
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestFace_TableName(t *testing.T) {
|
|
m := &Face{}
|
|
assert.Contains(t, m.TableName(), "faces")
|
|
}
|
|
|
|
func TestFace_Match(t *testing.T) {
|
|
t.Run("1000003-4", func(t *testing.T) {
|
|
m := FaceFixtures.Get("joe-biden")
|
|
match, dist := m.Match(MarkerFixtures.Pointer("1000003-4").Embeddings())
|
|
|
|
assert.True(t, match)
|
|
assert.Greater(t, dist, 1.31)
|
|
assert.Less(t, dist, 1.32)
|
|
})
|
|
|
|
t.Run("1000003-6", func(t *testing.T) {
|
|
m := FaceFixtures.Get("joe-biden")
|
|
match, dist := m.Match(MarkerFixtures.Pointer("1000003-6").Embeddings())
|
|
|
|
assert.True(t, match)
|
|
assert.Greater(t, dist, 1.27)
|
|
assert.Less(t, dist, 1.28)
|
|
})
|
|
|
|
t.Run("len(embeddings) == 0", func(t *testing.T) {
|
|
m := FaceFixtures.Get("joe-biden")
|
|
match, dist := m.Match(Embeddings{})
|
|
|
|
assert.False(t, match)
|
|
assert.Equal(t, dist, float64(-1))
|
|
})
|
|
t.Run("len(efacEmbeddings) == 0", func(t *testing.T) {
|
|
m := NewFace("12345", SrcAuto, Embeddings{})
|
|
match, dist := m.Match(MarkerFixtures.Pointer("1000003-6").Embeddings())
|
|
|
|
assert.False(t, match)
|
|
assert.Equal(t, dist, float64(-1))
|
|
})
|
|
t.Run("jane doe- no match", func(t *testing.T) {
|
|
m := FaceFixtures.Get("jane-doe")
|
|
match, _ := m.Match(MarkerFixtures.Pointer("1000003-5").Embeddings())
|
|
|
|
assert.False(t, match)
|
|
})
|
|
}
|
|
|
|
func TestFace_ReportCollision(t *testing.T) {
|
|
t.Run("collision", func(t *testing.T) {
|
|
m := FaceFixtures.Get("joe-biden")
|
|
|
|
assert.Zero(t, m.Collisions)
|
|
assert.Zero(t, m.CollisionRadius)
|
|
|
|
if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err != nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
assert.True(t, reported)
|
|
}
|
|
|
|
// Number of collisions must have increased by one.
|
|
assert.Equal(t, 1, m.Collisions)
|
|
|
|
// Actual distance is ~1.314040
|
|
assert.Greater(t, m.CollisionRadius, 1.2)
|
|
assert.Less(t, m.CollisionRadius, 1.314)
|
|
|
|
if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-6").Embeddings()); err != nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
assert.False(t, reported)
|
|
}
|
|
|
|
// Number of collisions must not have increased.
|
|
assert.Equal(t, 1, m.Collisions)
|
|
|
|
// Actual distance is ~1.272604
|
|
assert.Greater(t, m.CollisionRadius, 1.1)
|
|
assert.Less(t, m.CollisionRadius, 1.272)
|
|
})
|
|
t.Run("subject id empty", func(t *testing.T) {
|
|
m := NewFace("", SrcAuto, Embeddings{})
|
|
if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err != nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
assert.False(t, reported)
|
|
}
|
|
})
|
|
t.Run("invalid face id", func(t *testing.T) {
|
|
m := NewFace("123", SrcAuto, Embeddings{})
|
|
m.ID = ""
|
|
if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err == nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
assert.False(t, reported)
|
|
assert.Equal(t, "invalid face id", err.Error())
|
|
}
|
|
})
|
|
t.Run("embedding empty", func(t *testing.T) {
|
|
m := NewFace("123", SrcAuto, Embeddings{})
|
|
m.EmbeddingJSON = []byte("")
|
|
if reported, err := m.ReportCollision(MarkerFixtures.Pointer("1000003-4").Embeddings()); err == nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
assert.False(t, reported)
|
|
assert.Equal(t, "embedding must not be empty", err.Error())
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestFace_ReviseMatches(t *testing.T) {
|
|
m := FaceFixtures.Get("joe-biden")
|
|
removed, err := m.ReviseMatches()
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
assert.Empty(t, removed)
|
|
}
|
|
|
|
func TestNewFace(t *testing.T) {
|
|
t.Run("success", func(t *testing.T) {
|
|
marker := MarkerFixtures.Get("1000003-4")
|
|
e := marker.Embeddings()
|
|
|
|
r := NewFace("123", SrcAuto, e)
|
|
assert.Equal(t, "", r.FaceSrc)
|
|
assert.Equal(t, "123", r.SubjectUID)
|
|
})
|
|
}
|
|
|
|
func TestFace_SetEmbeddings(t *testing.T) {
|
|
t.Run("success", func(t *testing.T) {
|
|
marker := MarkerFixtures.Get("1000003-4")
|
|
e := marker.Embeddings()
|
|
m := FaceFixtures.Get("joe-biden")
|
|
assert.NotEqual(t, e[0][0], m.Embedding()[0])
|
|
|
|
err := m.SetEmbeddings(e)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
assert.Equal(t, e[0][0], m.Embedding()[0])
|
|
})
|
|
}
|
|
|
|
func TestFace_Embedding(t *testing.T) {
|
|
t.Run("success", func(t *testing.T) {
|
|
m := FaceFixtures.Get("joe-biden")
|
|
|
|
assert.Equal(t, 0.10730543085474682, m.Embedding()[0])
|
|
})
|
|
t.Run("empty embedding", func(t *testing.T) {
|
|
m := NewFace("12345", SrcAuto, Embeddings{})
|
|
m.EmbeddingJSON = []byte("")
|
|
|
|
assert.Empty(t, m.Embedding())
|
|
})
|
|
t.Run("invalid embedding json", func(t *testing.T) {
|
|
m := NewFace("12345", SrcAuto, Embeddings{})
|
|
m.EmbeddingJSON = []byte("[false]")
|
|
|
|
assert.Equal(t, float64(0), m.Embedding()[0])
|
|
})
|
|
}
|
|
|
|
func TestFace_UpdateMatchTime(t *testing.T) {
|
|
m := NewFace("12345", SrcAuto, Embeddings{})
|
|
initialMatchTime := m.MatchedAt
|
|
assert.Equal(t, initialMatchTime, m.MatchedAt)
|
|
m.UpdateMatchTime()
|
|
assert.NotEqual(t, initialMatchTime, m.MatchedAt)
|
|
}
|
|
|
|
func TestFace_Save(t *testing.T) {
|
|
m := NewFace("12345fde", SrcAuto, Embeddings{Embedding{1}, Embedding{2}})
|
|
assert.Nil(t, FindFace(m.ID))
|
|
m.Save()
|
|
assert.NotNil(t, FindFace(m.ID))
|
|
assert.Equal(t, "12345fde", FindFace(m.ID).SubjectUID)
|
|
}
|
|
|
|
func TestFace_Update(t *testing.T) {
|
|
m := NewFace("12345fdef", SrcAuto, Embeddings{Embedding{8}, Embedding{16}})
|
|
assert.Nil(t, FindFace(m.ID))
|
|
m.Save()
|
|
assert.NotNil(t, FindFace(m.ID))
|
|
assert.Equal(t, "12345fdef", FindFace(m.ID).SubjectUID)
|
|
|
|
m2 := FindFace(m.ID)
|
|
m2.Update("SubjectUID", "new")
|
|
assert.Equal(t, "new", FindFace(m.ID).SubjectUID)
|
|
}
|
|
|
|
func TestFirstOrCreateFace(t *testing.T) {
|
|
t.Run("create new face", func(t *testing.T) {
|
|
m := NewFace("12345unique", SrcAuto, Embeddings{Embedding{99}, Embedding{2}})
|
|
r := FirstOrCreateFace(m)
|
|
assert.Equal(t, "12345unique", r.SubjectUID)
|
|
})
|
|
t.Run("return existing entity", func(t *testing.T) {
|
|
m := FaceFixtures.Pointer("joe-biden")
|
|
r := FirstOrCreateFace(m)
|
|
assert.Equal(t, "jqy3y652h8njw0sx", r.SubjectUID)
|
|
assert.Equal(t, 33, r.Samples)
|
|
})
|
|
}
|
|
|
|
func TestFindFace(t *testing.T) {
|
|
t.Run("existing face", func(t *testing.T) {
|
|
assert.Equal(t, 3, FindFace("VF7ANLDET2BKZNT4VQWJMMC6HBEFDOG7").Samples)
|
|
})
|
|
t.Run("empty id", func(t *testing.T) {
|
|
assert.Nil(t, FindFace(""))
|
|
})
|
|
}
|