Search: Make special character escaping compatible with SQLite #1994

This commit is contained in:
Michael Mayer 2022-03-28 17:36:59 +02:00
parent e693fad8dc
commit 9e46a66f24
7 changed files with 124 additions and 69 deletions

View File

@ -10,6 +10,11 @@ import (
"github.com/jinzhu/inflection"
)
// Like escapes a string for use in a query.
func Like(s string) string {
return strings.Trim(sanitize.SqlString(s), " |&*%")
}
// LikeAny returns a single where condition matching the search words.
func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
if s == "" {
@ -44,9 +49,9 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
for _, w := range words {
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w)))
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
} else {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w)))
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
}
if !keywords || !txt.ContainsASCIILetters(w) {
@ -56,7 +61,7 @@ func LikeAny(col, s string, keywords, exact bool) (wheres []string) {
singular := inflection.Singular(w)
if singular != w {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(singular)))
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s'", col, Like(singular)))
}
}
@ -103,9 +108,9 @@ func LikeAll(col, s string, keywords, exact bool) (wheres []string) {
for _, w := range words {
if wildcardThreshold > 0 && len(w) >= wildcardThreshold {
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, SqlLike(w)))
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s%%'", col, Like(w)))
} else {
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, SqlLike(w)))
wheres = append(wheres, fmt.Sprintf("%s LIKE '%s'", col, Like(w)))
}
}
@ -140,9 +145,9 @@ func LikeAllNames(cols Cols, s string) (wheres []string) {
for _, c := range cols {
if strings.Contains(w, txt.Space) {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, SqlLike(w)))
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%s%%'", c, Like(w)))
} else {
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, SqlLike(w)))
orWheres = append(orWheres, fmt.Sprintf("%s LIKE '%%%s%%'", c, Like(w)))
}
}
}
@ -189,7 +194,7 @@ func AnySlug(col, search, sep string) (where string) {
}
for _, w := range words {
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, SqlLike(w)))
wheres = append(wheres, fmt.Sprintf("%s = '%s'", col, Like(w)))
}
return strings.Join(wheres, " OR ")

View File

@ -12,6 +12,24 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLike(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", Like(""))
})
t.Run("Special", func(t *testing.T) {
s := " ' \" \t \n %_''"
exp := "'' \"\" %_''''"
result := Like(s)
t.Logf("String..: %s", s)
t.Logf("Expected: %s", exp)
t.Logf("Result..: %s", result)
assert.Equal(t, exp, result)
})
t.Run("Alnum", func(t *testing.T) {
assert.Equal(t, "123ABCabc", Like(" 123ABCabc%* "))
})
}
func TestLikeAny(t *testing.T) {
t.Run("and_or_search", func(t *testing.T) {
if w := LikeAny("k.keyword", "table spoon & usa | img json", true, false); len(w) != 2 {
@ -119,11 +137,11 @@ func TestLikeAnyWord(t *testing.T) {
}
})
t.Run("EscapeSql", func(t *testing.T) {
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"usa"); len(w) != 2 {
if w := LikeAnyWord("k.keyword", "table% | 'spoon' & \"us'a"); len(w) != 2 {
t.Fatalf("two where conditions expected: %#v", w)
} else {
assert.Equal(t, "k.keyword LIKE 'spoon%' OR k.keyword LIKE 'table%'", w[0])
assert.Equal(t, "k.keyword LIKE '\\\"usa%'", w[1])
assert.Equal(t, "k.keyword LIKE '\"\"us''a%'", w[1])
}
})
}

View File

@ -99,8 +99,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
}
assert.Equal(t, len(photos), 0)
})
// TODO should not throw error
/*t.Run("albums middle '", func(t *testing.T) {
t.Run("AlbumsSingleQuote", func(t *testing.T) {
var f form.SearchPhotos
f.Albums = "Father's Day"
@ -113,7 +112,7 @@ func TestPhotosFilterAlbums(t *testing.T) {
}
assert.Greater(t, len(photos), 0)
})*/
})
t.Run("albums end '", func(t *testing.T) {
var f form.SearchPhotos
@ -190,8 +189,8 @@ func TestPhotosFilterAlbums(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// TODO: Needs review, variable number of results.
// TODO: Needs review, variable number of results.
assert.GreaterOrEqual(t, len(photos), 0)
})
t.Run("albums end |", func(t *testing.T) {
@ -340,8 +339,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
}
assert.Equal(t, len(photos), 0)
})
//TODO should not throw error
/*t.Run("albums middle '", func(t *testing.T) {
t.Run("AlbumsQuerySingleQuote", func(t *testing.T) {
var f form.SearchPhotos
f.Query = "albums:\"Father's Day\""
@ -354,7 +352,7 @@ func TestPhotosQueryAlbums(t *testing.T) {
}
assert.Greater(t, len(photos), 0)
})*/
})
t.Run("albums end '", func(t *testing.T) {
var f form.SearchPhotos
@ -431,8 +429,8 @@ func TestPhotosQueryAlbums(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// TODO: Needs review, variable number of results.
// TODO: Needs review, variable number of results.
assert.GreaterOrEqual(t, len(photos), 0)
})
t.Run("albums end |", func(t *testing.T) {

View File

@ -1,12 +0,0 @@
package search
import (
"strings"
"github.com/photoprism/photoprism/pkg/sanitize"
)
// SqlLike escapes a string for use in an SQL query.
func SqlLike(s string) string {
return strings.Trim(sanitize.SqlString(s), " |&*%")
}

View File

@ -1,25 +0,0 @@
package search
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSqlLike(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", SqlLike(""))
})
t.Run("Special", func(t *testing.T) {
s := "' \" \t \n %_''"
exp := "\\' \\\" %\\_\\'\\'"
result := SqlLike(s)
t.Logf("String..: %s", s)
t.Logf("Expected: %s", exp)
t.Logf("Result..: %s", result)
assert.Equal(t, exp, result)
})
t.Run("Alnum", func(t *testing.T) {
assert.Equal(t, "123ABCabc", SqlLike(" 123ABCabc%* "))
})
}

View File

@ -1,41 +1,53 @@
package sanitize
import (
"bytes"
)
// SqlSpecial checks if the byte must be escaped/omitted in SQL.
func SqlSpecial(b byte) (special bool, omit bool) {
if b < 32 {
return true, true
}
// sqlSpecialBytes contains special bytes to escape in SQL search queries.
// see https://mariadb.com/kb/en/string-literals/
var sqlSpecialBytes = []byte{34, 39, 92, 95} // ", ', \, _
switch b {
case '"', '\'', '\\':
return true, false
default:
return false, false
}
}
// SqlString escapes a string for use in an SQL query.
func SqlString(s string) string {
var i int
for i = 0; i < len(s); i++ {
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) {
if found, _ := SqlSpecial(s[i]); found {
break
}
}
// No special characters found, return original string.
// Return if no special characters were found.
if i >= len(s) {
return s
}
b := make([]byte, 2*len(s)-i)
copy(b, s[:i])
j := i
for ; i < len(s); i++ {
if s[i] < 31 {
// Ignore control chars.
if special, omit := SqlSpecial(s[i]); omit {
// Omit control characters.
continue
}
if bytes.Contains(sqlSpecialBytes, []byte{s[i]}) {
b[j] = '\\'
} else if special {
// Escape other special characters.
// see https://mariadb.com/kb/en/string-literals/
b[j] = s[i]
j++
}
b[j] = s[i]
j++
}
return string(b[:j])
}

View File

@ -6,13 +6,72 @@ import (
"github.com/stretchr/testify/assert"
)
func TestSqlSpecial(t *testing.T) {
t.Run("Special", func(t *testing.T) {
if s, o := SqlSpecial(1); !s {
t.Error("char is special")
} else if !o {
t.Error("\" must be omitted")
}
if s, o := SqlSpecial(31); !s {
t.Error("char is special")
} else if !o {
t.Error("\" must be omitted")
}
if s, o := SqlSpecial('\\'); !s {
t.Error("\\ is special")
} else if o {
t.Error("\\ must not be omitted")
}
if s, o := SqlSpecial('\''); !s {
t.Error("' is special")
} else if o {
t.Error("' must not be omitted")
}
if s, o := SqlSpecial('"'); !s {
t.Error("\" is special")
} else if o {
t.Error("\" must not be omitted")
}
})
t.Run("NotSpecial", func(t *testing.T) {
if s, o := SqlSpecial(32); s {
t.Error("space is not special")
} else if o {
t.Error("space must not be omitted")
}
if s, o := SqlSpecial('A'); s {
t.Error("A is not special")
} else if o {
t.Error("A must not be omitted")
}
if s, o := SqlSpecial('a'); s {
t.Error("a is not special")
} else if o {
t.Error("a must not be omitted")
}
if s, o := SqlSpecial('_'); s {
t.Error("_ is not special")
} else if o {
t.Error("_ must not be omitted")
}
})
}
func TestSqlString(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
assert.Equal(t, "", SqlString(""))
})
t.Run("Special", func(t *testing.T) {
s := "' \" \t \n %_''"
exp := "\\' \\\" %\\_\\'\\'"
exp := "'' \"\" %_''''"
result := SqlString(s)
t.Logf("String..: %s", s)
t.Logf("Expected: %s", exp)