Adds a generator that wraps store methods with transactions and migrates implementations to use transactions (#1440)

* Adds a generator that wraps store methods with transactions and migrates implementations to use transactions

* Remove OpenTracing parameters from the generator

* Remove unused template methods

* Generate transactional methods only for those labelled as so

* Fix linter
This commit is contained in:
Miguel de la Cruz 2021-10-22 12:48:53 +02:00 committed by GitHub
parent 2415c9f28b
commit 8666bc833a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 699 additions and 149 deletions

View file

@ -0,0 +1,264 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"path"
"strings"
"text/template"
)
const (
WithTransactionComment = "@withTransaction"
ErrorType = "error"
StringType = "string"
IntType = "int"
Int32Type = "int32"
Int64Type = "int64"
BoolType = "bool"
)
func isError(typeName string) bool {
return strings.Contains(typeName, ErrorType)
}
func isString(typeName string) bool {
return typeName == StringType
}
func isInt(typeName string) bool {
return typeName == IntType || typeName == Int32Type || typeName == Int64Type
}
func isBool(typeName string) bool {
return typeName == BoolType
}
func main() {
if err := buildTransactionalStore(); err != nil {
log.Fatal(err)
}
}
func buildTransactionalStore() error {
code, err := generateLayer("TransactionalStore", "transactional_store.go.tmpl")
if err != nil {
return err
}
formatedCode, err := format.Source(code)
if err != nil {
return err
}
return ioutil.WriteFile(path.Join("sqlstore/public_methods.go"), formatedCode, 0644) //nolint:gosec
}
type methodParam struct {
Name string
Type string
}
type methodData struct {
Params []methodParam
Results []string
WithTransaction bool
}
type storeMetadata struct {
Name string
Methods map[string]methodData
}
var blacklistedStoreMethodNames = map[string]bool{
"Shutdown": true,
}
func extractMethodMetadata(method *ast.Field, src []byte) methodData {
params := []methodParam{}
results := []string{}
withTransaction := false
ast.Inspect(method.Type, func(expr ast.Node) bool {
//nolint:gocritic
switch e := expr.(type) {
case *ast.FuncType:
if method.Doc != nil {
for _, comment := range method.Doc.List {
if strings.Contains(comment.Text, WithTransactionComment) {
withTransaction = true
break
}
}
}
if e.Params != nil {
for _, param := range e.Params.List {
for _, paramName := range param.Names {
params = append(params, methodParam{Name: paramName.Name, Type: string(src[param.Type.Pos()-1 : param.Type.End()-1])})
}
}
}
if e.Results != nil {
for _, result := range e.Results.List {
results = append(results, string(src[result.Type.Pos()-1:result.Type.End()-1]))
}
}
}
return true
})
return methodData{Params: params, Results: results, WithTransaction: withTransaction}
}
func extractStoreMetadata() (*storeMetadata, error) {
// Create the AST by parsing src.
fset := token.NewFileSet() // positions are relative to fset
file, err := os.Open("store.go")
if err != nil {
return nil, fmt.Errorf("unable to open store/store.go file: %w", err)
}
src, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
file.Close()
f, err := parser.ParseFile(fset, "", src, parser.AllErrors|parser.ParseComments)
if err != nil {
return nil, err
}
metadata := storeMetadata{Methods: map[string]methodData{}}
ast.Inspect(f, func(n ast.Node) bool {
//nolint:gocritic
switch x := n.(type) {
case *ast.TypeSpec:
if x.Name.Name == "Store" {
for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
methodName := method.Names[0].Name
if _, ok := blacklistedStoreMethodNames[methodName]; ok {
continue
}
metadata.Methods[methodName] = extractMethodMetadata(method, src)
}
}
}
return true
})
return &metadata, nil
}
func generateLayer(name, templateFile string) ([]byte, error) {
out := bytes.NewBufferString("")
metadata, err := extractStoreMetadata()
if err != nil {
return nil, err
}
metadata.Name = name
myFuncs := template.FuncMap{
"joinResultsForSignature": func(results []string) string {
if len(results) == 0 {
return ""
}
if len(results) == 1 {
return strings.Join(results, ", ")
}
return fmt.Sprintf("(%s)", strings.Join(results, ", "))
},
"genResultsVars": func(results []string, withNilError bool) string {
vars := []string{}
for i, typeName := range results {
switch {
case isError(typeName):
if withNilError {
vars = append(vars, "nil")
} else {
vars = append(vars, "err")
}
case i == 0:
vars = append(vars, "result")
default:
vars = append(vars, fmt.Sprintf("resultVar%d", i))
}
}
return strings.Join(vars, ", ")
},
"errorPresent": func(results []string) bool {
for _, typeName := range results {
if isError(typeName) {
return true
}
}
return false
},
"errorVar": func(results []string) string {
for _, typeName := range results {
if isError(typeName) {
return "err"
}
}
return ""
},
"joinParams": func(params []methodParam) string {
paramsNames := make([]string, 0, len(params))
for _, param := range params {
tParams := ""
if strings.HasPrefix(param.Type, "...") {
tParams = "..."
}
paramsNames = append(paramsNames, param.Name+tParams)
}
return strings.Join(paramsNames, ", ")
},
"joinParamsWithType": func(params []methodParam) string {
paramsWithType := []string{}
for _, param := range params {
switch param.Type {
case "Container":
paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type))
default:
paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type))
}
}
return strings.Join(paramsWithType, ", ")
},
"renameStoreMethod": func(methodName string) string {
return strings.ToLower(methodName[0:1]) + methodName[1:]
},
"genErrorResultsVars": func(results []string, errName string) string {
vars := []string{}
for _, typeName := range results {
switch {
case isError(typeName):
vars = append(vars, errName)
case isString(typeName):
vars = append(vars, "\"\"")
case isInt(typeName):
vars = append(vars, "0")
case isBool(typeName):
vars = append(vars, "false")
default:
vars = append(vars, "nil")
}
}
return strings.Join(vars, ", ")
},
}
t := template.Must(template.New(templateFile).Funcs(myFuncs).ParseFiles("generators/" + templateFile))
if err = t.Execute(out, metadata); err != nil {
return nil, err
}
return out.Bytes(), nil
}

View file

@ -0,0 +1,58 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
// Code generated by "make generate" from the Store interface
// DO NOT EDIT
// To add a public method, create an entry in the Store interface,
// prefix it with a @withTransaction comment if you need it to be
// transactional and then add a private method in the store itself
// with db sq.BaseRunner as the first parameter before running `make
// generate`
package sqlstore
import (
"context"
"github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/store"
"github.com/mattermost/mattermost-server/v6/shared/mlog"
)
{{range $index, $element := .Methods}}
func (s *SQLStore) {{$index}}({{$element.Params | joinParamsWithType}}) {{$element.Results | joinResultsForSignature}} {
{{- if $element.WithTransaction}}
tx, txErr := s.db.BeginTx(context.Background(), nil)
if txErr != nil {
return {{ genErrorResultsVars $element.Results "txErr"}}
}
{{- if $element.Results | len | eq 0}}
s.{{$index | renameStoreMethod}}(tx, {{$element.Params | joinParams}})
if err := tx.Commit(); err != nil {
return {{ genErrorResultsVars $element.Results "err"}}
}
{{else}}
{{genResultsVars $element.Results false }} := s.{{$index | renameStoreMethod}}(tx, {{$element.Params | joinParams}})
{{- if $element.Results | errorPresent }}
if {{$element.Results | errorVar}} != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error("transaction rollback error", mlog.Err(rollbackErr), mlog.String("methodName", "{{$index}}"))
}
return {{ genErrorResultsVars $element.Results "err"}}
}
{{end}}
if err := tx.Commit(); err != nil {
return {{ genErrorResultsVars $element.Results "err"}}
}
return {{ genResultsVars $element.Results true -}}
{{end}}
{{else}}
return s.{{$index | renameStoreMethod}}(s.db, {{$element.Params | joinParams}})
{{end}}
}
{{end}}

View file

@ -1,7 +1,6 @@
package sqlstore
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -49,8 +48,8 @@ func (s *SQLStore) blockFields() []string {
}
}
func (s *SQLStore) GetBlocksWithParentAndType(c store.Container, parentID string, blockType string) ([]model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getBlocksWithParentAndType(db sq.BaseRunner, c store.Container, parentID string, blockType string) ([]model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Eq{"COALESCE(workspace_id, '0')": c.WorkspaceID}).
@ -68,8 +67,8 @@ func (s *SQLStore) GetBlocksWithParentAndType(c store.Container, parentID string
return s.blocksFromRows(rows)
}
func (s *SQLStore) GetBlocksWithParent(c store.Container, parentID string) ([]model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getBlocksWithParent(db sq.BaseRunner, c store.Container, parentID string) ([]model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Eq{"parent_id": parentID}).
@ -86,8 +85,8 @@ func (s *SQLStore) GetBlocksWithParent(c store.Container, parentID string) ([]mo
return s.blocksFromRows(rows)
}
func (s *SQLStore) GetBlocksWithRootID(c store.Container, rootID string) ([]model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getBlocksWithRootID(db sq.BaseRunner, c store.Container, rootID string) ([]model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Eq{"root_id": rootID}).
@ -104,8 +103,8 @@ func (s *SQLStore) GetBlocksWithRootID(c store.Container, rootID string) ([]mode
return s.blocksFromRows(rows)
}
func (s *SQLStore) GetBlocksWithType(c store.Container, blockType string) ([]model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getBlocksWithType(db sq.BaseRunner, c store.Container, blockType string) ([]model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Eq{"type": blockType}).
@ -123,8 +122,8 @@ func (s *SQLStore) GetBlocksWithType(c store.Container, blockType string) ([]mod
}
// GetSubTree2 returns blocks within 2 levels of the given blockID.
func (s *SQLStore) GetSubTree2(c store.Container, blockID string) ([]model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getSubTree2(db sq.BaseRunner, c store.Container, blockID string) ([]model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Or{sq.Eq{"id": blockID}, sq.Eq{"parent_id": blockID}}).
@ -142,9 +141,9 @@ func (s *SQLStore) GetSubTree2(c store.Container, blockID string) ([]model.Block
}
// GetSubTree3 returns blocks within 3 levels of the given blockID.
func (s *SQLStore) GetSubTree3(c store.Container, blockID string) ([]model.Block, error) {
func (s *SQLStore) getSubTree3(db sq.BaseRunner, c store.Container, blockID string) ([]model.Block, error) {
// This first subquery returns repeated blocks
query := s.getQueryBuilder().Select(
query := s.getQueryBuilder(db).Select(
"l3.id",
"l3.parent_id",
"l3.root_id",
@ -182,8 +181,8 @@ func (s *SQLStore) GetSubTree3(c store.Container, blockID string) ([]model.Block
return s.blocksFromRows(rows)
}
func (s *SQLStore) GetAllBlocks(c store.Container) ([]model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getAllBlocks(db sq.BaseRunner, c store.Container) ([]model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Eq{"coalesce(workspace_id, '0')": c.WorkspaceID})
@ -246,8 +245,8 @@ func (s *SQLStore) blocksFromRows(rows *sql.Rows) ([]model.Block, error) {
return results, nil
}
func (s *SQLStore) GetRootID(c store.Container, blockID string) (string, error) {
query := s.getQueryBuilder().Select("root_id").
func (s *SQLStore) getRootID(db sq.BaseRunner, c store.Container, blockID string) (string, error) {
query := s.getQueryBuilder(db).Select("root_id").
From(s.tablePrefix + "blocks").
Where(sq.Eq{"id": blockID}).
Where(sq.Eq{"coalesce(workspace_id, '0')": c.WorkspaceID})
@ -264,8 +263,8 @@ func (s *SQLStore) GetRootID(c store.Container, blockID string) (string, error)
return rootID, nil
}
func (s *SQLStore) GetParentID(c store.Container, blockID string) (string, error) {
query := s.getQueryBuilder().Select("parent_id").
func (s *SQLStore) getParentID(db sq.BaseRunner, c store.Container, blockID string) (string, error) {
query := s.getQueryBuilder(db).Select("parent_id").
From(s.tablePrefix + "blocks").
Where(sq.Eq{"id": blockID}).
Where(sq.Eq{"coalesce(workspace_id, '0')": c.WorkspaceID})
@ -282,7 +281,7 @@ func (s *SQLStore) GetParentID(c store.Container, blockID string) (string, error
return parentID, nil
}
func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID string) error {
func (s *SQLStore) insertBlock(db sq.BaseRunner, c store.Container, block *model.Block, userID string) error {
if block.RootID == "" {
return RootIDNilError{}
}
@ -292,18 +291,12 @@ func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID str
return err
}
existingBlock, err := s.GetBlock(c, block.ID)
existingBlock, err := s.getBlock(db, c, block.ID)
if err != nil {
return err
}
ctx := context.Background()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
insertQuery := s.getQueryBuilder().Insert("").
insertQuery := s.getQueryBuilder(db).Insert("").
Columns(
"workspace_id",
"id",
@ -341,7 +334,7 @@ func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID str
if existingBlock != nil {
// block with ID exists, so this is an update operation
query := s.getQueryBuilder().Update(s.tablePrefix+"blocks").
query := s.getQueryBuilder(db).Update(s.tablePrefix+"blocks").
Where(sq.Eq{"id": block.ID}).
Where(sq.Eq{"COALESCE(workspace_id, '0')": c.WorkspaceID}).
Set("parent_id", block.ParentID).
@ -354,18 +347,9 @@ func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID str
Set("update_at", block.UpdateAt).
Set("delete_at", block.DeleteAt)
q, args, err2 := query.ToSql()
if err2 != nil {
s.logger.Error("InsertBlock error converting update query object to SQL", mlog.Err(err2))
return err2
}
if _, err2 := tx.Exec(q, args...); err2 != nil {
s.logger.Error(`InsertBlock error occurred while updating existing block`, mlog.String("blockID", block.ID), mlog.Err(err2))
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn("Transaction rollback error", mlog.Err(rollbackErr))
}
return err2
if _, err := query.Exec(); err != nil {
s.logger.Error(`InsertBlock error occurred while updating existing block`, mlog.String("blockID", block.ID), mlog.Err(err))
return err
}
} else {
block.CreatedBy = userID
@ -378,38 +362,23 @@ func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID str
insertQueryValues["update_at"] = block.UpdateAt
insertQueryValues["modified_by"] = block.ModifiedBy
query := insertQuery.SetMap(insertQueryValues)
_, err = sq.ExecContextWith(ctx, tx, query.Into(s.tablePrefix+"blocks"))
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn("Transaction rollback error", mlog.Err(rollbackErr))
}
query := insertQuery.SetMap(insertQueryValues).Into(s.tablePrefix + "blocks")
if _, err := query.Exec(); err != nil {
return err
}
}
// writing block history
query := insertQuery.SetMap(insertQueryValues)
_, err = sq.ExecContextWith(ctx, tx, query.Into(s.tablePrefix+"blocks_history"))
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn("Transaction rollback error", mlog.Err(rollbackErr))
}
return err
}
err = tx.Commit()
if err != nil {
query := insertQuery.SetMap(insertQueryValues).Into(s.tablePrefix + "blocks_history")
if _, err := query.Exec(); err != nil {
return err
}
return nil
}
func (s *SQLStore) PatchBlock(c store.Container, blockID string, blockPatch *model.BlockPatch, userID string) error {
existingBlock, err := s.GetBlock(c, blockID)
func (s *SQLStore) patchBlock(db sq.BaseRunner, c store.Container, blockID string, blockPatch *model.BlockPatch, userID string) error {
existingBlock, err := s.getBlock(db, c, blockID)
if err != nil {
return err
}
@ -418,18 +387,12 @@ func (s *SQLStore) PatchBlock(c store.Container, blockID string, blockPatch *mod
}
block := blockPatch.Patch(existingBlock)
return s.InsertBlock(c, block, userID)
return s.insertBlock(db, c, block, userID)
}
func (s *SQLStore) DeleteBlock(c store.Container, blockID string, modifiedBy string) error {
ctx := context.Background()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
func (s *SQLStore) deleteBlock(db sq.BaseRunner, c store.Container, blockID string, modifiedBy string) error {
now := utils.GetMillis()
insertQuery := s.getQueryBuilder().Insert(s.tablePrefix+"blocks_history").
insertQuery := s.getQueryBuilder(db).Insert(s.tablePrefix+"blocks_history").
Columns(
"workspace_id",
"id",
@ -445,37 +408,24 @@ func (s *SQLStore) DeleteBlock(c store.Container, blockID string, modifiedBy str
now,
)
_, err = sq.ExecContextWith(ctx, tx, insertQuery)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn("Transaction rollback error", mlog.Err(rollbackErr))
}
if _, err := insertQuery.Exec(); err != nil {
return err
}
deleteQuery := s.getQueryBuilder().
deleteQuery := s.getQueryBuilder(db).
Delete(s.tablePrefix + "blocks").
Where(sq.Eq{"id": blockID}).
Where(sq.Eq{"COALESCE(workspace_id, '0')": c.WorkspaceID})
_, err = sq.ExecContextWith(ctx, tx, deleteQuery)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn("Transaction rollback error", mlog.Err(rollbackErr))
}
return err
}
err = tx.Commit()
if err != nil {
if _, err := deleteQuery.Exec(); err != nil {
return err
}
return nil
}
func (s *SQLStore) GetBlockCountsByType() (map[string]int64, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getBlockCountsByType(db sq.BaseRunner) (map[string]int64, error) {
query := s.getQueryBuilder(db).
Select(
"type",
"COUNT(*) AS count",
@ -507,8 +457,8 @@ func (s *SQLStore) GetBlockCountsByType() (map[string]int64, error) {
return m, nil
}
func (s *SQLStore) GetBlock(c store.Container, blockID string) (*model.Block, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getBlock(db sq.BaseRunner, c store.Container, blockID string) (*model.Block, error) {
query := s.getQueryBuilder(db).
Select(s.blockFields()...).
From(s.tablePrefix + "blocks").
Where(sq.Eq{"id": blockID}).

View file

@ -57,7 +57,7 @@ func (s *SQLStore) importInitialTemplates() error {
// isInitializationNeeded returns true if the blocks table is empty.
func (s *SQLStore) isInitializationNeeded() (bool, error) {
query := s.getQueryBuilder().
query := s.getQueryBuilder(s.db).
Select("count(*)").
From(s.tablePrefix + "blocks").
Where(sq.Eq{"COALESCE(workspace_id, '0')": "0"})

View file

@ -0,0 +1,270 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
// Code generated by "make generate" from the Store interface
// DO NOT EDIT
// To add a public method, create an entry in the Store interface,
// prefix it with a @withTransaction comment if you need it to be
// transactional and then add a private method in the store itself
// with db sq.BaseRunner as the first parameter before running `make
// generate`
package sqlstore
import (
"context"
"github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/store"
"github.com/mattermost/mattermost-server/v6/shared/mlog"
)
func (s *SQLStore) CleanUpSessions(expireTime int64) error {
return s.cleanUpSessions(s.db, expireTime)
}
func (s *SQLStore) CreateSession(session *model.Session) error {
return s.createSession(s.db, session)
}
func (s *SQLStore) CreateUser(user *model.User) error {
return s.createUser(s.db, user)
}
func (s *SQLStore) DeleteBlock(c store.Container, blockID string, modifiedBy string) error {
tx, txErr := s.db.BeginTx(context.Background(), nil)
if txErr != nil {
return txErr
}
err := s.deleteBlock(tx, c, blockID, modifiedBy)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error("transaction rollback error", mlog.Err(rollbackErr), mlog.String("methodName", "DeleteBlock"))
}
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (s *SQLStore) DeleteSession(sessionID string) error {
return s.deleteSession(s.db, sessionID)
}
func (s *SQLStore) GetActiveUserCount(updatedSecondsAgo int64) (int, error) {
return s.getActiveUserCount(s.db, updatedSecondsAgo)
}
func (s *SQLStore) GetAllBlocks(c store.Container) ([]model.Block, error) {
return s.getAllBlocks(s.db, c)
}
func (s *SQLStore) GetBlock(c store.Container, blockID string) (*model.Block, error) {
return s.getBlock(s.db, c, blockID)
}
func (s *SQLStore) GetBlockCountsByType() (map[string]int64, error) {
return s.getBlockCountsByType(s.db)
}
func (s *SQLStore) GetBlocksWithParent(c store.Container, parentID string) ([]model.Block, error) {
return s.getBlocksWithParent(s.db, c, parentID)
}
func (s *SQLStore) GetBlocksWithParentAndType(c store.Container, parentID string, blockType string) ([]model.Block, error) {
return s.getBlocksWithParentAndType(s.db, c, parentID, blockType)
}
func (s *SQLStore) GetBlocksWithRootID(c store.Container, rootID string) ([]model.Block, error) {
return s.getBlocksWithRootID(s.db, c, rootID)
}
func (s *SQLStore) GetBlocksWithType(c store.Container, blockType string) ([]model.Block, error) {
return s.getBlocksWithType(s.db, c, blockType)
}
func (s *SQLStore) GetParentID(c store.Container, blockID string) (string, error) {
return s.getParentID(s.db, c, blockID)
}
func (s *SQLStore) GetRegisteredUserCount() (int, error) {
return s.getRegisteredUserCount(s.db)
}
func (s *SQLStore) GetRootID(c store.Container, blockID string) (string, error) {
return s.getRootID(s.db, c, blockID)
}
func (s *SQLStore) GetSession(token string, expireTime int64) (*model.Session, error) {
return s.getSession(s.db, token, expireTime)
}
func (s *SQLStore) GetSharing(c store.Container, rootID string) (*model.Sharing, error) {
return s.getSharing(s.db, c, rootID)
}
func (s *SQLStore) GetSubTree2(c store.Container, blockID string) ([]model.Block, error) {
return s.getSubTree2(s.db, c, blockID)
}
func (s *SQLStore) GetSubTree3(c store.Container, blockID string) ([]model.Block, error) {
return s.getSubTree3(s.db, c, blockID)
}
func (s *SQLStore) GetSystemSettings() (map[string]string, error) {
return s.getSystemSettings(s.db)
}
func (s *SQLStore) GetUserByEmail(email string) (*model.User, error) {
return s.getUserByEmail(s.db, email)
}
func (s *SQLStore) GetUserByID(userID string) (*model.User, error) {
return s.getUserByID(s.db, userID)
}
func (s *SQLStore) GetUserByUsername(username string) (*model.User, error) {
return s.getUserByUsername(s.db, username)
}
func (s *SQLStore) GetUserWorkspaces(userID string) ([]model.UserWorkspace, error) {
return s.getUserWorkspaces(s.db, userID)
}
func (s *SQLStore) GetUsersByWorkspace(workspaceID string) ([]*model.User, error) {
return s.getUsersByWorkspace(s.db, workspaceID)
}
func (s *SQLStore) GetWorkspace(ID string) (*model.Workspace, error) {
return s.getWorkspace(s.db, ID)
}
func (s *SQLStore) GetWorkspaceCount() (int64, error) {
return s.getWorkspaceCount(s.db)
}
func (s *SQLStore) HasWorkspaceAccess(userID string, workspaceID string) (bool, error) {
return s.hasWorkspaceAccess(s.db, userID, workspaceID)
}
func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID string) error {
tx, txErr := s.db.BeginTx(context.Background(), nil)
if txErr != nil {
return txErr
}
err := s.insertBlock(tx, c, block, userID)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error("transaction rollback error", mlog.Err(rollbackErr), mlog.String("methodName", "InsertBlock"))
}
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (s *SQLStore) PatchBlock(c store.Container, blockID string, blockPatch *model.BlockPatch, userID string) error {
tx, txErr := s.db.BeginTx(context.Background(), nil)
if txErr != nil {
return txErr
}
err := s.patchBlock(tx, c, blockID, blockPatch, userID)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error("transaction rollback error", mlog.Err(rollbackErr), mlog.String("methodName", "PatchBlock"))
}
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (s *SQLStore) RefreshSession(session *model.Session) error {
return s.refreshSession(s.db, session)
}
func (s *SQLStore) SetSystemSetting(key string, value string) error {
return s.setSystemSetting(s.db, key, value)
}
func (s *SQLStore) UpdateSession(session *model.Session) error {
return s.updateSession(s.db, session)
}
func (s *SQLStore) UpdateUser(user *model.User) error {
return s.updateUser(s.db, user)
}
func (s *SQLStore) UpdateUserPassword(username string, password string) error {
return s.updateUserPassword(s.db, username, password)
}
func (s *SQLStore) UpdateUserPasswordByID(userID string, password string) error {
return s.updateUserPasswordByID(s.db, userID, password)
}
func (s *SQLStore) UpsertSharing(c store.Container, sharing model.Sharing) error {
return s.upsertSharing(s.db, c, sharing)
}
func (s *SQLStore) UpsertWorkspaceSettings(workspace model.Workspace) error {
return s.upsertWorkspaceSettings(s.db, workspace)
}
func (s *SQLStore) UpsertWorkspaceSignupToken(workspace model.Workspace) error {
return s.upsertWorkspaceSignupToken(s.db, workspace)
}

View file

@ -9,8 +9,8 @@ import (
)
// GetActiveUserCount returns the number of users with active sessions within N seconds ago.
func (s *SQLStore) GetActiveUserCount(updatedSecondsAgo int64) (int, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getActiveUserCount(db sq.BaseRunner, updatedSecondsAgo int64) (int, error) {
query := s.getQueryBuilder(db).
Select("count(distinct user_id)").
From(s.tablePrefix + "sessions").
Where(sq.Gt{"update_at": utils.GetMillis() - utils.SecondsToMillis(updatedSecondsAgo)})
@ -26,8 +26,8 @@ func (s *SQLStore) GetActiveUserCount(updatedSecondsAgo int64) (int, error) {
return count, nil
}
func (s *SQLStore) GetSession(token string, expireTimeSeconds int64) (*model.Session, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getSession(db sq.BaseRunner, token string, expireTimeSeconds int64) (*model.Session, error) {
query := s.getQueryBuilder(db).
Select("id", "token", "user_id", "auth_service", "props").
From(s.tablePrefix + "sessions").
Where(sq.Eq{"token": token}).
@ -50,7 +50,7 @@ func (s *SQLStore) GetSession(token string, expireTimeSeconds int64) (*model.Ses
return &session, nil
}
func (s *SQLStore) CreateSession(session *model.Session) error {
func (s *SQLStore) createSession(db sq.BaseRunner, session *model.Session) error {
now := utils.GetMillis()
propsBytes, err := json.Marshal(session.Props)
@ -58,7 +58,7 @@ func (s *SQLStore) CreateSession(session *model.Session) error {
return err
}
query := s.getQueryBuilder().Insert(s.tablePrefix+"sessions").
query := s.getQueryBuilder(db).Insert(s.tablePrefix+"sessions").
Columns("id", "token", "user_id", "auth_service", "props", "create_at", "update_at").
Values(session.ID, session.Token, session.UserID, session.AuthService, propsBytes, now, now)
@ -66,10 +66,10 @@ func (s *SQLStore) CreateSession(session *model.Session) error {
return err
}
func (s *SQLStore) RefreshSession(session *model.Session) error {
func (s *SQLStore) refreshSession(db sq.BaseRunner, session *model.Session) error {
now := utils.GetMillis()
query := s.getQueryBuilder().Update(s.tablePrefix+"sessions").
query := s.getQueryBuilder(db).Update(s.tablePrefix+"sessions").
Where(sq.Eq{"token": session.Token}).
Set("update_at", now)
@ -77,7 +77,7 @@ func (s *SQLStore) RefreshSession(session *model.Session) error {
return err
}
func (s *SQLStore) UpdateSession(session *model.Session) error {
func (s *SQLStore) updateSession(db sq.BaseRunner, session *model.Session) error {
now := utils.GetMillis()
propsBytes, err := json.Marshal(session.Props)
@ -85,7 +85,7 @@ func (s *SQLStore) UpdateSession(session *model.Session) error {
return err
}
query := s.getQueryBuilder().Update(s.tablePrefix+"sessions").
query := s.getQueryBuilder(db).Update(s.tablePrefix+"sessions").
Where(sq.Eq{"token": session.Token}).
Set("update_at", now).
Set("props", propsBytes)
@ -94,16 +94,16 @@ func (s *SQLStore) UpdateSession(session *model.Session) error {
return err
}
func (s *SQLStore) DeleteSession(sessionID string) error {
query := s.getQueryBuilder().Delete(s.tablePrefix + "sessions").
func (s *SQLStore) deleteSession(db sq.BaseRunner, sessionID string) error {
query := s.getQueryBuilder(db).Delete(s.tablePrefix + "sessions").
Where(sq.Eq{"id": sessionID})
_, err := query.Exec()
return err
}
func (s *SQLStore) CleanUpSessions(expireTimeSeconds int64) error {
query := s.getQueryBuilder().Delete(s.tablePrefix + "sessions").
func (s *SQLStore) cleanUpSessions(db sq.BaseRunner, expireTimeSeconds int64) error {
query := s.getQueryBuilder(db).Delete(s.tablePrefix + "sessions").
Where(sq.Lt{"update_at": utils.GetMillis() - utils.SecondsToMillis(expireTimeSeconds)})
_, err := query.Exec()

View file

@ -8,10 +8,10 @@ import (
sq "github.com/Masterminds/squirrel"
)
func (s *SQLStore) UpsertSharing(c store.Container, sharing model.Sharing) error {
func (s *SQLStore) upsertSharing(db sq.BaseRunner, _ store.Container, sharing model.Sharing) error {
now := utils.GetMillis()
query := s.getQueryBuilder().
query := s.getQueryBuilder(db).
Insert(s.tablePrefix+"sharing").
Columns(
"id",
@ -32,7 +32,7 @@ func (s *SQLStore) UpsertSharing(c store.Container, sharing model.Sharing) error
sharing.Enabled, sharing.Token, sharing.ModifiedBy, now)
} else {
query = query.Suffix(
`ON CONFLICT (id)
`ON CONFLICT (id)
DO UPDATE SET enabled = EXCLUDED.enabled, token = EXCLUDED.token, modified_by = EXCLUDED.modified_by, update_at = EXCLUDED.update_at`,
)
}
@ -41,8 +41,8 @@ func (s *SQLStore) UpsertSharing(c store.Container, sharing model.Sharing) error
return err
}
func (s *SQLStore) GetSharing(c store.Container, rootID string) (*model.Sharing, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getSharing(db sq.BaseRunner, _ store.Container, rootID string) (*model.Sharing, error) {
query := s.getQueryBuilder(db).
Select(
"id",
"enabled",

View file

@ -66,13 +66,13 @@ func (s *SQLStore) DBHandle() *sql.DB {
return s.db
}
func (s *SQLStore) getQueryBuilder() sq.StatementBuilderType {
func (s *SQLStore) getQueryBuilder(db sq.BaseRunner) sq.StatementBuilderType {
builder := sq.StatementBuilder
if s.dbType == postgresDBType || s.dbType == sqliteDBType {
builder = builder.PlaceholderFormat(sq.Dollar)
}
return builder.RunWith(s.db)
return builder.RunWith(db)
}
func (s *SQLStore) escapeField(fieldName string) string { //nolint:unparam

View file

@ -1,7 +1,11 @@
package sqlstore
func (s *SQLStore) GetSystemSettings() (map[string]string, error) {
query := s.getQueryBuilder().Select("*").From(s.tablePrefix + "system_settings")
import (
sq "github.com/Masterminds/squirrel"
)
func (s *SQLStore) getSystemSettings(db sq.BaseRunner) (map[string]string, error) {
query := s.getQueryBuilder(db).Select("*").From(s.tablePrefix + "system_settings")
rows, err := query.Query()
if err != nil {
@ -26,8 +30,8 @@ func (s *SQLStore) GetSystemSettings() (map[string]string, error) {
return results, nil
}
func (s *SQLStore) SetSystemSetting(id, value string) error {
query := s.getQueryBuilder().Insert(s.tablePrefix+"system_settings").Columns("id", "value").Values(id, value)
func (s *SQLStore) setSystemSetting(db sq.BaseRunner, id, value string) error {
query := s.getQueryBuilder(db).Insert(s.tablePrefix+"system_settings").Columns("id", "value").Values(id, value)
if s.dbType == mysqlDBType {
query = query.Suffix("ON DUPLICATE KEY UPDATE value = ?", value)

View file

@ -20,8 +20,8 @@ func (unf UserNotFoundError) Error() string {
return fmt.Sprintf("user not found (%s)", unf.id)
}
func (s *SQLStore) GetRegisteredUserCount() (int, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getRegisteredUserCount(db sq.BaseRunner) (int, error) {
query := s.getQueryBuilder(db).
Select("count(*)").
From(s.tablePrefix + "users").
Where(sq.Eq{"delete_at": 0})
@ -36,8 +36,8 @@ func (s *SQLStore) GetRegisteredUserCount() (int, error) {
return count, nil
}
func (s *SQLStore) getUserByCondition(condition sq.Eq) (*model.User, error) {
users, err := s.getUsersByCondition(condition)
func (s *SQLStore) getUserByCondition(db sq.BaseRunner, condition sq.Eq) (*model.User, error) {
users, err := s.getUsersByCondition(db, condition)
if err != nil {
return nil, err
}
@ -49,8 +49,8 @@ func (s *SQLStore) getUserByCondition(condition sq.Eq) (*model.User, error) {
return users[0], nil
}
func (s *SQLStore) getUsersByCondition(condition sq.Eq) ([]*model.User, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getUsersByCondition(db sq.BaseRunner, condition sq.Eq) ([]*model.User, error) {
query := s.getQueryBuilder(db).
Select(
"id",
"username",
@ -86,19 +86,19 @@ func (s *SQLStore) getUsersByCondition(condition sq.Eq) ([]*model.User, error) {
return users, nil
}
func (s *SQLStore) GetUserByID(userID string) (*model.User, error) {
return s.getUserByCondition(sq.Eq{"id": userID})
func (s *SQLStore) getUserByID(db sq.BaseRunner, userID string) (*model.User, error) {
return s.getUserByCondition(db, sq.Eq{"id": userID})
}
func (s *SQLStore) GetUserByEmail(email string) (*model.User, error) {
return s.getUserByCondition(sq.Eq{"email": email})
func (s *SQLStore) getUserByEmail(db sq.BaseRunner, email string) (*model.User, error) {
return s.getUserByCondition(db, sq.Eq{"email": email})
}
func (s *SQLStore) GetUserByUsername(username string) (*model.User, error) {
return s.getUserByCondition(sq.Eq{"username": username})
func (s *SQLStore) getUserByUsername(db sq.BaseRunner, username string) (*model.User, error) {
return s.getUserByCondition(db, sq.Eq{"username": username})
}
func (s *SQLStore) CreateUser(user *model.User) error {
func (s *SQLStore) createUser(db sq.BaseRunner, user *model.User) error {
now := utils.GetMillis()
propsBytes, err := json.Marshal(user.Props)
@ -106,7 +106,7 @@ func (s *SQLStore) CreateUser(user *model.User) error {
return err
}
query := s.getQueryBuilder().Insert(s.tablePrefix+"users").
query := s.getQueryBuilder(db).Insert(s.tablePrefix+"users").
Columns("id", "username", "email", "password", "mfa_secret", "auth_service", "auth_data", "props", "create_at", "update_at", "delete_at").
Values(user.ID, user.Username, user.Email, user.Password, user.MfaSecret, user.AuthService, user.AuthData, propsBytes, now, now, 0)
@ -114,7 +114,7 @@ func (s *SQLStore) CreateUser(user *model.User) error {
return err
}
func (s *SQLStore) UpdateUser(user *model.User) error {
func (s *SQLStore) updateUser(db sq.BaseRunner, user *model.User) error {
now := utils.GetMillis()
propsBytes, err := json.Marshal(user.Props)
@ -122,7 +122,7 @@ func (s *SQLStore) UpdateUser(user *model.User) error {
return err
}
query := s.getQueryBuilder().Update(s.tablePrefix+"users").
query := s.getQueryBuilder(db).Update(s.tablePrefix+"users").
Set("username", user.Username).
Set("email", user.Email).
Set("props", propsBytes).
@ -146,10 +146,10 @@ func (s *SQLStore) UpdateUser(user *model.User) error {
return nil
}
func (s *SQLStore) UpdateUserPassword(username, password string) error {
func (s *SQLStore) updateUserPassword(db sq.BaseRunner, username, password string) error {
now := utils.GetMillis()
query := s.getQueryBuilder().Update(s.tablePrefix+"users").
query := s.getQueryBuilder(db).Update(s.tablePrefix+"users").
Set("password", password).
Set("update_at", now).
Where(sq.Eq{"username": username})
@ -171,10 +171,10 @@ func (s *SQLStore) UpdateUserPassword(username, password string) error {
return nil
}
func (s *SQLStore) UpdateUserPasswordByID(userID, password string) error {
func (s *SQLStore) updateUserPasswordByID(db sq.BaseRunner, userID, password string) error {
now := utils.GetMillis()
query := s.getQueryBuilder().Update(s.tablePrefix+"users").
query := s.getQueryBuilder(db).Update(s.tablePrefix+"users").
Set("password", password).
Set("update_at", now).
Where(sq.Eq{"id": userID})
@ -196,8 +196,8 @@ func (s *SQLStore) UpdateUserPasswordByID(userID, password string) error {
return nil
}
func (s *SQLStore) GetUsersByWorkspace(workspaceID string) ([]*model.User, error) {
return s.getUsersByCondition(nil)
func (s *SQLStore) getUsersByWorkspace(db sq.BaseRunner, _ string) ([]*model.User, error) {
return s.getUsersByCondition(db, nil)
}
func (s *SQLStore) usersFromRows(rows *sql.Rows) ([]*model.User, error) {

View file

@ -17,10 +17,10 @@ var (
errUnsupportedOperation = errors.New("unsupported operation")
)
func (s *SQLStore) UpsertWorkspaceSignupToken(workspace model.Workspace) error {
func (s *SQLStore) upsertWorkspaceSignupToken(db sq.BaseRunner, workspace model.Workspace) error {
now := utils.GetMillis()
query := s.getQueryBuilder().
query := s.getQueryBuilder(db).
Insert(s.tablePrefix+"workspaces").
Columns(
"id",
@ -48,7 +48,7 @@ func (s *SQLStore) UpsertWorkspaceSignupToken(workspace model.Workspace) error {
return err
}
func (s *SQLStore) UpsertWorkspaceSettings(workspace model.Workspace) error {
func (s *SQLStore) upsertWorkspaceSettings(db sq.BaseRunner, workspace model.Workspace) error {
now := utils.GetMillis()
signupToken := utils.NewID(utils.IDTypeToken)
@ -57,7 +57,7 @@ func (s *SQLStore) UpsertWorkspaceSettings(workspace model.Workspace) error {
return err
}
query := s.getQueryBuilder().
query := s.getQueryBuilder(db).
Insert(s.tablePrefix+"workspaces").
Columns(
"id",
@ -86,10 +86,10 @@ func (s *SQLStore) UpsertWorkspaceSettings(workspace model.Workspace) error {
return err
}
func (s *SQLStore) GetWorkspace(id string) (*model.Workspace, error) {
func (s *SQLStore) getWorkspace(db sq.BaseRunner, id string) (*model.Workspace, error) {
var settingsJSON string
query := s.getQueryBuilder().
query := s.getQueryBuilder(db).
Select(
"id",
"signup_token",
@ -122,12 +122,12 @@ func (s *SQLStore) GetWorkspace(id string) (*model.Workspace, error) {
return &workspace, nil
}
func (s *SQLStore) HasWorkspaceAccess(userID string, workspaceID string) (bool, error) {
func (s *SQLStore) hasWorkspaceAccess(db sq.BaseRunner, userID string, workspaceID string) (bool, error) {
return true, nil
}
func (s *SQLStore) GetWorkspaceCount() (int64, error) {
query := s.getQueryBuilder().
func (s *SQLStore) getWorkspaceCount(db sq.BaseRunner) (int64, error) {
query := s.getQueryBuilder(db).
Select(
"COUNT(*) AS count",
).
@ -151,6 +151,6 @@ func (s *SQLStore) GetWorkspaceCount() (int64, error) {
return count, nil
}
func (s *SQLStore) GetUserWorkspaces(userID string) ([]model.UserWorkspace, error) {
func (s *SQLStore) getUserWorkspaces(_ sq.BaseRunner, _ string) ([]model.UserWorkspace, error) {
return nil, fmt.Errorf("GetUserWorkspaces %w", errUnsupportedOperation)
}

View file

@ -1,4 +1,5 @@
//go:generate mockgen --build_flags=--mod=mod -destination=mockstore/mockstore.go -package mockstore . Store
//go:generate go run ./generators/main.go
package store
import "github.com/mattermost/focalboard/server/model"
@ -20,10 +21,13 @@ type Store interface {
GetAllBlocks(c Container) ([]model.Block, error)
GetRootID(c Container, blockID string) (string, error)
GetParentID(c Container, blockID string) (string, error)
// @withTransaction
InsertBlock(c Container, block *model.Block, userID string) error
// @withTransaction
DeleteBlock(c Container, blockID string, modifiedBy string) error
GetBlockCountsByType() (map[string]int64, error)
GetBlock(c Container, blockID string) (*model.Block, error)
// @withTransaction
PatchBlock(c Container, blockID string, blockPatch *model.BlockPatch, userID string) error
Shutdown() error