github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/gen_method/model.go (about)

     1  package gen_method
     2  
     3  import (
     4  	"fmt"
     5  	"go/build"
     6  	"go/types"
     7  	"os"
     8  	"reflect"
     9  	"sort"
    10  	"strings"
    11  
    12  	"golang.org/x/tools/imports"
    13  
    14  	"github.com/artisanhe/tools/codegen"
    15  	"github.com/artisanhe/tools/codegen/loaderx"
    16  )
    17  
    18  type Field struct {
    19  	Name                string
    20  	Type                string
    21  	DbFieldName         string
    22  	IndexName           string
    23  	IndexNumber         int
    24  	IsEnable            bool
    25  	IsCreateTime        bool
    26  	IsSpecifyIndexOrder bool
    27  }
    28  
    29  type Model struct {
    30  	Pkg                 *types.Package
    31  	Name                string
    32  	Fields              []*Field
    33  	IsDbModel           bool
    34  	UniqueIndex         map[string][]Field
    35  	NormalIndex         map[string][]Field
    36  	PrimaryIndex        []Field
    37  	HasCreateTimeField  bool
    38  	HasUpdateTimeField  bool
    39  	HasEnabledField     bool
    40  	EnabledFieldType    string
    41  	CreateTimeFieldType string
    42  	UpdateTimeFieldType string
    43  	DbCreateTimeField   string
    44  	DbEnabledField      string
    45  	DbUpdateTimeField   string
    46  	FuncMapContent      map[string]string
    47  	Deps                []string
    48  }
    49  
    50  func (model *Model) collectInfoFromStructType(tpeStruct *types.Struct) {
    51  	for i := 0; i < tpeStruct.NumFields(); i++ {
    52  		field := tpeStruct.Field(i)
    53  		tag := reflect.StructTag(tpeStruct.Tag(i))
    54  
    55  		sqlSettings := ParseTagSetting(tag.Get("sql"))
    56  		gormSettings := ParseTagSetting(tag.Get("gorm"))
    57  
    58  		if len(gormSettings) != 0 && len(sqlSettings) != 0 {
    59  			if _, ok := gormSettings["-"]; ok {
    60  				continue
    61  			}
    62  			model.IsDbModel = true
    63  		} else {
    64  			continue
    65  		}
    66  
    67  		var tmpField = Field{}
    68  		if dbFieldName, ok := gormSettings["COLUMN"]; ok {
    69  			tmpField.DbFieldName = dbFieldName[0]
    70  		} else {
    71  			tmpField.DbFieldName = tmpField.Name
    72  		}
    73  		tmpField.Name = field.Name()
    74  
    75  		pkgPath, method := loaderx.GetPkgImportPathAndExpose(field.Type().String())
    76  		if pkgPath != "" {
    77  			pkg, _ := build.Import(pkgPath, "", build.ImportComment)
    78  			tmpField.Type = fmt.Sprintf("%s.%s", pkg.Name, method)
    79  		} else {
    80  			tmpField.Type = method
    81  		}
    82  
    83  		if pkgPath != "" {
    84  			model.Deps = append(model.Deps, pkgPath)
    85  		}
    86  
    87  		model.Fields = append(model.Fields, &tmpField)
    88  		if tmpField.Name == "Enabled" || tmpField.DbFieldName == "F_enabled" {
    89  			// enabled field don't join into index slice
    90  			tmpField.IsEnable = true
    91  			model.HasEnabledField = true
    92  			model.EnabledFieldType = tmpField.Type
    93  			model.DbEnabledField = tmpField.DbFieldName
    94  			continue
    95  		} else if tmpField.Name == "UpdateTime" || tmpField.DbFieldName == "F_update_time" {
    96  			model.HasUpdateTimeField = true
    97  			model.UpdateTimeFieldType = tmpField.Type
    98  			model.DbUpdateTimeField = tmpField.DbFieldName
    99  		} else if tmpField.Name == "CreateTime" || tmpField.DbFieldName == "F_create_time" {
   100  			tmpField.IsCreateTime = true
   101  			model.HasCreateTimeField = true
   102  			model.CreateTimeFieldType = tmpField.Type
   103  			model.DbCreateTimeField = tmpField.DbFieldName
   104  		}
   105  
   106  		if _, ok := gormSettings["PRIMARY_KEY"]; ok {
   107  			model.PrimaryIndex = append(model.PrimaryIndex, tmpField)
   108  		}
   109  		if indexName, ok := sqlSettings["INDEX"]; ok && len(indexName) > 0 {
   110  			for _, index := range indexName {
   111  				tmpField.IndexName, tmpField.IndexNumber, tmpField.IsSpecifyIndexOrder = ParseIndex(index)
   112  				if _, ok := model.NormalIndex[tmpField.IndexName]; ok {
   113  					model.NormalIndex[tmpField.IndexName] = append(
   114  						model.NormalIndex[tmpField.IndexName],
   115  						tmpField)
   116  				} else {
   117  					model.NormalIndex[tmpField.IndexName] = []Field{tmpField}
   118  				}
   119  			}
   120  		}
   121  		if indexName, ok := sqlSettings["UNIQUE_INDEX"]; ok && len(indexName) > 0 {
   122  			for _, index := range indexName {
   123  				tmpField.IndexName, tmpField.IndexNumber, tmpField.IsSpecifyIndexOrder = ParseIndex(index)
   124  				if _, ok := model.UniqueIndex[tmpField.IndexName]; ok {
   125  					model.UniqueIndex[tmpField.IndexName] = append(
   126  						model.UniqueIndex[tmpField.IndexName],
   127  						tmpField)
   128  				} else {
   129  					model.UniqueIndex[tmpField.IndexName] = []Field{tmpField}
   130  				}
   131  			}
   132  		}
   133  	}
   134  }
   135  
   136  func (model *Model) Output(pkgName string, ignoreCreateTableNameFunc bool) {
   137  	if err := genTableNameFunc(model, pkgName, "git.chinawayltd.com/golib", ignoreCreateTableNameFunc); err != nil {
   138  		fmt.Printf("%s\n", err.Error())
   139  		os.Exit(1)
   140  	}
   141  
   142  	if err := genCreateFunc(model); err != nil {
   143  		fmt.Printf("%s\n", err.Error())
   144  		os.Exit(1)
   145  	}
   146  
   147  	model.genDBFieldFunc()
   148  	model.genCodeByNormalIndex()
   149  	model.genCodeByUniqueIndex()
   150  	model.genCodeByPrimaryKeyIndex()
   151  
   152  	if err := genFetchListFunc(model); err != nil {
   153  		fmt.Printf("%s\n", err.Error())
   154  		os.Exit(1)
   155  	}
   156  
   157  	model.GenerateFile()
   158  }
   159  
   160  func (model *Model) GenerateFile() {
   161  	var funcNameList []string
   162  
   163  	// first part of file
   164  	var tableName = "tableName"
   165  	for key := range model.FuncMapContent {
   166  		if key == tableName {
   167  			continue
   168  		}
   169  		funcNameList = append(funcNameList, key)
   170  	}
   171  
   172  	contents := []string{
   173  		model.FuncMapContent[tableName],
   174  	}
   175  
   176  	sort.Strings(funcNameList)
   177  
   178  	for _, funcName := range funcNameList {
   179  		contents = append(contents, model.FuncMapContent[funcName])
   180  	}
   181  
   182  	// p, _ := build.Import(model.Pkg.Path(), "", build.FindOnly)
   183  	// cwd, _ := os.Getwd()
   184  	// path, _ := filepath.Rel(cwd, p.Dir)
   185  
   186  	path := "."
   187  	filename := path + "/" + replaceUpperWithLowerAndUnderscore(model.Name) + ".go"
   188  	content := strings.Join(contents, "\n\n")
   189  	bytes, err := imports.Process(filename, []byte(content), nil)
   190  	if err != nil {
   191  		panic(err)
   192  	} else {
   193  		content = string(bytes)
   194  	}
   195  	codegen.WriteFile(codegen.GeneratedSuffix(filename), content)
   196  }
   197  
   198  func (model *Model) genCodeByUniqueIndex() {
   199  	for _, fieldList := range model.UniqueIndex {
   200  		sortFieldList := model.sortFieldsByIndexNumber(fieldList)
   201  		model.handleGenCodeForUniqueIndex(sortFieldList)
   202  	}
   203  }
   204  
   205  func (model *Model) genCodeByPrimaryKeyIndex() {
   206  	if len(model.PrimaryIndex) > 0 {
   207  		model.handleGenCodeForUniqueIndex(model.PrimaryIndex)
   208  	}
   209  }
   210  
   211  func (model *Model) genCodeByNormalIndex() {
   212  	for _, fieldList := range model.NormalIndex {
   213  		sortFieldList := model.sortFieldsByIndexNumber(fieldList)
   214  		baseInfoGenCode := fetchBaseInfoOfGenFuncForNormalIndex(sortFieldList)
   215  		if err := genFetchFuncByNormalIndex(model, baseInfoGenCode); err != nil {
   216  			fmt.Printf("%s\n", err.Error())
   217  			os.Exit(1)
   218  		}
   219  		model.handleGenFetchCodeBySubIndex(sortFieldList)
   220  	}
   221  }
   222  
   223  func (model *Model) genDBFieldFunc() {
   224  	var structFields, structFieldsAndValue, noUniqueIndexFields, dBAndStructField, structAndDbField []string
   225  	var primaryIndexMap = make(map[string]string)
   226  	for _, field := range model.PrimaryIndex {
   227  		primaryIndexMap[field.Name] = ""
   228  	}
   229  
   230  	for _, field := range model.Fields {
   231  		structFields = append(structFields, fmt.Sprintf("%s string", field.Name))
   232  		structFieldsAndValue = append(structFieldsAndValue, fmt.Sprintf("%s:\"%s\"", field.Name, field.DbFieldName))
   233  		dBAndStructField = append(dBAndStructField, fmt.Sprintf("\"%s\" : \"%s\",", field.DbFieldName, field.Name))
   234  		structAndDbField = append(structAndDbField, fmt.Sprintf("\"%s\" : \"%s\",", field.Name, field.DbFieldName))
   235  		if _, ok := model.UniqueIndex[field.IndexName]; !ok {
   236  			if _, ok := primaryIndexMap[field.Name]; !ok && !field.IsEnable && !field.IsCreateTime {
   237  				noUniqueIndexFields = append(noUniqueIndexFields, fmt.Sprintf("\"%s\"", field.DbFieldName))
   238  			}
   239  		}
   240  	}
   241  
   242  	if err := genDBFiledFunc(model, structFields, structFieldsAndValue, noUniqueIndexFields, dBAndStructField,
   243  		structAndDbField); err != nil {
   244  		fmt.Printf("%s\n", err.Error())
   245  		os.Exit(1)
   246  	}
   247  }
   248  
   249  func (model *Model) sortFieldsByIndexNumber(fields []Field) []Field {
   250  	if len(fields) == 0 && len(fields) == 1 {
   251  		return fields
   252  	}
   253  
   254  	if !IsSpecifyIndexSequence(fields) {
   255  		return fields
   256  	}
   257  
   258  	var markIndexNubmer = make(map[string]string)
   259  	var sortFieldSlice = []Field{}
   260  	for _, field := range fields {
   261  		if fieldName, ok := markIndexNubmer[fmt.Sprintf("%d", field.IndexNumber)]; ok {
   262  			fmt.Printf("Field[%s] and Field[%s] has same index number[%d] in Mode[%s]", fieldName, field.Name,
   263  				field.IndexNumber, model.Name)
   264  			os.Exit(1)
   265  		}
   266  		if field.IndexNumber >= 0 {
   267  			tmpSlice := make([]Field, field.IndexNumber+1)
   268  			tmpSlice[field.IndexNumber] = field
   269  			if len(sortFieldSlice) > field.IndexNumber+1 {
   270  				sortFieldSlice = append(tmpSlice, sortFieldSlice[field.IndexNumber+1:]...)
   271  			} else if len(sortFieldSlice) < field.IndexNumber+1 {
   272  				sortFieldSlice = append(sortFieldSlice, tmpSlice[len(sortFieldSlice):]...)
   273  			} else {
   274  				fmt.Printf("Field[%s] wrong index sequence, may be same index number.\n", field.Name)
   275  				os.Exit(1)
   276  			}
   277  		}
   278  	}
   279  
   280  	var notEmptySlice = []Field{}
   281  	for index, field := range sortFieldSlice {
   282  		if len(field.Name) > 0 {
   283  			notEmptySlice = append(notEmptySlice, sortFieldSlice[index])
   284  			//sortFieldSlice = append(sortFieldSlice[:index], sortFieldSlice[index+1:]...)
   285  		}
   286  	}
   287  
   288  	return notEmptySlice
   289  }
   290  
   291  func (model *Model) genBatchFetchFuncBySingleIndex(field Field) {
   292  	if err := genBatchFetchFunc(model, field.Name, field.DbFieldName, field.Type); err != nil {
   293  		fmt.Printf("%s\n", err.Error())
   294  		os.Exit(1)
   295  	}
   296  }
   297  
   298  func (model *Model) handleGenFetchCodeBySubIndex(fieldList []Field) {
   299  	if len(fieldList) == 1 {
   300  		model.genBatchFetchFuncBySingleIndex(fieldList[0])
   301  	} else if len(fieldList) > 1 {
   302  		// [x, y, z, e] Split to [x, y, z], [x, y], [x]
   303  		for i := 1; i < len(fieldList); i++ {
   304  			subSortFieldSlice := fieldList[:len(fieldList)-i]
   305  			baseInfoGenCode := fetchBaseInfoOfGenFuncForNormalIndex(subSortFieldSlice)
   306  			if err := genFetchFuncByNormalIndex(model, baseInfoGenCode); err != nil {
   307  				fmt.Printf("%s\n", err.Error())
   308  				os.Exit(1)
   309  			}
   310  
   311  			if len(subSortFieldSlice) == 1 {
   312  				model.genBatchFetchFuncBySingleIndex(subSortFieldSlice[0])
   313  			}
   314  
   315  		}
   316  	}
   317  }
   318  
   319  func (model *Model) handleGenCodeForUniqueIndex(sortFieldList []Field) {
   320  	baseInfoGenCode := fetchBaseInfoOfGenFuncForUniqueIndex(model, sortFieldList)
   321  	if err := genFetchFuncByUniqueIndex(model, baseInfoGenCode); err != nil {
   322  		fmt.Printf("%s\n", err.Error())
   323  		os.Exit(1)
   324  	}
   325  
   326  	model.handleGenFetchCodeBySubIndex(sortFieldList)
   327  
   328  	if err := genFetchForUpdateFuncByUniqueIndex(model, baseInfoGenCode); err != nil {
   329  		fmt.Printf("%s\n", err.Error())
   330  		os.Exit(1)
   331  	}
   332  
   333  	if err := genUpdateWithStructFuncByUniqueIndex(model, baseInfoGenCode); err != nil {
   334  		fmt.Printf("%s\n", err.Error())
   335  		os.Exit(1)
   336  	}
   337  
   338  	if err := genUpdateWithMapFuncByUniqueIndex(model, baseInfoGenCode); err != nil {
   339  		fmt.Printf("%s\n", err.Error())
   340  		os.Exit(1)
   341  	}
   342  
   343  	if err := genSoftDeleteFuncByUniqueIndex(model, baseInfoGenCode); err != nil {
   344  		fmt.Printf("%s\n", err.Error())
   345  		os.Exit(1)
   346  	}
   347  	if err := genPhysicsDeleteFuncByUniqueIndex(model, baseInfoGenCode); err != nil {
   348  		fmt.Printf("%s\n", err.Error())
   349  		os.Exit(1)
   350  	}
   351  }