Created loadLabels function
This commit is contained in:
parent
73b16162ab
commit
d63f4ec09f
1 changed files with 24 additions and 19 deletions
|
@ -120,25 +120,8 @@ func (t *TensorFlow) Labels(img []byte) (result Labels, err error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (t *TensorFlow) loadModel() error {
|
||||
if t.model != nil {
|
||||
// Already loaded
|
||||
return nil
|
||||
}
|
||||
|
||||
savedModel := t.conf.TensorFlowModelPath()
|
||||
modelLabels := savedModel + "/labels.txt"
|
||||
|
||||
log.Infof("loading image classification model from \"%s\"", savedModel)
|
||||
|
||||
// Load model
|
||||
model, err := tf.LoadSavedModel(savedModel, []string{"photoprism"}, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.model = model
|
||||
func (t *TensorFlow) loadLabels(path string) error {
|
||||
modelLabels := path + "/labels.txt"
|
||||
|
||||
log.Infof("loading classification labels from \"%s\"", modelLabels)
|
||||
|
||||
|
@ -165,6 +148,28 @@ func (t *TensorFlow) loadModel() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *TensorFlow) loadModel() error {
|
||||
if t.model != nil {
|
||||
// Already loaded
|
||||
return nil
|
||||
}
|
||||
|
||||
path := t.conf.TensorFlowModelPath()
|
||||
|
||||
log.Infof("loading image classification model from \"%s\"", path)
|
||||
|
||||
// Load model
|
||||
model, err := tf.LoadSavedModel(path, []string{"photoprism"}, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.model = model
|
||||
|
||||
return t.loadLabels(path)
|
||||
}
|
||||
|
||||
func (t *TensorFlow) labelRule(label string) LabelRule {
|
||||
label = strings.ToLower(label)
|
||||
|
||||
|
|
Loading…
Reference in a new issue