diff --git a/internal/photoprism/tensorflow.go b/internal/photoprism/tensorflow.go index 0417750eb..ef1f9524a 100644 --- a/internal/photoprism/tensorflow.go +++ b/internal/photoprism/tensorflow.go @@ -120,25 +120,8 @@ func (t *TensorFlow) Labels(img []byte) (result Labels, err error) { return result, nil } -func (t *TensorFlow) loadModel() error { - if t.model != nil { - // Already loaded - return nil - } - - savedModel := t.conf.TensorFlowModelPath() - modelLabels := savedModel + "/labels.txt" - - log.Infof("loading image classification model from \"%s\"", savedModel) - - // Load model - model, err := tf.LoadSavedModel(savedModel, []string{"photoprism"}, nil) - - if err != nil { - return err - } - - t.model = model +func (t *TensorFlow) loadLabels(path string) error { + modelLabels := path + "/labels.txt" log.Infof("loading classification labels from \"%s\"", modelLabels) @@ -165,6 +148,28 @@ func (t *TensorFlow) loadModel() error { return nil } +func (t *TensorFlow) loadModel() error { + if t.model != nil { + // Already loaded + return nil + } + + path := t.conf.TensorFlowModelPath() + + log.Infof("loading image classification model from \"%s\"", path) + + // Load model + model, err := tf.LoadSavedModel(path, []string{"photoprism"}, nil) + + if err != nil { + return err + } + + t.model = model + + return t.loadLabels(path) +} + func (t *TensorFlow) labelRule(label string) LabelRule { label = strings.ToLower(label)