github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/gen/tag_generator.go (about) 1 package gen 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/parser" 7 "go/token" 8 "go/types" 9 "sort" 10 "strconv" 11 "strings" 12 13 "github.com/sirupsen/logrus" 14 "golang.org/x/tools/go/loader" 15 16 "github.com/johnnyeven/libtools/codegen" 17 "github.com/johnnyeven/libtools/codegen/loaderx" 18 ) 19 20 type TagGenerator struct { 21 StructNames []string 22 pkgImportPath string 23 WithDefaults bool 24 program *loader.Program 25 outputs codegen.Outputs 26 } 27 28 func (g *TagGenerator) Load(cwd string) { 29 ldr := loader.Config{ 30 AllowErrors: true, 31 ParserMode: parser.ParseComments, 32 } 33 34 pkgImportPath := codegen.GetPackageImportPath(cwd) 35 ldr.Import(pkgImportPath) 36 37 p, err := ldr.Load() 38 if err != nil { 39 panic(err) 40 } 41 42 g.pkgImportPath = pkgImportPath 43 g.program = p 44 g.outputs = codegen.Outputs{} 45 } 46 47 func (g *TagGenerator) Pick() { 48 for pkg, pkgInfo := range g.program.AllPackages { 49 if pkg.Path() != g.pkgImportPath { 50 continue 51 } 52 for ident, obj := range pkgInfo.Defs { 53 if typeName, ok := obj.(*types.TypeName); ok { 54 for _, structName := range g.StructNames { 55 if typeName.Name() == structName { 56 if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok { 57 modifyTag(ident.Obj.Decl.(*ast.TypeSpec).Type.(*ast.StructType), typeStruct, g.WithDefaults) 58 file := loaderx.FileOf(ident, pkgInfo.Files...) 59 g.outputs.Add(g.program.Fset.Position(file.Pos()).Filename, loaderx.StringifyAst(g.program.Fset, file)) 60 } 61 } 62 } 63 } 64 } 65 } 66 } 67 68 func toTags(tags map[string]string) (tag string) { 69 names := make([]string, 0) 70 for name := range tags { 71 names = append(names, name) 72 } 73 sort.Strings(names) 74 for _, name := range names { 75 tag += fmt.Sprintf("%s:%s ", name, strconv.Quote(tags[name])) 76 } 77 return strings.TrimSpace(tag) 78 } 79 80 func getTags(tag string) (tags map[string]string) { 81 tags = make(map[string]string) 82 for tag != "" { 83 i := 0 84 for i < len(tag) && tag[i] == ' ' { 85 i++ 86 } 87 tag = tag[i:] 88 if tag == "" { 89 break 90 } 91 i = 0 92 for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f { 93 i++ 94 } 95 if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' { 96 break 97 } 98 name := string(tag[:i]) 99 tag = tag[i+1:] 100 101 // Scan quoted string to find value. 102 i = 1 103 for i < len(tag) && tag[i] != '"' { 104 if tag[i] == '\\' { 105 i++ 106 } 107 i++ 108 } 109 if i >= len(tag) { 110 break 111 } 112 qvalue := string(tag[:i+1]) 113 tag = tag[i+1:] 114 115 value, err := strconv.Unquote(qvalue) 116 if err != nil { 117 break 118 } 119 tags[name] = value 120 121 } 122 return 123 } 124 125 func modifyTag(structType *ast.StructType, typeStruct *types.Struct, withDefaults bool) { 126 for i := 0; i < typeStruct.NumFields(); i++ { 127 f := typeStruct.Field(i) 128 if f.Anonymous() { 129 continue 130 } 131 tags := getTags(typeStruct.Tag(i)) 132 astField := structType.Fields.List[i] 133 134 if tags["db"] == "" { 135 tags["db"] = fmt.Sprintf("F_%s", codegen.ToLowerSnakeCase(f.Name())) 136 } 137 if tags["json"] == "" { 138 tags["json"] = codegen.ToLowerCamelCase(f.Name()) 139 switch f.Type().(type) { 140 case *types.Basic: 141 if f.Type().(*types.Basic).Kind() == types.Uint64 { 142 tags["json"] = tags["json"] + ",string" 143 } 144 } 145 } 146 if tags["sql"] == "" { 147 tpe := f.Type() 148 switch codegen.DeVendor(tpe.String()) { 149 case "github.com/johnnyeven/libtools/timelib.MySQLDatetime": 150 tags["sql"] = "datetime NOT NULL" 151 case "github.com/johnnyeven/libtools/timelib.MySQLTimestamp": 152 tags["sql"] = toSqlFromKind(types.Typ[types.Int64].Kind(), withDefaults) 153 default: 154 tpe, err := IndirectType(tpe) 155 if err != nil { 156 logrus.Warnf("%s, make sure type of Field `%s` have sql.Valuer and sql.Scanner interface", err, f.Name()) 157 } 158 switch tpe.(type) { 159 case *types.Basic: 160 tags["sql"] = toSqlFromKind(tpe.(*types.Basic).Kind(), withDefaults) 161 default: 162 tags["sql"] = WithDefaults("varchar(255) NOT NULL", withDefaults, "") 163 } 164 } 165 } 166 astField.Tag = &ast.BasicLit{Kind: token.STRING, Value: "`" + toTags(tags) + "`"} 167 } 168 } 169 170 func IndirectType(tpe types.Type) (types.Type, error) { 171 switch tpe.(type) { 172 case *types.Basic: 173 return tpe.(*types.Basic), nil 174 case *types.Struct, *types.Slice, *types.Array, *types.Map: 175 return nil, fmt.Errorf("unsupport type %s", tpe) 176 case *types.Pointer: 177 return IndirectType(tpe.(*types.Pointer).Elem()) 178 default: 179 return IndirectType(tpe.Underlying()) 180 } 181 } 182 183 func WithDefaults(dataType string, withDefaults bool, defaultValue string) string { 184 if withDefaults { 185 return dataType + fmt.Sprintf(" DEFAULT '%s'", defaultValue) 186 } 187 return dataType 188 } 189 190 func toSqlFromKind(kind types.BasicKind, withDefaults bool) string { 191 switch kind { 192 case types.Bool: 193 return WithDefaults("tinyint(1) NOT NULL", withDefaults, "0") 194 case types.Int8: 195 return WithDefaults("tinyint NOT NULL", withDefaults, "0") 196 case types.Int16: 197 return WithDefaults("smallint NOT NULL", withDefaults, "0") 198 case types.Int, types.Int32: 199 return WithDefaults("int NOT NULL", withDefaults, "0") 200 case types.Int64: 201 return WithDefaults("bigint NOT NULL", withDefaults, "0") 202 case types.Uint8: 203 return WithDefaults("tinyint unsigned NOT NULL", withDefaults, "0") 204 case types.Uint16: 205 return WithDefaults("smallint unsigned NOT NULL", withDefaults, "0") 206 case types.Uint, types.Uint32: 207 return WithDefaults("int unsigned NOT NULL", withDefaults, "0") 208 case types.Uint64: 209 return WithDefaults("bigint unsigned NOT NULL", withDefaults, "0") 210 case types.Float32: 211 return WithDefaults("float NOT NULL", withDefaults, "0") 212 case types.Float64: 213 return WithDefaults("double NOT NULL", withDefaults, "0") 214 default: 215 // string 216 return WithDefaults("varchar(255) NOT NULL", withDefaults, "") 217 } 218 } 219 220 func (g *TagGenerator) Output(cwd string) codegen.Outputs { 221 return g.outputs 222 }