github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/gen/model.go (about) 1 package gen 2 3 import ( 4 "fmt" 5 "go/types" 6 "io" 7 "strings" 8 "text/template" 9 10 "golang.org/x/tools/go/loader" 11 12 "github.com/artisanhe/tools/codegen" 13 "github.com/artisanhe/tools/codegen/loaderx" 14 "github.com/artisanhe/tools/sqlx/builder" 15 ) 16 17 func NewModel(prog *loader.Program, typeName *types.TypeName, comments string, cfg *Config) *Model { 18 m := Model{} 19 m.Config = cfg 20 m.TypeName = typeName 21 22 m.Table = builder.T(nil, cfg.TableName) 23 24 p := prog.Package(typeName.Pkg().Path()) 25 26 forEachStructField(typeName.Type().Underlying().Underlying().(*types.Struct), func(structVal *types.Var, columnName string, tpe string) { 27 col := builder.Col(m.Table, columnName).Field(structVal.Name()).Type(tpe) 28 29 for id, o := range p.Defs { 30 if o == structVal { 31 doc := loaderx.CommentsOf(prog.Fset, id, p.Files...) 32 col.Comment = strings.Split(doc, "\n")[0] 33 } 34 } 35 36 m.AddColumn(col, structVal) 37 }) 38 39 m.HasSoftDelete = m.Table.F(m.FieldSoftDelete) != nil 40 m.HasCreatedAt = m.Table.F(m.FieldCreatedAt) != nil 41 m.HasUpdatedAt = m.Table.F(m.FieldUpdatedAt) != nil 42 43 m.Keys = parseKeysFromDoc(comments) 44 if m.HasSoftDelete { 45 m.Keys.PatchUniqueIndexesWithSoftDelete(m.FieldSoftDelete) 46 } 47 m.Keys.Bind(m.Table) 48 49 if autoIncrementCol := m.Table.AutoIncrement(); autoIncrementCol != nil { 50 m.HasAutoIncrement = true 51 m.FieldAutoIncrement = autoIncrementCol.FieldName 52 } 53 54 m.Importer = &codegen.Importer{} 55 56 m.Template = T().Funcs(template.FuncMap{ 57 "use": m.Importer.Use, 58 "dump": m.Importer.Sdump, 59 "var": codegen.ToLowerCamelCase, 60 }) 61 62 return &m 63 } 64 65 type Model struct { 66 *types.TypeName 67 *codegen.Importer 68 *Config 69 *Template 70 *Keys 71 *builder.Table 72 Fields map[string]*types.Var 73 FieldAutoIncrement string 74 HasSoftDelete bool 75 HasCreatedAt bool 76 HasUpdatedAt bool 77 HasAutoIncrement bool 78 } 79 80 func (m *Model) AddColumn(col *builder.Column, tpe *types.Var) { 81 m.Table.Columns.Add(col) 82 if m.Fields == nil { 83 m.Fields = map[string]*types.Var{} 84 } 85 m.Fields[col.FieldName] = tpe 86 } 87 88 func (m *Model) Render() string { 89 blocks := strings.Join( 90 []string{ 91 m.dataAndTable(), 92 m.interfaces(), 93 m.methodsForCRUD(), 94 m.methodsForList(), 95 }, 96 "\n", 97 ) 98 99 return fmt.Sprintf(` 100 package %s 101 102 %s 103 104 %s 105 `, 106 m.TypeName.Pkg().Name(), 107 m.Importer.String(), 108 blocks, 109 ) 110 } 111 112 func (m *Model) ParseTo(writer io.Writer, tpl string) { 113 m.Template.Parse(tpl).Execute(writer, m) 114 }