github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/gen/tag_generator.go (about)

     1  package gen
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"go/types"
     9  	"sort"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/sirupsen/logrus"
    14  	"golang.org/x/tools/go/loader"
    15  
    16  	"github.com/johnnyeven/libtools/codegen"
    17  	"github.com/johnnyeven/libtools/codegen/loaderx"
    18  )
    19  
    20  type TagGenerator struct {
    21  	StructNames   []string
    22  	pkgImportPath string
    23  	WithDefaults  bool
    24  	program       *loader.Program
    25  	outputs       codegen.Outputs
    26  }
    27  
    28  func (g *TagGenerator) Load(cwd string) {
    29  	ldr := loader.Config{
    30  		AllowErrors: true,
    31  		ParserMode:  parser.ParseComments,
    32  	}
    33  
    34  	pkgImportPath := codegen.GetPackageImportPath(cwd)
    35  	ldr.Import(pkgImportPath)
    36  
    37  	p, err := ldr.Load()
    38  	if err != nil {
    39  		panic(err)
    40  	}
    41  
    42  	g.pkgImportPath = pkgImportPath
    43  	g.program = p
    44  	g.outputs = codegen.Outputs{}
    45  }
    46  
    47  func (g *TagGenerator) Pick() {
    48  	for pkg, pkgInfo := range g.program.AllPackages {
    49  		if pkg.Path() != g.pkgImportPath {
    50  			continue
    51  		}
    52  		for ident, obj := range pkgInfo.Defs {
    53  			if typeName, ok := obj.(*types.TypeName); ok {
    54  				for _, structName := range g.StructNames {
    55  					if typeName.Name() == structName {
    56  						if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok {
    57  							modifyTag(ident.Obj.Decl.(*ast.TypeSpec).Type.(*ast.StructType), typeStruct, g.WithDefaults)
    58  							file := loaderx.FileOf(ident, pkgInfo.Files...)
    59  							g.outputs.Add(g.program.Fset.Position(file.Pos()).Filename, loaderx.StringifyAst(g.program.Fset, file))
    60  						}
    61  					}
    62  				}
    63  			}
    64  		}
    65  	}
    66  }
    67  
    68  func toTags(tags map[string]string) (tag string) {
    69  	names := make([]string, 0)
    70  	for name := range tags {
    71  		names = append(names, name)
    72  	}
    73  	sort.Strings(names)
    74  	for _, name := range names {
    75  		tag += fmt.Sprintf("%s:%s ", name, strconv.Quote(tags[name]))
    76  	}
    77  	return strings.TrimSpace(tag)
    78  }
    79  
    80  func getTags(tag string) (tags map[string]string) {
    81  	tags = make(map[string]string)
    82  	for tag != "" {
    83  		i := 0
    84  		for i < len(tag) && tag[i] == ' ' {
    85  			i++
    86  		}
    87  		tag = tag[i:]
    88  		if tag == "" {
    89  			break
    90  		}
    91  		i = 0
    92  		for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f {
    93  			i++
    94  		}
    95  		if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' {
    96  			break
    97  		}
    98  		name := string(tag[:i])
    99  		tag = tag[i+1:]
   100  
   101  		// Scan quoted string to find value.
   102  		i = 1
   103  		for i < len(tag) && tag[i] != '"' {
   104  			if tag[i] == '\\' {
   105  				i++
   106  			}
   107  			i++
   108  		}
   109  		if i >= len(tag) {
   110  			break
   111  		}
   112  		qvalue := string(tag[:i+1])
   113  		tag = tag[i+1:]
   114  
   115  		value, err := strconv.Unquote(qvalue)
   116  		if err != nil {
   117  			break
   118  		}
   119  		tags[name] = value
   120  
   121  	}
   122  	return
   123  }
   124  
   125  func modifyTag(structType *ast.StructType, typeStruct *types.Struct, withDefaults bool) {
   126  	for i := 0; i < typeStruct.NumFields(); i++ {
   127  		f := typeStruct.Field(i)
   128  		if f.Anonymous() {
   129  			continue
   130  		}
   131  		tags := getTags(typeStruct.Tag(i))
   132  		astField := structType.Fields.List[i]
   133  
   134  		if tags["db"] == "" {
   135  			tags["db"] = fmt.Sprintf("F_%s", codegen.ToLowerSnakeCase(f.Name()))
   136  		}
   137  		if tags["json"] == "" {
   138  			tags["json"] = codegen.ToLowerCamelCase(f.Name())
   139  			switch f.Type().(type) {
   140  			case *types.Basic:
   141  				if f.Type().(*types.Basic).Kind() == types.Uint64 {
   142  					tags["json"] = tags["json"] + ",string"
   143  				}
   144  			}
   145  		}
   146  		if tags["sql"] == "" {
   147  			tpe := f.Type()
   148  			switch codegen.DeVendor(tpe.String()) {
   149  			case "github.com/johnnyeven/libtools/timelib.MySQLDatetime":
   150  				tags["sql"] = "datetime NOT NULL"
   151  			case "github.com/johnnyeven/libtools/timelib.MySQLTimestamp":
   152  				tags["sql"] = toSqlFromKind(types.Typ[types.Int64].Kind(), withDefaults)
   153  			default:
   154  				tpe, err := IndirectType(tpe)
   155  				if err != nil {
   156  					logrus.Warnf("%s, make sure type of Field `%s` have sql.Valuer and sql.Scanner interface", err, f.Name())
   157  				}
   158  				switch tpe.(type) {
   159  				case *types.Basic:
   160  					tags["sql"] = toSqlFromKind(tpe.(*types.Basic).Kind(), withDefaults)
   161  				default:
   162  					tags["sql"] = WithDefaults("varchar(255) NOT NULL", withDefaults, "")
   163  				}
   164  			}
   165  		}
   166  		astField.Tag = &ast.BasicLit{Kind: token.STRING, Value: "`" + toTags(tags) + "`"}
   167  	}
   168  }
   169  
   170  func IndirectType(tpe types.Type) (types.Type, error) {
   171  	switch tpe.(type) {
   172  	case *types.Basic:
   173  		return tpe.(*types.Basic), nil
   174  	case *types.Struct, *types.Slice, *types.Array, *types.Map:
   175  		return nil, fmt.Errorf("unsupport type %s", tpe)
   176  	case *types.Pointer:
   177  		return IndirectType(tpe.(*types.Pointer).Elem())
   178  	default:
   179  		return IndirectType(tpe.Underlying())
   180  	}
   181  }
   182  
   183  func WithDefaults(dataType string, withDefaults bool, defaultValue string) string {
   184  	if withDefaults {
   185  		return dataType + fmt.Sprintf(" DEFAULT '%s'", defaultValue)
   186  	}
   187  	return dataType
   188  }
   189  
   190  func toSqlFromKind(kind types.BasicKind, withDefaults bool) string {
   191  	switch kind {
   192  	case types.Bool:
   193  		return WithDefaults("tinyint(1) NOT NULL", withDefaults, "0")
   194  	case types.Int8:
   195  		return WithDefaults("tinyint NOT NULL", withDefaults, "0")
   196  	case types.Int16:
   197  		return WithDefaults("smallint NOT NULL", withDefaults, "0")
   198  	case types.Int, types.Int32:
   199  		return WithDefaults("int NOT NULL", withDefaults, "0")
   200  	case types.Int64:
   201  		return WithDefaults("bigint NOT NULL", withDefaults, "0")
   202  	case types.Uint8:
   203  		return WithDefaults("tinyint unsigned NOT NULL", withDefaults, "0")
   204  	case types.Uint16:
   205  		return WithDefaults("smallint unsigned NOT NULL", withDefaults, "0")
   206  	case types.Uint, types.Uint32:
   207  		return WithDefaults("int unsigned NOT NULL", withDefaults, "0")
   208  	case types.Uint64:
   209  		return WithDefaults("bigint unsigned NOT NULL", withDefaults, "0")
   210  	case types.Float32:
   211  		return WithDefaults("float NOT NULL", withDefaults, "0")
   212  	case types.Float64:
   213  		return WithDefaults("double NOT NULL", withDefaults, "0")
   214  	default:
   215  		// string
   216  		return WithDefaults("varchar(255) NOT NULL", withDefaults, "")
   217  	}
   218  }
   219  
   220  func (g *TagGenerator) Output(cwd string) codegen.Outputs {
   221  	return g.outputs
   222  }