photoprism/cmd/tensorflowapi/tensorflowapi.go
2018-02-04 17:34:07 +01:00

138 lines
No EOL
3.1 KiB
Go

package main
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"sort"
"strings"
"github.com/julienschmidt/httprouter"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
type ClassifyResult struct {
Filename string `json:"filename"`
Labels []LabelResult `json:"labels"`
}
type LabelResult struct {
Label string `json:"label"`
Probability float32 `json:"probability"`
}
var (
graph *tf.Graph
labels []string
)
func main() {
if err := loadModel(); err != nil {
log.Fatal(err)
return
}
r := httprouter.New()
r.POST("/recognize", recognizeHandler)
log.Fatal(http.ListenAndServe(":8080", r))
}
func loadModel() error {
// Load inception model
model, err := ioutil.ReadFile("/model/tensorflow_inception_graph.pb")
if err != nil {
return err
}
graph = tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
return err
}
// Load labels
labelsFile, err := os.Open("/model/imagenet_comp_graph_label_strings.txt")
if err != nil {
return err
}
defer labelsFile.Close()
scanner := bufio.NewScanner(labelsFile)
// Labels are separated by newlines
for scanner.Scan() {
labels = append(labels, scanner.Text())
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}
func recognizeHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
// Read image
imageFile, header, err := r.FormFile("image")
// Will contain filename and extension
imageName := strings.Split(header.Filename, ".")
if err != nil {
responseError(w, "Could not read image", http.StatusBadRequest)
return
}
defer imageFile.Close()
var imageBuffer bytes.Buffer
// Copy image data to a buffer
io.Copy(&imageBuffer, imageFile)
// ...
// Make tensor
tensor, err := makeTensorFromImage(&imageBuffer, imageName[:1][0])
if err != nil {
responseError(w, "Invalid image", http.StatusBadRequest)
return
}
// Run inference
session, err := tf.NewSession(graph, nil)
if err != nil {
log.Fatal(err)
}
defer session.Close()
output, err := session.Run(
map[tf.Output]*tf.Tensor{
graph.Operation("input").Output(0): tensor,
},
[]tf.Output{
graph.Operation("output").Output(0),
},
nil)
if err != nil {
responseError(w, "Could not run inference", http.StatusInternalServerError)
return
}
// Return best labels
responseJSON(w, ClassifyResult{
Filename: header.Filename,
Labels: findBestLabels(output[0].Value().([][]float32)[0]),
})
}
type ByProbability []LabelResult
func (a ByProbability) Len() int { return len(a) }
func (a ByProbability) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByProbability) Less(i, j int) bool { return a[i].Probability > a[j].Probability }
func findBestLabels(probabilities []float32) []LabelResult {
// Make a list of label/probability pairs
var resultLabels []LabelResult
for i, p := range probabilities {
if i >= len(labels) {
break
}
resultLabels = append(resultLabels, LabelResult{Label: labels[i], Probability: p})
}
// Sort by probability
sort.Sort(ByProbability(resultLabels))
// Return top 5 labels
return resultLabels[:5]
}