github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/gen/utils.go (about)

     1  package gen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/types"
     7  	"io"
     8  	"reflect"
     9  	"regexp"
    10  	"strings"
    11  	"text/template"
    12  
    13  	"github.com/google/uuid"
    14  
    15  	"github.com/artisanhe/tools/codegen"
    16  	"github.com/artisanhe/tools/godash"
    17  	"github.com/artisanhe/tools/sqlx"
    18  	"github.com/artisanhe/tools/sqlx/builder"
    19  )
    20  
    21  var (
    22  	defRegexp = regexp.MustCompile(`@def ([^\n]+)`)
    23  )
    24  
    25  type Keys struct {
    26  	Primary       sqlx.FieldNames
    27  	Indexes       sqlx.Indexes
    28  	UniqueIndexes sqlx.Indexes
    29  }
    30  
    31  func (ks *Keys) PatchUniqueIndexesWithSoftDelete(softDeleteField string) {
    32  	if len(ks.UniqueIndexes) > 0 {
    33  		for name, fieldNames := range ks.UniqueIndexes {
    34  			ks.UniqueIndexes[name] = godash.StringUniq(append(fieldNames, softDeleteField))
    35  		}
    36  	}
    37  }
    38  
    39  func (ks *Keys) Bind(table *builder.Table) {
    40  	if len(ks.Primary) > 0 {
    41  		cols, err := CheckFields(table, ks.Primary...)
    42  		if err != nil {
    43  			panic(fmt.Errorf("%s, please check primary def", err.Error()))
    44  		}
    45  		ks.Primary = cols.FieldNames()
    46  		table.Keys.Add(builder.PrimaryKey().WithCols(cols.List()...))
    47  	}
    48  	if len(ks.Indexes) > 0 {
    49  		for name, fieldNames := range ks.Indexes {
    50  			cols, err := CheckFields(table, fieldNames...)
    51  			if err != nil {
    52  				panic(fmt.Errorf("%s, please check index def", err.Error()))
    53  			}
    54  			ks.Indexes[name] = cols.FieldNames()
    55  			table.Keys.Add(builder.Index(name).WithCols(cols.List()...))
    56  		}
    57  	}
    58  
    59  	if len(ks.UniqueIndexes) > 0 {
    60  		for name, fieldNames := range ks.UniqueIndexes {
    61  			cols, err := CheckFields(table, fieldNames...)
    62  			if err != nil {
    63  				panic(fmt.Errorf("%s, please check unique_index def", err.Error()))
    64  			}
    65  			ks.UniqueIndexes[name] = cols.FieldNames()
    66  			table.Keys.Add(builder.UniqueIndex(name).WithCols(cols.List()...))
    67  		}
    68  	}
    69  }
    70  
    71  func CheckFields(table *builder.Table, fieldNames ...string) (cols builder.Columns, err error) {
    72  	for _, fieldName := range fieldNames {
    73  		col := table.F(fieldName)
    74  		if col == nil {
    75  			err = fmt.Errorf("table %s has no field %s", table.Name, fieldName)
    76  			return
    77  		}
    78  		cols.Add(col)
    79  	}
    80  	return
    81  }
    82  
    83  func parseKeysFromDoc(doc string) *Keys {
    84  	ks := &Keys{}
    85  	matches := defRegexp.FindAllStringSubmatch(doc, -1)
    86  
    87  	for _, subMatch := range matches {
    88  		if len(subMatch) == 2 {
    89  			defs := defSplit(subMatch[1])
    90  
    91  			switch strings.ToLower(defs[0]) {
    92  			case "primary":
    93  				if len(defs) < 2 {
    94  					panic(fmt.Errorf("primary at lease 1 Field"))
    95  				}
    96  				ks.Primary = sqlx.FieldNames(defs[1:])
    97  			case "index":
    98  				if len(defs) < 3 {
    99  					panic(fmt.Errorf("index at lease 1 Field"))
   100  				}
   101  				if ks.Indexes == nil {
   102  					ks.Indexes = sqlx.Indexes{}
   103  				}
   104  				ks.Indexes[defs[1]] = sqlx.FieldNames(defs[2:])
   105  			case "unique_index":
   106  				if len(defs) < 3 {
   107  					panic(fmt.Errorf("unique Indexes at lease 1 Field"))
   108  				}
   109  				if ks.UniqueIndexes == nil {
   110  					ks.UniqueIndexes = sqlx.Indexes{}
   111  				}
   112  				ks.UniqueIndexes[defs[1]] = sqlx.FieldNames(defs[2:])
   113  			}
   114  		}
   115  	}
   116  	return ks
   117  }
   118  
   119  func defSplit(def string) (defs []string) {
   120  	vs := strings.Split(def, " ")
   121  	for _, s := range vs {
   122  		if s != "" {
   123  			defs = append(defs, s)
   124  		}
   125  	}
   126  	return
   127  }
   128  
   129  func toDefaultTableName(name string) string {
   130  	return codegen.ToLowerSnakeCase("t_" + name)
   131  }
   132  
   133  func forEachStructField(structType *types.Struct, fn func(fieldVar *types.Var, columnName string, tpe string)) {
   134  	for i := 0; i < structType.NumFields(); i++ {
   135  		field := structType.Field(i)
   136  		tag := structType.Tag(i)
   137  		if field.Exported() {
   138  			structTag := reflect.StructTag(tag)
   139  			fieldName, exists := structTag.Lookup("db")
   140  			if exists {
   141  				if fieldName != "-" {
   142  					fn(field, fieldName, structTag.Get("sql"))
   143  				}
   144  			} else if field.Anonymous() {
   145  				if nextStructType, ok := field.Type().Underlying().(*types.Struct); ok {
   146  					forEachStructField(nextStructType, fn)
   147  				}
   148  				continue
   149  			}
   150  		}
   151  	}
   152  }
   153  
   154  func T() *Template {
   155  	return &Template{}
   156  }
   157  
   158  type Template struct {
   159  	tpl     string
   160  	funcMap template.FuncMap
   161  }
   162  
   163  func (t Template) Funcs(funcMap template.FuncMap) *Template {
   164  	t.funcMap = funcMap
   165  	return &t
   166  }
   167  
   168  func (t Template) Parse(tpl string) *Template {
   169  	t.tpl = tpl
   170  	return &t
   171  }
   172  
   173  func (t *Template) Execute(wr io.Writer, data interface{}) {
   174  	tpl, parseErr := template.New(uuid.New().String()).Funcs(t.funcMap).Parse(t.tpl)
   175  	if parseErr != nil {
   176  		panic(fmt.Sprintf("template Prase failded: %s", parseErr.Error()))
   177  	}
   178  	err := tpl.Execute(wr, data)
   179  	if err != nil {
   180  		panic(fmt.Sprintf("template Execute failded: %s", err.Error()))
   181  	}
   182  }
   183  
   184  func (t *Template) Render(data interface{}) string {
   185  	buf := new(bytes.Buffer)
   186  	t.Execute(buf, data)
   187  	return buf.String()
   188  }