diff --git a/internal/entity/embedding_test.go b/internal/entity/embedding_test.go index d0f293bc4..b57e215a4 100644 --- a/internal/entity/embedding_test.go +++ b/internal/entity/embedding_test.go @@ -58,6 +58,10 @@ func TestUnmarshalEmbedding(t *testing.T) { r := UnmarshalEmbedding("-0.013,-0.031]") assert.Nil(t, r) }) + t.Run("invalid json", func(t *testing.T) { + r := UnmarshalEmbedding("[true, false]") + assert.Equal(t, []float64{0, 0}, r) + }) } func TestUnmarshalEmbeddings(t *testing.T) { @@ -69,4 +73,8 @@ func TestUnmarshalEmbeddings(t *testing.T) { r := UnmarshalEmbeddings("-0.013,-0.031]") assert.Nil(t, r) }) + t.Run("invalid json", func(t *testing.T) { + r := UnmarshalEmbeddings("[[true, false]]") + assert.Equal(t, [][]float64{{0, 0}}, r) + }) } diff --git a/internal/entity/face_test.go b/internal/entity/face_test.go index d68aeb5bf..90195118a 100644 --- a/internal/entity/face_test.go +++ b/internal/entity/face_test.go @@ -1,9 +1,8 @@ package entity import ( - "testing" - "github.com/stretchr/testify/assert" + "testing" ) func TestFace_TableName(t *testing.T) { @@ -29,6 +28,27 @@ func TestFace_Match(t *testing.T) { assert.Greater(t, dist, 1.27) assert.Less(t, dist, 1.28) }) + + t.Run("len(embeddings) == 0", func(t *testing.T) { + m := FaceFixtures.Get("joe-biden") + match, dist := m.Match(Embeddings{}) + + assert.False(t, match) + assert.Equal(t, dist, float64(-1)) + }) + t.Run("len(efacEmbeddings) == 0", func(t *testing.T) { + m := NewFace("12345", SrcAuto, Embeddings{}) + match, dist := m.Match(MarkerFixtures.Pointer("1000003-6").Embeddings()) + + assert.False(t, match) + assert.Equal(t, dist, float64(-1)) + }) + t.Run("jane doe- no match", func(t *testing.T) { + m := FaceFixtures.Get("jane-doe") + match, _ := m.Match(MarkerFixtures.Pointer("1000003-5").Embeddings()) + + assert.False(t, match) + }) } func TestFace_ReportCollision(t *testing.T) { @@ -74,3 +94,56 @@ func TestFace_ReviseMatches(t *testing.T) { assert.Empty(t, removed) } + +func TestNewFace(t *testing.T) { + t.Run("success", func(t *testing.T) { + marker := MarkerFixtures.Get("1000003-4") + e := marker.Embeddings() + + r := NewFace("123", SrcAuto, e) + assert.Equal(t, "", r.FaceSrc) + assert.Equal(t, "123", r.SubjectUID) + }) +} + +func TestFace_SetEmbeddings(t *testing.T) { + t.Run("success", func(t *testing.T) { + marker := MarkerFixtures.Get("1000003-4") + e := marker.Embeddings() + f := FaceFixtures.Get("joe-biden") + assert.NotEqual(t, e[0][0], f.Embedding()[0]) + + err := f.SetEmbeddings(e) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, e[0][0], f.Embedding()[0]) + }) +} + +func TestFace_Embedding(t *testing.T) { + t.Run("success", func(t *testing.T) { + f := FaceFixtures.Get("joe-biden") + + assert.Equal(t, 0.10730543085474682, f.Embedding()[0]) + }) + t.Run("empty embedding", func(t *testing.T) { + f := NewFace("12345", SrcAuto, Embeddings{}) + + assert.Empty(t, f.Embedding()) + }) + t.Run("invalid embedding json", func(t *testing.T) { + f := NewFace("12345", SrcAuto, Embeddings{}) + f.EmbeddingJSON = []byte("[false]") + + assert.Equal(t, float64(0), f.Embedding()[0]) + }) +} + +func TestFace_UpdateMatchTime(t *testing.T) { + f := NewFace("12345", SrcAuto, Embeddings{}) + initialMatchTime := f.MatchedAt + assert.Equal(t, initialMatchTime, f.MatchedAt) + f.UpdateMatchTime() + assert.NotEqual(t, initialMatchTime, f.MatchedAt) +}