Use label rules to optimize image classification

This commit is contained in:
Michael Mayer 2019-05-16 08:41:16 +02:00
parent 7eeab25ae1
commit 8124a8cde1
18 changed files with 279 additions and 56 deletions

View file

@ -7,7 +7,7 @@
/assets/server/public/build/*
/assets/testdata
/assets/backups
/assets/tensorflow
/assets/tensorflow/nasnet
Dockerfile
/photoprism
docker-compose*

2
.gitignore vendored
View file

@ -9,7 +9,7 @@
/frontend/tests/result.html
/assets/testdata
/assets/backups
/assets/tensorflow
/assets/tensorflow/nasnet
*.log
# Binaries for programs and plugins

View file

@ -1,4 +1,4 @@
FROM photoprism/development:20190509
FROM photoprism/development:20190516
# Set up project directory
WORKDIR "/go/src/github.com/photoprism/photoprism"

View file

@ -34,6 +34,8 @@ dep-go:
go build -v ./...
dep-tensorflow:
scripts/download-nasnet.sh
zip-nasnet:
(cd assets/tensorflow && zip -r nasnet.zip nasnet -x "*/.*" -x "*/version.txt")
build-js:
(cd frontend && env NODE_ENV=production npm run build)
build-go:

102
assets/tensorflow/rules.yml Normal file
View file

@ -0,0 +1,102 @@
tabby cat:
see: cat
tiger cat:
tag: tiger cat
priority: 3
synonyms:
- pussy
- animal
- miau
analog clock:
priority: 2
synonyms:
- chronograph
tiger:
priority: 3
synonyms:
- cat
- animal
cat:
tag: tabby cat
priority: 3
synonyms:
- pussy
- animal
- miau
racer:
tag: race car
priority: 1
seashore:
priority: 1
synonyms:
- beach
- waterfront
- coast
- water
lakeside:
synonyms:
- water
- waterfront
cardoon:
tag: flower
window shade:
tag: house front
synonyms:
- window
- building
banded gecko:
tag: animal
daisy:
tag: flower
stole:
threshold: 0.2
quilt:
threshold: 0.2
liner:
tag: ocean liner
synonyms:
- ship
- boat
solar dish:
threshold: 0.5
grasshopper:
priority: 1
velvet:
tag: velvet flower
hair slide:
tag: jewelry
threshold: 0.4
shower curtain:
tag: bathroom
threshold: 0.15
windsor tie:
tag: people
chainlink fence:
tag: fence
mitten:
tag: unknown
bubble:
tag: bubbles

View file

@ -2,6 +2,8 @@ FROM ubuntu:18.04
LABEL maintainer="Michael Mayer <michael@liquidbytes.net>"
ARG BUILD_TAG
ENV DEBIAN_FRONTEND noninteractive
# Configure apt-get
@ -92,12 +94,10 @@ ENV PATH $GOBIN:/usr/local/go/bin:$PATH
ENV GO111MODULE on
RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" && chmod -R 777 "$GOPATH"
ENV BUILD_DATE `date -u +%Y%m%d`
# Download TensorFlow model and test files
RUN rm -rf /tmp/* && mkdir -p /tmp/photoprism
RUN wget "https://dl.photoprism.org/tensorflow/nasnet.zip?$BUILD_DATE" -O /tmp/photoprism/nasnet.zip
RUN wget "https://dl.photoprism.org/fixtures/testdata.zip?$BUILD_DATE" -O /tmp/photoprism/testdata.zip
RUN wget "https://dl.photoprism.org/tensorflow/nasnet.zip?${BUILD_TAG}" -O /tmp/photoprism/nasnet.zip
RUN wget "https://dl.photoprism.org/fixtures/testdata.zip?${BUILD_TAG}" -O /tmp/photoprism/testdata.zip
# Install goimports
RUN env GO111MODULE=off /usr/local/go/bin/go get golang.org/x/tools/cmd/goimports

View file

@ -1,4 +1,4 @@
FROM photoprism/development:20190509 as build
FROM photoprism/development:20190516 as build
# Set up project directory
WORKDIR "/go/src/github.com/photoprism/photoprism"

View file

@ -1,4 +1,4 @@
FROM photoprism/development:20190509
FROM photoprism/development:20190516
# Install Python and TensorFlow
RUN apt-get update && apt-get install -y --no-install-recommends \

View file

@ -158,7 +158,7 @@
<td>{{ props.item.TakenAt | moment('DD/MM/YYYY hh:mm:ss') }}</td>
<td>{{ props.item.LocCity }}</td>
<td>{{ props.item.LocCountry }}</td>
<td>{{ props.item.CameraModel }}</td>
<td>{{ props.item.CameraMake }} {{ props.item.CameraModel }}</td>
<td>{{ props.item.PhotoFavorite ? 'Yes' : 'No' }}</td>
</template>
</v-data-table>

View file

@ -5,6 +5,7 @@ import (
"strconv"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/util"
log "github.com/sirupsen/logrus"
"github.com/gin-gonic/gin"
@ -30,21 +31,24 @@ import (
func GetPhotos(router *gin.RouterGroup, conf *config.Config) {
router.GET("/photos", func(c *gin.Context) {
var form forms.PhotoSearchForm
search := photoprism.NewSearch(conf.OriginalsPath(), conf.Db())
err := c.MustBindWith(&form, binding.Form)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": util.UcFirst(err.Error())})
return
}
result, err := search.Photos(form)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(400, gin.H{"error": util.UcFirst(err.Error())})
return
}
c.Header("x-result-count", strconv.Itoa(form.Count))
c.Header("x-result-offset", strconv.Itoa(form.Offset))
c.JSON(http.StatusOK, result)
})
}
@ -57,6 +61,7 @@ func LikePhoto(router *gin.RouterGroup, conf *config.Config) {
router.POST("/photos/:id/like", func(c *gin.Context) {
search := photoprism.NewSearch(conf.OriginalsPath(), conf.Db())
photoID, err := strconv.ParseUint(c.Param("id"), 10, 64)
if err != nil {
log.Errorf("could not find image for id: %s", err.Error())
c.Data(http.StatusNotFound, "image", []byte(""))
@ -66,12 +71,13 @@ func LikePhoto(router *gin.RouterGroup, conf *config.Config) {
photo, err := search.FindPhotoByID(photoID)
if err != nil {
c.AbortWithStatusJSON(404, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(404, gin.H{"error": util.UcFirst(err.Error())})
return
}
photo.PhotoFavorite = true
conf.Db().Save(&photo)
c.JSON(http.StatusOK, http.Response{})
})
}
@ -84,6 +90,7 @@ func DislikePhoto(router *gin.RouterGroup, conf *config.Config) {
router.DELETE("/photos/:id/like", func(c *gin.Context) {
search := photoprism.NewSearch(conf.OriginalsPath(), conf.Db())
id, err := strconv.ParseUint(c.Param("id"), 10, 64)
if err != nil {
log.Errorf("could not find image for id: %s", err.Error())
c.Data(http.StatusNotFound, "image", []byte(""))
@ -93,12 +100,13 @@ func DislikePhoto(router *gin.RouterGroup, conf *config.Config) {
photo, err := search.FindPhotoByID(id)
if err != nil {
c.AbortWithStatusJSON(404, gin.H{"error": err.Error()})
c.AbortWithStatusJSON(404, gin.H{"error": util.UcFirst(err.Error())})
return
}
photo.PhotoFavorite = false
conf.Db().Save(&photo)
c.JSON(http.StatusOK, http.Response{})
})
}

View file

@ -102,10 +102,10 @@ func (f *PhotoSearchForm) ParseQueryString() (result error) {
result = fmt.Errorf("not a bool value: %s", fieldName)
}
default:
result = fmt.Errorf("unsupported field type: %s", fieldName)
result = fmt.Errorf("unsupported type: %s", fieldName)
}
} else {
result = fmt.Errorf("unknown form field: %s", fieldName)
result = fmt.Errorf("unknown filter: %s", fieldName)
}
} else {
f.Query = string(bytes.ToLower(key))

View file

@ -2,6 +2,7 @@ package models
import (
"fmt"
"strings"
"github.com/gosimple/slug"
"github.com/jinzhu/gorm"
@ -20,16 +21,20 @@ type Camera struct {
}
func NewCamera(modelName string, makeName string) *Camera {
makeName = strings.TrimSpace(makeName)
if modelName == "" {
modelName = "Unknown"
} else if strings.HasPrefix(modelName, makeName) {
modelName = strings.TrimSpace(modelName[len(makeName):])
}
var cameraSlug string
if makeName != "" {
cameraSlug = slug.MakeLang(makeName+" "+modelName, "en")
cameraSlug = slug.Make(makeName + " " + modelName)
} else {
cameraSlug = slug.MakeLang(modelName, "en")
cameraSlug = slug.Make(modelName)
}
result := &Camera{

View file

@ -17,7 +17,7 @@ func (PhotoTag) TableName() string {
}
func (t *PhotoTag) FirstOrCreate(db *gorm.DB) *PhotoTag {
db.FirstOrCreate(t, "photo_id = ? AND tags_id = ?", t.PhotoID, t.TagID)
db.FirstOrCreate(t, "photo_id = ? AND tag_id = ?", t.PhotoID, t.TagID)
return t
}

View file

@ -4,14 +4,14 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/photoprism/photoprism/internal/config"
log "github.com/sirupsen/logrus"
"github.com/jinzhu/gorm"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/models"
log "github.com/sirupsen/logrus"
)
const (
@ -49,6 +49,8 @@ func (i *Indexer) thumbnailsPath() string {
// getImageTags returns all tags of a given mediafile. This function returns
// an empty list in the case of an error.
func (i *Indexer) getImageTags(jpeg *MediaFile) (results []*models.Tag) {
start := time.Now()
var thumbs []string
if jpeg.AspectRatio() == 1 {
@ -57,7 +59,7 @@ func (i *Indexer) getImageTags(jpeg *MediaFile) (results []*models.Tag) {
thumbs = []string{"tile_224", "left_224", "right_224"}
}
tagExists := make(map[string]bool)
var allLabels TensorFlowLabels
for _, thumb := range thumbs {
filename, err := jpeg.Thumbnail(i.thumbnailsPath(), thumb)
@ -67,23 +69,35 @@ func (i *Indexer) getImageTags(jpeg *MediaFile) (results []*models.Tag) {
continue
}
tags, err := i.tensorFlow.GetImageTagsFromFile(filename)
labels, err := i.tensorFlow.GetImageTagsFromFile(filename)
if err != nil {
log.Error(err)
continue
}
for _, tag := range tags {
if tag.Probability > 0.15 { // TODO: Use config variable
if _, ok := tagExists[tag.Label]; !ok {
results = i.appendTag(results, tag.Label)
tagExists[tag.Label] = true
}
}
allLabels = append(allLabels, labels...)
}
// Sort by probability
sort.Sort(TensorFlowLabels(allLabels))
var max float32 = -1
for _, l := range allLabels {
if max == -1 {
max = l.Probability
}
if l.Probability > (max / 3) {
results = i.appendTag(results, l.Label)
}
}
elapsed := time.Since(start)
log.Infof("finding %+v labels for %s took %s", allLabels, jpeg.Filename(), elapsed)
return results
}
@ -142,15 +156,15 @@ func (i *Indexer) indexMediaFile(mediaFile *MediaFile) string {
tags = i.appendTag(tags, location.LocName)
tags = i.appendTag(tags, location.LocType)
if photo.PhotoTitle == "" && location.LocName != "" { // TODO: User defined title format
if photo.PhotoTitle == "" && location.LocName != "" && location.LocCity != "" { // TODO: User defined title format
if len(location.LocName) > 40 {
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(location.LocName), mediaFile.DateCreated().Format("2006"))
} else {
photo.PhotoTitle = fmt.Sprintf("%s / %s / %s", strings.Title(location.LocName), location.LocCity, mediaFile.DateCreated().Format("2006"))
}
} else if photo.PhotoTitle == "" && location.LocCity != "" {
} else if photo.PhotoTitle == "" && location.LocCity != "" && location.LocCountry != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s / %s", location.LocCity, location.LocCountry, mediaFile.DateCreated().Format("2006"))
} else if photo.PhotoTitle == "" && location.LocCounty != "" {
} else if photo.PhotoTitle == "" && location.LocCounty != "" && location.LocCountry != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s / %s", location.LocCounty, location.LocCountry, mediaFile.DateCreated().Format("2006"))
}
} else {
@ -180,8 +194,6 @@ func (i *Indexer) indexMediaFile(mediaFile *MediaFile) string {
if photo.PhotoTitle == "" {
if len(photo.Tags) > 0 { // TODO: User defined title format
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Tags[0].TagLabel), mediaFile.DateCreated().Format("2006"))
} else if photo.Country != nil && photo.Country.CountryName != "" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", strings.Title(photo.Country.CountryName), mediaFile.DateCreated().Format("2006"))
} else if photo.Camera.String() != "" && photo.Camera.String() != "Unknown" {
photo.PhotoTitle = fmt.Sprintf("%s / %s", photo.Camera, mediaFile.DateCreated().Format("January 2006"))
} else {

View file

@ -4,45 +4,96 @@ import (
"bufio"
"bytes"
"errors"
"fmt"
"image"
"io/ioutil"
"math"
"os"
"sort"
"strings"
"github.com/disintegration/imaging"
"github.com/photoprism/photoprism/internal/util"
log "github.com/sirupsen/logrus"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"gopkg.in/yaml.v2"
)
// TensorFlow if a tensorflow wrapper given a graph, labels and a modelPath.
type TensorFlow struct {
modelPath string
model *tf.SavedModel
labels []string
}
// NewTensorFlow returns a new TensorFlow.
func NewTensorFlow(tensorFlowModelPath string) *TensorFlow {
return &TensorFlow{modelPath: tensorFlowModelPath}
modelPath string
model *tf.SavedModel
labels []string
labelRules LabelRules
}
// TensorFlowLabel defines a Json struct with label and probability.
type TensorFlowLabel struct {
Label string `json:"label"`
Probability float32 `json:"probability"`
Synonyms []string
Priority int
}
type LabelRule struct {
Tag string
See string
Threshold float32
Synonyms []string
Priority int
}
type LabelRules map[string]LabelRule
// NewTensorFlow returns a new TensorFlow.
func NewTensorFlow(tensorFlowModelPath string) *TensorFlow {
return &TensorFlow{modelPath: tensorFlowModelPath}
}
func (a *TensorFlowLabel) Percent() int {
return int(math.Round(float64(a.Probability * 100)))
}
func (t *TensorFlow) loadLabelRules() (err error) {
if len(t.labelRules) > 0 {
return nil
}
t.labelRules = make(LabelRules)
fileName := t.modelPath + "/rules.yml"
log.Debugf("loading label rules from \"%s\"", fileName)
if !util.Exists(fileName) {
log.Errorf("label rules file not found: \"%s\"", fileName)
return fmt.Errorf("label rules file not found: \"%s\"", fileName)
}
yamlConfig, err := ioutil.ReadFile(fileName)
if err != nil {
log.Error(err)
return err
}
err = yaml.Unmarshal(yamlConfig, t.labelRules)
return err
}
// TensorFlowLabels is a slice of tensorflow labels.
type TensorFlowLabels []TensorFlowLabel
func (a TensorFlowLabels) Len() int { return len(a) }
func (a TensorFlowLabels) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a TensorFlowLabels) Less(i, j int) bool { return a[i].Probability > a[j].Probability }
func (a TensorFlowLabels) Len() int { return len(a) }
func (a TensorFlowLabels) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a TensorFlowLabels) Less(i, j int) bool {
if a[i].Priority == a[j].Priority {
return a[i].Probability > a[j].Probability
} else {
return a[i].Priority > a[j].Priority
}
}
// GetImageTagsFromFile returns tags for a jpeg image file.
func (t *TensorFlow) GetImageTagsFromFile(filename string) (result []TensorFlowLabel, err error) {
@ -139,9 +190,30 @@ func (t *TensorFlow) loadModel() error {
return nil
}
func (t *TensorFlow) labelRule(label string) LabelRule {
if err := t.loadLabelRules(); err != nil {
log.Error(err)
}
if rule, ok := t.labelRules[label]; ok {
if rule.See != "" {
return t.labelRule(rule.See)
}
return t.labelRules[label]
}
return LabelRule{Threshold: 0.08}
}
func (t *TensorFlow) findBestLabels(probabilities []float32) []TensorFlowLabel {
if err := t.loadLabelRules(); err != nil {
log.Error(err)
}
// Make a list of label/probability pairs
var result []TensorFlowLabel
for i, p := range probabilities {
if i >= len(t.labels) {
break
@ -151,18 +223,28 @@ func (t *TensorFlow) findBestLabels(probabilities []float32) []TensorFlowLabel {
continue
}
result = append(result, TensorFlowLabel{Label: t.labels[i], Probability: p})
labelText := strings.ToLower(t.labels[i])
rule := t.labelRule(labelText)
if p < rule.Threshold {
continue
}
if rule.Tag != "" {
labelText = rule.Tag
}
result = append(result, TensorFlowLabel{Label: labelText, Probability: p, Synonyms: rule.Synonyms, Priority: rule.Priority})
}
// Sort by probability
sort.Sort(TensorFlowLabels(result))
l := len(result)
if l >= 5 {
return result[:5]
} else {
if l := len(result); l < 5 {
return result[:l]
} else {
return result[:5]
}
}

12
internal/util/strings.go Normal file
View file

@ -0,0 +1,12 @@
package util
import (
"unicode"
)
func UcFirst(str string) string {
for i, v := range str {
return string(unicode.ToUpper(v)) + str[i+1:]
}
return ""
}

View file

@ -5,6 +5,6 @@ if [[ -z $1 ]] || [[ -z $2 ]]; then
exit 1
else
echo "Building 'photoprism/$1:$2'...";
docker build -t photoprism/$1:latest -t photoprism/$1:$2 -f docker/$1/Dockerfile .
docker build --build-arg BUILD_TAG=$2 -t photoprism/$1:latest -t photoprism/$1:$2 -f docker/$1/Dockerfile .
echo "Done"
fi
fi

View file

@ -6,7 +6,7 @@ MODEL_NAME="NASNet Mobile"
MODEL_URL="https://dl.photoprism.org/tensorflow/nasnet.zip?$TODAY"
MODEL_PATH="assets/tensorflow/nasnet"
MODEL_ZIP="/tmp/photoprism/nasnet.zip"
MODEL_HASH="6a9450f89afa56b4539c0d7188f108f083c10fc9 $MODEL_ZIP"
MODEL_HASH="cb893eaa93d59eca9e63ab10f76ae60519ecee24 $MODEL_ZIP"
MODEL_VERSION="$MODEL_PATH/version.txt"
MODEL_BACKUP="assets/backups/nasnet-$TODAY"