Created loadLabels function

This commit is contained in:
Theresa Gresch 2019-07-17 11:53:33 +02:00
parent 73b16162ab
commit d63f4ec09f

View file

@ -120,25 +120,8 @@ func (t *TensorFlow) Labels(img []byte) (result Labels, err error) {
return result, nil
}
func (t *TensorFlow) loadModel() error {
if t.model != nil {
// Already loaded
return nil
}
savedModel := t.conf.TensorFlowModelPath()
modelLabels := savedModel + "/labels.txt"
log.Infof("loading image classification model from \"%s\"", savedModel)
// Load model
model, err := tf.LoadSavedModel(savedModel, []string{"photoprism"}, nil)
if err != nil {
return err
}
t.model = model
func (t *TensorFlow) loadLabels(path string) error {
modelLabels := path + "/labels.txt"
log.Infof("loading classification labels from \"%s\"", modelLabels)
@ -165,6 +148,28 @@ func (t *TensorFlow) loadModel() error {
return nil
}
func (t *TensorFlow) loadModel() error {
if t.model != nil {
// Already loaded
return nil
}
path := t.conf.TensorFlowModelPath()
log.Infof("loading image classification model from \"%s\"", path)
// Load model
model, err := tf.LoadSavedModel(path, []string{"photoprism"}, nil)
if err != nil {
return err
}
t.model = model
return t.loadLabels(path)
}
func (t *TensorFlow) labelRule(label string) LabelRule {
label = strings.ToLower(label)