Backend: Set NSFW flag while indexing
Signed-off-by: Michael Mayer <michael@liquidbytes.net>
This commit is contained in:
parent
78eae2f14e
commit
8cce9f7c8c
16 changed files with 75 additions and 27 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
"github.com/photoprism/photoprism/internal/form"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/photoprism"
|
||||
"github.com/photoprism/photoprism/internal/util"
|
||||
)
|
||||
|
@ -22,8 +23,9 @@ func initIndexer(conf *config.Config) {
|
|||
}
|
||||
|
||||
tensorFlow := photoprism.NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer = photoprism.NewIndexer(conf, tensorFlow)
|
||||
indexer = photoprism.NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
}
|
||||
|
||||
// POST /api/v1/index
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/photoprism"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
@ -40,8 +41,9 @@ func importAction(ctx *cli.Context) error {
|
|||
log.Infof("importing photos from %s", conf.ImportPath())
|
||||
|
||||
tensorFlow := photoprism.NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := photoprism.NewIndexer(conf, tensorFlow)
|
||||
indexer := photoprism.NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
converter := photoprism.NewConverter(conf)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/photoprism"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
@ -39,8 +40,9 @@ func indexAction(ctx *cli.Context) error {
|
|||
}
|
||||
|
||||
tensorFlow := photoprism.NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := photoprism.NewIndexer(conf, tensorFlow)
|
||||
indexer := photoprism.NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
options := photoprism.IndexerOptionsAll()
|
||||
files := indexer.IndexOriginals(options)
|
||||
|
|
|
@ -473,6 +473,11 @@ func (c *Config) TensorFlowModelPath() string {
|
|||
return c.ResourcesPath() + "/nasnet"
|
||||
}
|
||||
|
||||
// NSFWModelPath returns the NSFW tensorflow model path.
|
||||
func (c *Config) NSFWModelPath() string {
|
||||
return c.ResourcesPath() + "/nsfw"
|
||||
}
|
||||
|
||||
// HttpTemplatesPath returns the server templates path.
|
||||
func (c *Config) HttpTemplatesPath() string {
|
||||
return c.ResourcesPath() + "/templates"
|
||||
|
|
1
internal/nsfw/.gitignore
vendored
Normal file
1
internal/nsfw/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
testdata/porn*
|
|
@ -27,21 +27,19 @@ func (l *Labels) IsSafe() bool {
|
|||
}
|
||||
|
||||
func (l *Labels) NSFW() bool {
|
||||
if l.Neutral > 0.25 && l.Porn < 0.75 {
|
||||
if l.Neutral > 0.25 {
|
||||
return false
|
||||
}
|
||||
if l.Porn > 0.4 {
|
||||
|
||||
if l.Porn > 0.75 {
|
||||
return true
|
||||
}
|
||||
if l.Sexy > 0.5 {
|
||||
if l.Sexy > 0.75 {
|
||||
return true
|
||||
}
|
||||
if l.Hentai > 0.75 {
|
||||
return true
|
||||
}
|
||||
if l.Drawing > 0.9 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -86,10 +86,11 @@ func TestNSFW(t *testing.T) {
|
|||
assert.GreaterOrEqual(t, l.Sexy, e.Sexy)
|
||||
}
|
||||
|
||||
isNSFW := strings.Contains(basename, "porn") || strings.Contains(basename, "hentai")
|
||||
isSafe := !(strings.Contains(basename, "porn") || strings.Contains(basename, "hentai"))
|
||||
|
||||
assert.Equal(t, isNSFW, l.NSFW())
|
||||
assert.Equal(t, !isNSFW, l.IsSafe())
|
||||
if isSafe {
|
||||
assert.True(t, l.IsSafe())
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
|
|
BIN
internal/nsfw/testdata/architecture.jpg
vendored
Normal file
BIN
internal/nsfw/testdata/architecture.jpg
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 171 KiB |
BIN
internal/nsfw/testdata/art.jpg
vendored
Normal file
BIN
internal/nsfw/testdata/art.jpg
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 118 KiB |
BIN
internal/nsfw/testdata/museum.jpg
vendored
Normal file
BIN
internal/nsfw/testdata/museum.jpg
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 135 KiB |
BIN
internal/nsfw/testdata/san-francisco.jpg
vendored
Normal file
BIN
internal/nsfw/testdata/san-francisco.jpg
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 161 KiB |
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -11,8 +12,9 @@ func TestNewImporter(t *testing.T) {
|
|||
conf := config.TestConfig()
|
||||
|
||||
tensorFlow := NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := NewIndexer(conf, tensorFlow)
|
||||
indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
converter := NewConverter(conf)
|
||||
|
||||
|
@ -27,8 +29,9 @@ func TestImporter_DestinationFilename(t *testing.T) {
|
|||
conf.InitializeTestData(t)
|
||||
|
||||
tensorFlow := NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := NewIndexer(conf, tensorFlow)
|
||||
indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
converter := NewConverter(conf)
|
||||
|
||||
|
@ -55,8 +58,9 @@ func TestImporter_ImportPhotosFromDirectory(t *testing.T) {
|
|||
conf.InitializeTestData(t)
|
||||
|
||||
tensorFlow := NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := NewIndexer(conf, tensorFlow)
|
||||
indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
converter := NewConverter(conf)
|
||||
|
||||
|
|
|
@ -7,22 +7,25 @@ import (
|
|||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
)
|
||||
|
||||
// Indexer defines an indexer with originals path tensorflow and a db.
|
||||
type Indexer struct {
|
||||
conf *config.Config
|
||||
tensorFlow *TensorFlow
|
||||
db *gorm.DB
|
||||
conf *config.Config
|
||||
tensorFlow *TensorFlow
|
||||
nsfwDetector *nsfw.Detector
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewIndexer returns a new indexer.
|
||||
// TODO: Is it really necessary to return a pointer?
|
||||
func NewIndexer(conf *config.Config, tensorFlow *TensorFlow) *Indexer {
|
||||
func NewIndexer(conf *config.Config, tensorFlow *TensorFlow, nsfwDetector *nsfw.Detector) *Indexer {
|
||||
i := &Indexer{
|
||||
conf: conf,
|
||||
tensorFlow: tensorFlow,
|
||||
db: conf.Db(),
|
||||
conf: conf,
|
||||
tensorFlow: tensorFlow,
|
||||
nsfwDetector: nsfwDetector,
|
||||
db: conf.Db(),
|
||||
}
|
||||
|
||||
return i
|
||||
|
|
|
@ -2,6 +2,7 @@ package photoprism
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -29,6 +30,7 @@ func (i *Indexer) indexMediaFile(m *MediaFile, o IndexerOptions) IndexResult {
|
|||
var exifData *Exif
|
||||
var photoQuery, fileQuery *gorm.DB
|
||||
var keywords []string
|
||||
var isNSFW bool
|
||||
|
||||
labels := Labels{}
|
||||
fileBase := m.Basename()
|
||||
|
@ -86,7 +88,8 @@ func (i *Indexer) indexMediaFile(m *MediaFile, o IndexerOptions) IndexResult {
|
|||
if file.FilePrimary {
|
||||
if fileChanged || o.UpdateKeywords || o.UpdateLabels || o.UpdateTitle {
|
||||
// Image classification labels
|
||||
labels = i.classifyImage(m)
|
||||
labels, isNSFW = i.classifyImage(m)
|
||||
photo.PhotoNSFW = isNSFW
|
||||
}
|
||||
|
||||
if fileChanged || o.UpdateExif {
|
||||
|
@ -225,7 +228,7 @@ func (i *Indexer) indexMediaFile(m *MediaFile, o IndexerOptions) IndexResult {
|
|||
}
|
||||
|
||||
// classifyImage returns all matching labels for a media file.
|
||||
func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels) {
|
||||
func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels, isNSFW bool) {
|
||||
start := time.Now()
|
||||
|
||||
var thumbs []string
|
||||
|
@ -256,6 +259,25 @@ func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels) {
|
|||
labels = append(labels, imageLabels...)
|
||||
}
|
||||
|
||||
if filename, err := jpeg.Thumbnail(i.thumbnailsPath(), "fit_720"); err != nil {
|
||||
log.Error(err)
|
||||
} else {
|
||||
if nsfwLabels, err := i.nsfwDetector.LabelsFromFile(filename); err != nil {
|
||||
log.Error(err)
|
||||
} else {
|
||||
log.Infof("nsfw: %+v", nsfwLabels)
|
||||
|
||||
if nsfwLabels.NSFW() {
|
||||
isNSFW = true
|
||||
}
|
||||
|
||||
if nsfwLabels.Sexy > 0.2 {
|
||||
uncertainty := 100 - int(math.Round(float64(nsfwLabels.Sexy*100)))
|
||||
labels = append(labels, Label{Name: "sexy", Source: "nsfw", Uncertainty: uncertainty, Priority: -1})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by priority and uncertainty
|
||||
sort.Sort(labels)
|
||||
|
||||
|
@ -271,11 +293,15 @@ func (i *Indexer) classifyImage(jpeg *MediaFile) (results Labels) {
|
|||
}
|
||||
}
|
||||
|
||||
if isNSFW {
|
||||
log.Info("index: image might contain sexually explicit content")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
log.Debugf("index: image classification took %s", elapsed)
|
||||
|
||||
return results
|
||||
return results, isNSFW
|
||||
}
|
||||
|
||||
func (i *Indexer) addLabels(photoId uint, labels Labels) {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
)
|
||||
|
||||
func TestIndexer_IndexAll(t *testing.T) {
|
||||
|
@ -16,8 +17,9 @@ func TestIndexer_IndexAll(t *testing.T) {
|
|||
conf.InitializeTestData(t)
|
||||
|
||||
tensorFlow := NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := NewIndexer(conf, tensorFlow)
|
||||
indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
converter := NewConverter(conf)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/disintegration/imaging"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/nsfw"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -66,8 +67,9 @@ func TestThumbnails_CreateThumbnailsFromOriginals(t *testing.T) {
|
|||
conf.InitializeTestData(t)
|
||||
|
||||
tensorFlow := NewTensorFlow(conf)
|
||||
nsfwDetector := nsfw.NewDetector(conf.NSFWModelPath())
|
||||
|
||||
indexer := NewIndexer(conf, tensorFlow)
|
||||
indexer := NewIndexer(conf, tensorFlow, nsfwDetector)
|
||||
|
||||
converter := NewConverter(conf)
|
||||
|
||||
|
|
Loading…
Reference in a new issue