Use nasnet mobile instead of inception v1 for image classification
This commit is contained in:
parent
a82696f067
commit
818019a7ec
13 changed files with 141 additions and 105 deletions
|
@ -5,10 +5,11 @@
|
|||
/frontend/node_modules/*
|
||||
/assets/server/public/build/*
|
||||
/assets/testdata
|
||||
/assets/tensorflow
|
||||
Dockerfile
|
||||
/photoprism
|
||||
docker-compose*
|
||||
/coverage.*
|
||||
.dockerignore
|
||||
.idea
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
|
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -6,6 +6,7 @@
|
|||
/frontend/node_modules/*
|
||||
/frontend/tests/result.html
|
||||
/assets/testdata
|
||||
/assets/tensorflow
|
||||
*.log
|
||||
|
||||
# Binaries for programs and plugins
|
||||
|
@ -39,4 +40,4 @@ Thumbs.db
|
|||
.c9revisions
|
||||
.settings
|
||||
.swp
|
||||
.tmp
|
||||
.tmp
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM photoprism/development:20190418
|
||||
FROM photoprism/development:20190430
|
||||
|
||||
# Set up project directory
|
||||
WORKDIR "/go/src/github.com/photoprism/photoprism"
|
||||
|
|
2
Makefile
2
Makefile
|
@ -43,7 +43,7 @@ test-coverage:
|
|||
clean:
|
||||
rm -f $(BINARY_NAME)
|
||||
download:
|
||||
scripts/download-inception.sh
|
||||
scripts/download-assets.sh
|
||||
deploy-photoprism:
|
||||
scripts/docker-build.sh photoprism $(DOCKER_TAG)
|
||||
scripts/docker-push.sh photoprism $(DOCKER_TAG)
|
||||
|
|
2
assets/tensorflow/.gitignore
vendored
2
assets/tensorflow/.gitignore
vendored
|
@ -1,2 +0,0 @@
|
|||
*
|
||||
!.gitignore
|
|
@ -86,8 +86,8 @@ RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH"
|
|||
|
||||
# Download TensorFlow model and test files
|
||||
RUN rm -rf /tmp/* && mkdir -p /tmp/photoprism
|
||||
RUN wget "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" -O /tmp/photoprism/inception.zip
|
||||
RUN wget "https://www.dropbox.com/s/na9p9wwt98l7m5b/import.zip?dl=1" -O /tmp/photoprism/testdata.zip
|
||||
RUN wget "https://dl.photoprism.org/tensorflow/nasnet.zip" -O /tmp/photoprism/nasnet.zip
|
||||
RUN wget "https://dl.photoprism.org/fixtures/test.zip" -O /tmp/photoprism/testdata.zip
|
||||
|
||||
# Install goimports
|
||||
RUN env GO111MODULE=off /usr/local/go/bin/go get golang.org/x/tools/cmd/goimports
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM photoprism/development:20190418 as build
|
||||
FROM photoprism/development:20190430 as build
|
||||
|
||||
# Set up project directory
|
||||
WORKDIR "/go/src/github.com/photoprism/photoprism"
|
||||
|
|
2
go.mod
2
go.mod
|
@ -64,7 +64,7 @@ require (
|
|||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 // indirect
|
||||
github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3
|
||||
github.com/stretchr/testify v1.2.2
|
||||
github.com/tensorflow/tensorflow v1.12.0
|
||||
github.com/tensorflow/tensorflow v1.13.1
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect
|
||||
github.com/twinj/uuid v1.0.0 // indirect
|
||||
github.com/unrolled/render v0.0.0-20181210145518-4c664cb3ad2f // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -224,6 +224,8 @@ github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1
|
|||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/tensorflow/tensorflow v1.12.0 h1:fT4okrN4BkpgotWmDwS56wM6BdkRpTL0lLMzvkM+bLo=
|
||||
github.com/tensorflow/tensorflow v1.12.0/go.mod h1:itOSERT4trABok4UOoG+X4BoKds9F3rIsySdn+Lvu90=
|
||||
github.com/tensorflow/tensorflow v1.13.1 h1:ygn0+ztXusm6RGVP4Od5IF+8h5sAgD5qbeTvqYyMnjo=
|
||||
github.com/tensorflow/tensorflow v1.13.1/go.mod h1:itOSERT4trABok4UOoG+X4BoKds9F3rIsySdn+Lvu90=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 h1:lYIiVDtZnyTWlNwiAxLj0bbpTcx1BWCFhXjfsvmPdNc=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
github.com/twinj/uuid v0.0.0-20150629100731-70cac2bcd273/go.mod h1:mMgcE1RHFUFqe5AfiwlINXisXfDGro23fWdPUfOMjRY=
|
||||
|
|
|
@ -2,20 +2,24 @@ package photoprism
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
"github.com/tensorflow/tensorflow/tensorflow/go/op"
|
||||
)
|
||||
|
||||
// TensorFlow if a tensorflow wrapper given a graph, labels and a modelPath.
|
||||
type TensorFlow struct {
|
||||
modelPath string
|
||||
graph *tf.Graph
|
||||
model *tf.SavedModel
|
||||
labels []string
|
||||
}
|
||||
|
||||
|
@ -30,6 +34,10 @@ type TensorFlowLabel struct {
|
|||
Probability float32 `json:"probability"`
|
||||
}
|
||||
|
||||
func (a *TensorFlowLabel) Percent() int {
|
||||
return int(math.Round(float64(a.Probability * 100)))
|
||||
}
|
||||
|
||||
// TensorFlowLabels is a slice of tensorflow labels.
|
||||
type TensorFlowLabels []TensorFlowLabel
|
||||
|
||||
|
@ -37,6 +45,12 @@ func (a TensorFlowLabels) Len() int { return len(a) }
|
|||
func (a TensorFlowLabels) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a TensorFlowLabels) Less(i, j int) bool { return a[i].Probability > a[j].Probability }
|
||||
|
||||
func (t *TensorFlow) closeSession(s *tf.Session) {
|
||||
if err := s.Close(); err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetImageTagsFromFile returns tags for a jpeg image file.
|
||||
func (t *TensorFlow) GetImageTagsFromFile(filename string) (result []TensorFlowLabel, err error) {
|
||||
imageBuffer, err := ioutil.ReadFile(filename)
|
||||
|
@ -45,43 +59,41 @@ func (t *TensorFlow) GetImageTagsFromFile(filename string) (result []TensorFlowL
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return t.GetImageTags(string(imageBuffer))
|
||||
return t.GetImageTags(imageBuffer)
|
||||
}
|
||||
|
||||
// GetImageTags returns tags for a jpeg image string.
|
||||
func (t *TensorFlow) GetImageTags(image string) (result []TensorFlowLabel, err error) {
|
||||
func (t *TensorFlow) GetImageTags(img []byte) (result []TensorFlowLabel, err error) {
|
||||
if err := t.loadModel(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make tensor
|
||||
tensor, err := t.makeTensorFromImage(image, "jpeg")
|
||||
tensor, err := t.makeTensorFromImage(img, "jpeg")
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid image")
|
||||
}
|
||||
|
||||
// Run inference
|
||||
session, err := tf.NewSession(t.graph, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.Run(
|
||||
output, err := t.model.Session.Run(
|
||||
map[tf.Output]*tf.Tensor{
|
||||
t.graph.Operation("input").Output(0): tensor,
|
||||
t.model.Graph.Operation("input_1").Output(0): tensor,
|
||||
},
|
||||
[]tf.Output{
|
||||
t.graph.Operation("output").Output(0),
|
||||
t.model.Graph.Operation("predictions/Softmax").Output(0),
|
||||
},
|
||||
nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("could not run inference")
|
||||
return result, errors.New("could not run inference")
|
||||
}
|
||||
|
||||
if len(output) < 1 {
|
||||
return result, errors.New("result is empty")
|
||||
}
|
||||
|
||||
|
||||
// Return best labels
|
||||
return t.findBestLabels(output[0].Value().([][]float32)[0]), nil
|
||||
}
|
||||
|
@ -92,22 +104,22 @@ func (t *TensorFlow) loadModel() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Load inception model
|
||||
model, err := ioutil.ReadFile(t.modelPath + "/inception/tensorflow_inception_graph.pb")
|
||||
// Load model
|
||||
model, err := tf.LoadSavedModel(t.modelPath + "/nasnet", []string{"photoprism"}, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.graph = tf.NewGraph()
|
||||
if err := t.graph.Import(model, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.model = model
|
||||
|
||||
// Load labels
|
||||
labelsFile, err := os.Open(t.modelPath + "/inception/imagenet_comp_graph_label_strings.txt")
|
||||
labelsFile, err := os.Open(t.modelPath + "/nasnet/labels.txt")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer labelsFile.Close()
|
||||
|
||||
scanner := bufio.NewScanner(labelsFile)
|
||||
|
||||
// Labels are separated by newlines
|
||||
|
@ -137,61 +149,39 @@ func (t *TensorFlow) findBestLabels(probabilities []float32) []TensorFlowLabel {
|
|||
return resultLabels[:5]
|
||||
}
|
||||
|
||||
func (t *TensorFlow) makeTensorFromImage(image string, imageFormat string) (*tf.Tensor, error) {
|
||||
tensor, err := tf.NewTensor(image)
|
||||
func (t *TensorFlow) makeTensorFromImage(image []byte, imageFormat string) (*tf.Tensor, error) {
|
||||
img, err:= imaging.Decode(bytes.NewReader(image), imaging.AutoOrientation(true))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
graph, input, output, err := t.makeTransformImageGraph(imageFormat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.Close()
|
||||
normalized, err := session.Run(
|
||||
map[tf.Output]*tf.Tensor{input: tensor},
|
||||
[]tf.Output{output},
|
||||
nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return normalized[0], nil
|
||||
|
||||
width, height := 224, 224
|
||||
|
||||
img = imaging.Fill(img, width, height, imaging.Center, imaging.CatmullRom)
|
||||
|
||||
return imageToTensorTF(img, width, height)
|
||||
}
|
||||
|
||||
// Creates a graph to decode, resize and normalize an image
|
||||
func (t *TensorFlow) makeTransformImageGraph(imageFormat string) (
|
||||
graph *tf.Graph, input, output tf.Output, err error) {
|
||||
const (
|
||||
H, W = 224, 224
|
||||
Mean = float32(117)
|
||||
Scale = float32(1)
|
||||
)
|
||||
s := op.NewScope()
|
||||
input = op.Placeholder(s, tf.String)
|
||||
// Decode PNG or JPEG
|
||||
var decode tf.Output
|
||||
if imageFormat == "png" {
|
||||
decode = op.DecodePng(s, input, op.DecodePngChannels(3))
|
||||
} else {
|
||||
decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
|
||||
}
|
||||
// Div and Sub perform (value-Mean)/Scale for each pixel
|
||||
output = op.Div(s,
|
||||
op.Sub(s,
|
||||
// Resize to 224x224 with bi-linear interpolation
|
||||
op.ResizeBilinear(s,
|
||||
// Create a batch containing a single image
|
||||
op.ExpandDims(s,
|
||||
// Use decoded pixel values
|
||||
op.Cast(s, decode, tf.Float),
|
||||
op.Const(s.SubScope("make_batch"), int32(0))),
|
||||
op.Const(s.SubScope("size"), []int32{H, W})),
|
||||
op.Const(s.SubScope("mean"), Mean)),
|
||||
op.Const(s.SubScope("scale"), Scale))
|
||||
graph, err = s.Finalize()
|
||||
func imageToTensorTF(img image.Image, imageHeight, imageWidth int) (*tf.Tensor, error) {
|
||||
var tfImage [1][][][3]float32
|
||||
|
||||
return graph, input, output, err
|
||||
for j := 0; j < imageHeight; j++ {
|
||||
tfImage[0] = append(tfImage[0], make([][3]float32, imageWidth))
|
||||
}
|
||||
|
||||
for i := 0; i < imageWidth; i++ {
|
||||
for j := 0; j < imageHeight; j++ {
|
||||
r, g, b, _ := img.At(i, j).RGBA()
|
||||
tfImage[0][j][i][0] = convertTF(r)
|
||||
tfImage[0][j][i][1] = convertTF(g)
|
||||
tfImage[0][j][i][2] = convertTF(b)
|
||||
}
|
||||
}
|
||||
|
||||
return tf.NewTensor(tfImage)
|
||||
}
|
||||
|
||||
func convertTF(value uint32) float32 {
|
||||
return (float32(value>>8) - float32(127.5)) / float32(127.5)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package photoprism
|
|||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/test"
|
||||
|
@ -18,15 +17,24 @@ func TestTensorFlow_GetImageTagsFromFile(t *testing.T) {
|
|||
|
||||
result, err := tensorFlow.GetImageTagsFromFile(conf.ImportPath() + "/iphone/IMG_6788.JPG")
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if err != nil {
|
||||
t.Log(err.Error())
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.IsType(t, []TensorFlowLabel{}, result)
|
||||
assert.Equal(t, 5, len(result))
|
||||
|
||||
assert.Equal(t, "tabby", result[0].Label)
|
||||
t.Log(result)
|
||||
|
||||
assert.Equal(t, "tabby, cat", result[0].Label)
|
||||
assert.Equal(t, "tiger cat", result[1].Label)
|
||||
|
||||
assert.Equal(t, float64(0.165), math.Round(float64(result[1].Probability)*1000)/1000)
|
||||
assert.Equal(t, 63, result[0].Percent())
|
||||
assert.Equal(t, 16, result[1].Percent())
|
||||
}
|
||||
|
||||
func TestTensorFlow_GetImageTags(t *testing.T) {
|
||||
|
@ -43,16 +51,52 @@ func TestTensorFlow_GetImageTags(t *testing.T) {
|
|||
if imageBuffer, err := ioutil.ReadFile(conf.ImportPath() + "/iphone/IMG_6788.JPG"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.GetImageTags(string(imageBuffer))
|
||||
result, err := tensorFlow.GetImageTags(imageBuffer)
|
||||
|
||||
t.Log(result)
|
||||
|
||||
assert.NotNil(t, result)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.IsType(t, []TensorFlowLabel{}, result)
|
||||
assert.Equal(t, 5, len(result))
|
||||
|
||||
assert.Equal(t, "tabby", result[0].Label)
|
||||
assert.Equal(t, "tabby, cat", result[0].Label)
|
||||
assert.Equal(t, "tiger cat", result[1].Label)
|
||||
|
||||
assert.Equal(t, float64(0.165), math.Round(float64(result[1].Probability)*1000)/1000)
|
||||
assert.Equal(t, 63, result[0].Percent())
|
||||
assert.Equal(t, 16, result[1].Percent())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensorFlow_GetImageTags_Dog(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
conf := test.NewConfig()
|
||||
|
||||
conf.InitializeTestData(t)
|
||||
|
||||
tensorFlow := NewTensorFlow(conf.TensorFlowModelPath())
|
||||
|
||||
if imageBuffer, err := ioutil.ReadFile(conf.ImportPath() + "/dog.jpg"); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
result, err := tensorFlow.GetImageTags(imageBuffer)
|
||||
|
||||
t.Log(result)
|
||||
|
||||
assert.NotNil(t, result)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.IsType(t, []TensorFlowLabel{}, result)
|
||||
assert.Equal(t, 5, len(result))
|
||||
|
||||
assert.Equal(t, "belt", result[0].Label)
|
||||
assert.Equal(t, "basenji, dog", result[1].Label)
|
||||
|
||||
assert.Equal(t, 13, result[1].Percent())
|
||||
assert.Equal(t, 13, result[1].Percent())
|
||||
}
|
||||
}
|
||||
|
|
15
scripts/download-assets.sh
Executable file
15
scripts/download-assets.sh
Executable file
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
FILENAME="/tmp/photoprism/nasnet.zip"
|
||||
|
||||
if [[ ! -e assets/tensorflow/nasnet/saved_model.pb ]]; then
|
||||
if [[ ! -e ${FILENAME} ]]; then
|
||||
mkdir -p /tmp/photoprism
|
||||
wget "https://dl.photoprism.org/tensorflow/nasnet.zip" -O ${FILENAME}
|
||||
fi
|
||||
|
||||
mkdir -p assets/tensorflow
|
||||
unzip ${FILENAME} -d assets/tensorflow
|
||||
else
|
||||
echo "TensorFlow model already downloaded."
|
||||
fi
|
|
@ -1,15 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
FILENAME="/tmp/photoprism/inception.zip"
|
||||
|
||||
if [[ ! -e assets/tensorflow/inception/tensorflow_inception_graph.pb ]]; then
|
||||
if [[ ! -e ${FILENAME} ]]; then
|
||||
mkdir -p /tmp/photoprism
|
||||
wget "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" -O ${FILENAME}
|
||||
fi
|
||||
|
||||
mkdir -p assets/tensorflow/inception
|
||||
unzip ${FILENAME} -d assets/tensorflow/inception
|
||||
else
|
||||
echo "TensorFlow Inception V1 model already downloaded."
|
||||
fi
|
Loading…
Reference in a new issue