138 lines
3.1 KiB
Go
138 lines
3.1 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"sort"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/julienschmidt/httprouter"
|
||
|
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||
|
)
|
||
|
|
||
|
type ClassifyResult struct {
|
||
|
Filename string `json:"filename"`
|
||
|
Labels []LabelResult `json:"labels"`
|
||
|
}
|
||
|
|
||
|
type LabelResult struct {
|
||
|
Label string `json:"label"`
|
||
|
Probability float32 `json:"probability"`
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
graph *tf.Graph
|
||
|
labels []string
|
||
|
)
|
||
|
|
||
|
func main() {
|
||
|
if err := loadModel(); err != nil {
|
||
|
log.Fatal(err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
r := httprouter.New()
|
||
|
r.POST("/recognize", recognizeHandler)
|
||
|
log.Fatal(http.ListenAndServe(":8080", r))
|
||
|
}
|
||
|
|
||
|
func loadModel() error {
|
||
|
// Load inception model
|
||
|
model, err := ioutil.ReadFile("/model/tensorflow_inception_graph.pb")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
graph = tf.NewGraph()
|
||
|
if err := graph.Import(model, ""); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
// Load labels
|
||
|
labelsFile, err := os.Open("/model/imagenet_comp_graph_label_strings.txt")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer labelsFile.Close()
|
||
|
scanner := bufio.NewScanner(labelsFile)
|
||
|
// Labels are separated by newlines
|
||
|
for scanner.Scan() {
|
||
|
labels = append(labels, scanner.Text())
|
||
|
}
|
||
|
if err := scanner.Err(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func recognizeHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
||
|
// Read image
|
||
|
imageFile, header, err := r.FormFile("image")
|
||
|
// Will contain filename and extension
|
||
|
imageName := strings.Split(header.Filename, ".")
|
||
|
if err != nil {
|
||
|
responseError(w, "Could not read image", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
defer imageFile.Close()
|
||
|
var imageBuffer bytes.Buffer
|
||
|
// Copy image data to a buffer
|
||
|
io.Copy(&imageBuffer, imageFile)
|
||
|
|
||
|
// ...
|
||
|
// Make tensor
|
||
|
tensor, err := makeTensorFromImage(&imageBuffer, imageName[:1][0])
|
||
|
if err != nil {
|
||
|
responseError(w, "Invalid image", http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Run inference
|
||
|
session, err := tf.NewSession(graph, nil)
|
||
|
if err != nil {
|
||
|
log.Fatal(err)
|
||
|
}
|
||
|
defer session.Close()
|
||
|
output, err := session.Run(
|
||
|
map[tf.Output]*tf.Tensor{
|
||
|
graph.Operation("input").Output(0): tensor,
|
||
|
},
|
||
|
[]tf.Output{
|
||
|
graph.Operation("output").Output(0),
|
||
|
},
|
||
|
nil)
|
||
|
if err != nil {
|
||
|
responseError(w, "Could not run inference", http.StatusInternalServerError)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Return best labels
|
||
|
responseJSON(w, ClassifyResult{
|
||
|
Filename: header.Filename,
|
||
|
Labels: findBestLabels(output[0].Value().([][]float32)[0]),
|
||
|
})
|
||
|
}
|
||
|
|
||
|
type ByProbability []LabelResult
|
||
|
|
||
|
func (a ByProbability) Len() int { return len(a) }
|
||
|
func (a ByProbability) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||
|
func (a ByProbability) Less(i, j int) bool { return a[i].Probability > a[j].Probability }
|
||
|
|
||
|
func findBestLabels(probabilities []float32) []LabelResult {
|
||
|
// Make a list of label/probability pairs
|
||
|
var resultLabels []LabelResult
|
||
|
for i, p := range probabilities {
|
||
|
if i >= len(labels) {
|
||
|
break
|
||
|
}
|
||
|
resultLabels = append(resultLabels, LabelResult{Label: labels[i], Probability: p})
|
||
|
}
|
||
|
// Sort by probability
|
||
|
sort.Sort(ByProbability(resultLabels))
|
||
|
// Return top 5 labels
|
||
|
return resultLabels[:5]
|
||
|
}
|