photoprism/internal/face/net.go
2021-09-03 01:17:12 +02:00

244 lines
5.5 KiB
Go

package face
import (
"bytes"
"fmt"
"image"
"io/ioutil"
"os"
"path"
"path/filepath"
"runtime/debug"
"strings"
"sync"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/disintegration/imaging"
"github.com/photoprism/photoprism/pkg/txt"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// Net is a wrapper for the TensorFlow Facenet model.
type Net struct {
model *tf.SavedModel
modelPath string
cachePath string
disabled bool
modelName string
modelTags []string
mutex sync.Mutex
}
// NewNet returns a new TensorFlow Facenet instance.
func NewNet(modelPath, cachePath string, disabled bool) *Net {
return &Net{modelPath: modelPath, cachePath: cachePath, disabled: disabled, modelTags: []string{"serve"}}
}
// Detect runs the detection and facenet algorithms over the provided source image.
func (t *Net) Detect(fileName string, minSize int) (faces Faces, err error) {
faces, err = Detect(fileName, false, minSize)
if err != nil {
return faces, err
}
if t.disabled {
return faces, nil
}
err = t.loadModel()
if err != nil {
return faces, err
}
var cacheHash string
if t.cachePath != "" {
cacheHash = fs.Hash(fileName)
}
for i, f := range faces {
if f.Area.Col == 0 && f.Area.Row == 0 {
continue
}
if img, err := t.getFaceCrop(fileName, cacheHash, &faces[i]); err != nil {
log.Errorf("faces: failed to decode image: %v", err)
} else if embeddings := t.getEmbeddings(img); len(embeddings) > 0 {
faces[i].Embeddings = embeddings
}
}
return faces, nil
}
// ModelLoaded tests if the TensorFlow model is loaded.
func (t *Net) ModelLoaded() bool {
return t.model != nil
}
func (t *Net) loadModel() error {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.ModelLoaded() {
return nil
}
modelPath := path.Join(t.modelPath)
log.Infof("faces: loading %s", txt.Quote(filepath.Base(modelPath)))
// Load model
model, err := tf.LoadSavedModel(modelPath, t.modelTags, nil)
if err != nil {
return err
}
t.model = model
return nil
}
func (t *Net) getCacheFolder(fileName, cacheHash string) string {
if t.cachePath == "" || cacheHash == "" {
return filepath.Dir(fileName)
}
if cacheHash == "" {
log.Debugf("faces: no hash provided for caching %s crops", filepath.Base(fileName))
cacheHash = fs.Hash(fileName)
}
result := filepath.Join(t.cachePath, "faces", string(cacheHash[0]), string(cacheHash[1]), string(cacheHash[2]))
if err := os.MkdirAll(result, os.ModePerm); err != nil {
log.Errorf("faces: failed creating cache folder")
}
return result
}
func (t *Net) getFaceCrop(fileName, cacheHash string, f *Face) (img image.Image, err error) {
if f == nil {
return img, fmt.Errorf("face is nil")
}
area := f.Area
cacheFolder := t.getCacheFolder(fileName, cacheHash)
if cacheHash != "" {
f.Thumb = fmt.Sprintf("%s_%dx%d_crop_%s", cacheHash, CropSize, CropSize, f.Crop().ID())
} else {
base := filepath.Base(fileName)
i := strings.Index(base, "_")
if i > 32 {
base = base[:i]
}
f.Thumb = fmt.Sprintf("%s_%dx%d_crop_%s", base, CropSize, CropSize, f.Crop().ID())
}
cacheFile := filepath.Join(cacheFolder, f.Thumb+fs.JpegExt)
if !fs.FileExists(cacheFile) {
// Do nothing.
} else if img, err := imaging.Open(cacheFile); err != nil {
log.Errorf("faces: failed loading %s", filepath.Base(cacheFile))
} else {
log.Debugf("faces: extracting from %s", filepath.Base(cacheFile))
return img, nil
}
x, y := area.TopLeft()
imageBuffer, err := ioutil.ReadFile(fileName)
img, err = imaging.Decode(bytes.NewReader(imageBuffer), imaging.AutoOrientation(true))
if err != nil {
return img, err
}
img = imaging.Crop(img, image.Rect(y, x, y+area.Scale, x+area.Scale))
img = imaging.Fill(img, CropSize, CropSize, imaging.Center, imaging.Lanczos)
if err := imaging.Save(img, cacheFile); err != nil {
log.Errorf("faces: failed caching %s", filepath.Base(cacheFile))
} else {
log.Debugf("faces: saved %s", filepath.Base(cacheFile))
}
return img, nil
}
func (t *Net) getEmbeddings(img image.Image) [][]float32 {
tensor, err := imageToTensor(img, CropSize, CropSize)
if err != nil {
log.Errorf("faces: failed to convert image to tensor: %v", err)
}
// TODO: pre-whiten image as in facenet
trainPhaseBoolTensor, err := tf.NewTensor(false)
output, err := t.model.Session.Run(
map[tf.Output]*tf.Tensor{
t.model.Graph.Operation("input").Output(0): tensor,
t.model.Graph.Operation("phase_train").Output(0): trainPhaseBoolTensor,
},
[]tf.Output{
t.model.Graph.Operation("embeddings").Output(0),
},
nil)
if err != nil {
log.Errorf("faces: %s", err)
}
if len(output) < 1 {
log.Errorf("faces: inference failed, no output")
} else {
return output[0].Value().([][]float32)
}
return nil
}
func imageToTensor(img image.Image, imageHeight, imageWidth int) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("faces: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if imageHeight <= 0 || imageWidth <= 0 {
return tfTensor, fmt.Errorf("faces: image width and height must be > 0")
}
var tfImage [1][][][3]float32
for j := 0; j < imageHeight; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
}
for i := 0; i < imageWidth; i++ {
for j := 0; j < imageHeight; j++ {
r, g, b, _ := img.At(i, j).RGBA()
tfImage[0][j][i][0] = convertValue(r)
tfImage[0][j][i][1] = convertValue(g)
tfImage[0][j][i][2] = convertValue(b)
}
}
return tf.NewTensor(tfImage)
}
func convertValue(value uint32) float32 {
return (float32(value>>8) - float32(127.5)) / float32(127.5)
}