266 lines
6.3 KiB
Go
266 lines
6.3 KiB
Go
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
|
// See LICENSE.txt for license information.
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
"text/template"
|
|
)
|
|
|
|
const (
|
|
WithTransactionComment = "@withTransaction"
|
|
ErrorType = "error"
|
|
StringType = "string"
|
|
IntType = "int"
|
|
Int32Type = "int32"
|
|
Int64Type = "int64"
|
|
BoolType = "bool"
|
|
)
|
|
|
|
func isError(typeName string) bool {
|
|
return strings.Contains(typeName, ErrorType)
|
|
}
|
|
|
|
func isString(typeName string) bool {
|
|
return typeName == StringType
|
|
}
|
|
|
|
func isInt(typeName string) bool {
|
|
return typeName == IntType || typeName == Int32Type || typeName == Int64Type
|
|
}
|
|
|
|
func isBool(typeName string) bool {
|
|
return typeName == BoolType
|
|
}
|
|
|
|
func main() {
|
|
if err := buildTransactionalStore(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func buildTransactionalStore() error {
|
|
code, err := generateLayer("TransactionalStore", "transactional_store.go.tmpl")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
formatedCode, err := format.Source(code)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return os.WriteFile(path.Join("sqlstore/public_methods.go"), formatedCode, 0644) //nolint:gosec
|
|
}
|
|
|
|
type methodParam struct {
|
|
Name string
|
|
Type string
|
|
}
|
|
|
|
type methodData struct {
|
|
Params []methodParam
|
|
Results []string
|
|
WithTransaction bool
|
|
}
|
|
|
|
type storeMetadata struct {
|
|
Name string
|
|
Methods map[string]methodData
|
|
}
|
|
|
|
var blacklistedStoreMethodNames = map[string]bool{
|
|
"Shutdown": true,
|
|
"DBType": true,
|
|
}
|
|
|
|
func extractMethodMetadata(method *ast.Field, src []byte) methodData {
|
|
params := []methodParam{}
|
|
results := []string{}
|
|
withTransaction := false
|
|
ast.Inspect(method.Type, func(expr ast.Node) bool {
|
|
//nolint:gocritic
|
|
switch e := expr.(type) {
|
|
case *ast.FuncType:
|
|
if method.Doc != nil {
|
|
for _, comment := range method.Doc.List {
|
|
if strings.Contains(comment.Text, WithTransactionComment) {
|
|
withTransaction = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if e.Params != nil {
|
|
for _, param := range e.Params.List {
|
|
for _, paramName := range param.Names {
|
|
params = append(params, methodParam{Name: paramName.Name, Type: string(src[param.Type.Pos()-1 : param.Type.End()-1])})
|
|
}
|
|
}
|
|
}
|
|
if e.Results != nil {
|
|
for _, result := range e.Results.List {
|
|
results = append(results, string(src[result.Type.Pos()-1:result.Type.End()-1]))
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return methodData{Params: params, Results: results, WithTransaction: withTransaction}
|
|
}
|
|
|
|
func extractStoreMetadata() (*storeMetadata, error) {
|
|
// Create the AST by parsing src.
|
|
fset := token.NewFileSet() // positions are relative to fset
|
|
|
|
file, err := os.Open("store.go")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to open store/store.go file: %w", err)
|
|
}
|
|
src, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
file.Close()
|
|
f, err := parser.ParseFile(fset, "", src, parser.AllErrors|parser.ParseComments)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
metadata := storeMetadata{Methods: map[string]methodData{}}
|
|
|
|
ast.Inspect(f, func(n ast.Node) bool {
|
|
//nolint:gocritic
|
|
switch x := n.(type) {
|
|
case *ast.TypeSpec:
|
|
if x.Name.Name == "Store" {
|
|
for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
|
|
methodName := method.Names[0].Name
|
|
if _, ok := blacklistedStoreMethodNames[methodName]; ok {
|
|
continue
|
|
}
|
|
|
|
metadata.Methods[methodName] = extractMethodMetadata(method, src)
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
return &metadata, nil
|
|
}
|
|
|
|
func generateLayer(name, templateFile string) ([]byte, error) {
|
|
out := bytes.NewBufferString("")
|
|
metadata, err := extractStoreMetadata()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
metadata.Name = name
|
|
|
|
myFuncs := template.FuncMap{
|
|
"joinResultsForSignature": func(results []string) string {
|
|
if len(results) == 0 {
|
|
return ""
|
|
}
|
|
if len(results) == 1 {
|
|
return strings.Join(results, ", ")
|
|
}
|
|
return fmt.Sprintf("(%s)", strings.Join(results, ", "))
|
|
},
|
|
"genResultsVars": func(results []string, withNilError bool) string {
|
|
vars := []string{}
|
|
for i, typeName := range results {
|
|
switch {
|
|
case isError(typeName):
|
|
if withNilError {
|
|
vars = append(vars, "nil")
|
|
} else {
|
|
vars = append(vars, "err")
|
|
}
|
|
case i == 0:
|
|
vars = append(vars, "result")
|
|
default:
|
|
vars = append(vars, fmt.Sprintf("resultVar%d", i))
|
|
}
|
|
}
|
|
return strings.Join(vars, ", ")
|
|
},
|
|
"errorPresent": func(results []string) bool {
|
|
for _, typeName := range results {
|
|
if isError(typeName) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
},
|
|
"errorVar": func(results []string) string {
|
|
for _, typeName := range results {
|
|
if isError(typeName) {
|
|
return "err"
|
|
}
|
|
}
|
|
return ""
|
|
},
|
|
"joinParams": func(params []methodParam) string {
|
|
paramsNames := make([]string, 0, len(params))
|
|
for _, param := range params {
|
|
tParams := ""
|
|
if strings.HasPrefix(param.Type, "...") {
|
|
tParams = "..."
|
|
}
|
|
paramsNames = append(paramsNames, param.Name+tParams)
|
|
}
|
|
return strings.Join(paramsNames, ", ")
|
|
},
|
|
"joinParamsWithType": func(params []methodParam) string {
|
|
paramsWithType := []string{}
|
|
for _, param := range params {
|
|
switch param.Type {
|
|
case "Container":
|
|
paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type))
|
|
default:
|
|
paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type))
|
|
}
|
|
}
|
|
return strings.Join(paramsWithType, ", ")
|
|
},
|
|
"renameStoreMethod": func(methodName string) string {
|
|
return strings.ToLower(methodName[0:1]) + methodName[1:]
|
|
},
|
|
"genErrorResultsVars": func(results []string, errName string) string {
|
|
vars := []string{}
|
|
for _, typeName := range results {
|
|
switch {
|
|
case isError(typeName):
|
|
vars = append(vars, errName)
|
|
case isString(typeName):
|
|
vars = append(vars, "\"\"")
|
|
case isInt(typeName):
|
|
vars = append(vars, "0")
|
|
case isBool(typeName):
|
|
vars = append(vars, "false")
|
|
default:
|
|
vars = append(vars, "nil")
|
|
}
|
|
}
|
|
return strings.Join(vars, ", ")
|
|
},
|
|
}
|
|
|
|
t := template.Must(template.New(templateFile).Funcs(myFuncs).ParseFiles("generators/" + templateFile))
|
|
if err = t.Execute(out, metadata); err != nil {
|
|
return nil, err
|
|
}
|
|
return out.Bytes(), nil
|
|
}
|