photoprism/internal/classify/tensorflow.go

259 lines
5.5 KiB
Go
Raw Normal View History

package classify
import (
"bufio"
"bytes"
"fmt"
"image"
"math"
"os"
"path"
"path/filepath"
2021-05-06 12:45:38 +02:00
"runtime/debug"
"sort"
"strings"
"github.com/disintegration/imaging"
"github.com/photoprism/photoprism/pkg/sanitize"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// TensorFlow is a wrapper for tensorflow low-level API.
type TensorFlow struct {
model *tf.SavedModel
modelsPath string
disabled bool
modelName string
modelTags []string
labels []string
}
// New returns new TensorFlow instance with Nasnet model.
func New(modelsPath string, disabled bool) *TensorFlow {
return &TensorFlow{modelsPath: modelsPath, disabled: disabled, modelName: "nasnet", modelTags: []string{"photoprism"}}
}
// Init initialises tensorflow models if not disabled
func (t *TensorFlow) Init() (err error) {
if t.disabled {
return nil
}
return t.loadModel()
}
// File returns matching labels for a jpeg media file.
func (t *TensorFlow) File(filename string) (result Labels, err error) {
if t.disabled {
return result, nil
}
imageBuffer, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return t.Labels(imageBuffer)
}
// Labels returns matching labels for a jpeg media string.
func (t *TensorFlow) Labels(img []byte) (result Labels, err error) {
2021-05-06 12:45:38 +02:00
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("classify: %s (inference panic)\nstack: %s", r, debug.Stack())
2021-05-06 12:45:38 +02:00
}
}()
if t.disabled {
return result, nil
}
if err := t.loadModel(); err != nil {
return nil, err
}
// Create tensor from image.
tensor, err := t.createTensor(img, "jpeg")
if err != nil {
return nil, err
}
// Run inference.
output, err := t.model.Session.Run(
map[tf.Output]*tf.Tensor{
t.model.Graph.Operation("input_1").Output(0): tensor,
},
[]tf.Output{
t.model.Graph.Operation("predictions/Softmax").Output(0),
},
nil)
if err != nil {
return result, fmt.Errorf("classify: %s (run inference)", err.Error())
}
if len(output) < 1 {
return result, fmt.Errorf("classify: inference failed, no output")
}
// Return best labels
result = t.bestLabels(output[0].Value().([][]float32)[0])
if len(result) > 0 {
log.Tracef("classify: image classified as %+v", result)
}
return result, nil
}
2019-07-17 11:53:33 +02:00
func (t *TensorFlow) loadLabels(path string) error {
modelLabels := path + "/labels.txt"
log.Infof("classify: loading labels from labels.txt")
// Load labels
f, err := os.Open(modelLabels)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
// Labels are separated by newlines
for scanner.Scan() {
t.labels = append(t.labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
2021-05-06 12:45:38 +02:00
// ModelLoaded tests if the TensorFlow model is loaded.
func (t *TensorFlow) ModelLoaded() bool {
return t.model != nil
}
2019-07-17 11:53:33 +02:00
func (t *TensorFlow) loadModel() error {
if t.ModelLoaded() {
2019-07-17 11:53:33 +02:00
return nil
}
modelPath := path.Join(t.modelsPath, t.modelName)
2019-07-17 11:53:33 +02:00
log.Infof("classify: loading %s", sanitize.Log(filepath.Base(modelPath)))
2019-07-17 11:53:33 +02:00
// Load model
model, err := tf.LoadSavedModel(modelPath, t.modelTags, nil)
2019-07-17 11:53:33 +02:00
if err != nil {
return err
}
t.model = model
return t.loadLabels(modelPath)
2019-07-17 11:53:33 +02:00
}
// bestLabels returns the best 5 labels (if enough high probability labels) from the prediction of the model
func (t *TensorFlow) bestLabels(probabilities []float32) Labels {
var result Labels
for i, p := range probabilities {
if i >= len(t.labels) {
// break if probabilities and labels does not match
break
}
// discard labels with low probabilities
if p < 0.1 {
continue
}
labelText := strings.ToLower(t.labels[i])
2021-09-23 23:46:17 +02:00
rule, _ := Rules.Find(labelText)
// discard labels that don't met the threshold
if p < rule.Threshold {
continue
}
// Get rule label name instead of t.labels name if it exists
if rule.Label != "" {
labelText = rule.Label
}
labelText = strings.TrimSpace(labelText)
uncertainty := 100 - int(math.Round(float64(p*100)))
result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: uncertainty, Priority: rule.Priority, Categories: rule.Categories})
}
// Sort by probability
sort.Sort(result)
2021-05-06 12:45:38 +02:00
// Return the best labels only.
if l := len(result); l < 5 {
return result[:l]
} else {
return result[:5]
}
}
// createTensor converts bytes jpeg image in a tensor object required as tensorflow model input
func (t *TensorFlow) createTensor(image []byte, imageFormat string) (*tf.Tensor, error) {
img, err := imaging.Decode(bytes.NewReader(image), imaging.AutoOrientation(true))
if err != nil {
return nil, err
}
width, height := 224, 224
img = imaging.Fill(img, width, height, imaging.Center, imaging.Lanczos)
2021-05-06 12:45:38 +02:00
return imageToTensor(img, width, height)
}
2021-05-06 12:45:38 +02:00
func imageToTensor(img image.Image, imageHeight, imageWidth int) (tfTensor *tf.Tensor, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("classify: %s (panic)\nstack: %s", r, debug.Stack())
}
}()
if imageHeight <= 0 || imageWidth <= 0 {
return tfTensor, fmt.Errorf("classify: image width and height must be > 0")
}
var tfImage [1][][][3]float32
for j := 0; j < imageHeight; j++ {
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
}
for i := 0; i < imageWidth; i++ {
for j := 0; j < imageHeight; j++ {
r, g, b, _ := img.At(i, j).RGBA()
2021-05-06 12:45:38 +02:00
tfImage[0][j][i][0] = convertValue(r)
tfImage[0][j][i][1] = convertValue(g)
tfImage[0][j][i][2] = convertValue(b)
}
}
return tf.NewTensor(tfImage)
}
2021-05-06 12:45:38 +02:00
func convertValue(value uint32) float32 {
return (float32(value>>8) - float32(127.5)) / float32(127.5)
}