github.com/team-ide/go-dialect@v1.9.20/gen_type_test.go (about)

     1  package main
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"github.com/tealeg/xlsx"
     7  	"github.com/team-ide/go-dialect/dialect"
     8  	"os"
     9  	"strings"
    10  	"testing"
    11  )
    12  
    13  func TestTypeParseGen(t *testing.T) {
    14  	err := dataTypeParse(`数据库类型.xlsx`, "dialect/mapping.column.type.go")
    15  	if err != nil {
    16  		panic(err)
    17  	}
    18  }
    19  
    20  type databaseModel struct {
    21  	Name      string
    22  	DataTypes []*dialect.ColumnTypeInfo
    23  }
    24  
    25  func dataTypeParse(path string, outPath string) (err error) {
    26  	xlsxFForRead, err := xlsx.OpenFile(path)
    27  	if err != nil {
    28  		err = errors.New("excel [" + path + "] open error, " + err.Error())
    29  		return
    30  	}
    31  	sheets := xlsxFForRead.Sheets
    32  
    33  	var databases []*databaseModel
    34  
    35  	for _, sheet := range sheets {
    36  		database := &databaseModel{}
    37  		database.Name = sheet.Name
    38  
    39  		var titles []string
    40  
    41  		var RowMergeEnd = -1
    42  		var RowMergeCell = -1
    43  		var RowMergeValue string
    44  		for rowIndex, row := range sheet.Rows {
    45  
    46  			if rowIndex == 0 {
    47  				for _, cell := range row.Cells {
    48  					title := cell.Value
    49  					title = strings.TrimSpace(title)
    50  					titles = append(titles, title)
    51  				}
    52  				continue
    53  			}
    54  			var dataType = map[string]string{}
    55  			for cellIndex, cell := range row.Cells {
    56  				if cellIndex >= len(titles) {
    57  					break
    58  				}
    59  				title := titles[cellIndex]
    60  				if title == "" {
    61  					continue
    62  				}
    63  				value := cell.Value
    64  				value = strings.TrimSpace(value)
    65  				if cell.VMerge > 0 {
    66  					RowMergeCell = cellIndex
    67  					RowMergeEnd = rowIndex + cell.VMerge
    68  					RowMergeValue = value
    69  				}
    70  				if RowMergeCell == cellIndex {
    71  					if rowIndex <= RowMergeEnd {
    72  						value = RowMergeValue
    73  					} else {
    74  						RowMergeEnd = -1
    75  						RowMergeValue = ""
    76  					}
    77  				}
    78  				dataType[title] = value
    79  			}
    80  			if dataType["名称"] == "" {
    81  				continue
    82  			}
    83  			database.DataTypes = append(database.DataTypes, formatDataType(dataType))
    84  		}
    85  
    86  		databases = append(databases, database)
    87  	}
    88  
    89  	outFile, err := os.Create(outPath)
    90  	if err != nil {
    91  		return
    92  	}
    93  	_, err = outFile.WriteString(`package dialect
    94  
    95  import "strings"
    96  
    97  `)
    98  	if err != nil {
    99  		return
   100  	}
   101  	for _, one := range databases {
   102  		fmt.Println("-------- database [" + one.Name + "] start --------")
   103  
   104  		var code string
   105  		code += "// " + one.Name + " 数据库 字段类型" + "\n"
   106  		columnTypeListName := ""
   107  		var isMysql bool
   108  		var isShenTong bool
   109  		if strings.EqualFold(one.Name, "Mysql") {
   110  			columnTypeListName = "mysqlColumnTypeList"
   111  			isMysql = true
   112  		} else if strings.EqualFold(one.Name, "Oracle") {
   113  			columnTypeListName = "oracleColumnTypeList"
   114  		} else if strings.EqualFold(one.Name, "达梦") {
   115  			columnTypeListName = "dmColumnTypeList"
   116  		} else if strings.EqualFold(one.Name, "金仓") {
   117  			columnTypeListName = "kingBaseColumnTypeList"
   118  		} else if strings.EqualFold(one.Name, "神通") {
   119  			columnTypeListName = "shenTongColumnTypeList"
   120  			isShenTong = true
   121  		} else if strings.EqualFold(one.Name, "Sqlite") {
   122  			columnTypeListName = "sqliteColumnTypeList"
   123  		} else if strings.EqualFold(one.Name, "GBase") {
   124  			columnTypeListName = "gBaseColumnTypeList"
   125  		} else if strings.EqualFold(one.Name, "Postgresql") {
   126  			columnTypeListName = "postgresqlColumnTypeList"
   127  		} else if strings.EqualFold(one.Name, "DB2") {
   128  			columnTypeListName = "db2ColumnTypeList"
   129  		} else if strings.EqualFold(one.Name, "OpenGauss") {
   130  			columnTypeListName = "openGaussColumnTypeList"
   131  		}
   132  		code += "var " + columnTypeListName + " = []*ColumnTypeInfo{" + "\n"
   133  		for _, dataType := range one.DataTypes {
   134  			code += "\t" + "{"
   135  			code += "Name: `" + dataType.Name + "`, "
   136  			code += "Format: `" + dataType.Format + "`, "
   137  			if len(dataType.Matches) > 0 {
   138  				code += "Matches: []string{`" + strings.Join(dataType.Matches, "`, `") + "`}, "
   139  			}
   140  			if dataType.IsNumber {
   141  				code += "IsNumber: true, "
   142  			}
   143  			if dataType.IsInteger {
   144  				code += "IsInteger: true, "
   145  			}
   146  			if dataType.IsFloat {
   147  				code += "IsFloat: true, "
   148  			}
   149  			if dataType.IsBoolean {
   150  				code += "IsBoolean: true, "
   151  			}
   152  			if dataType.IsString {
   153  				code += "IsString: true, "
   154  			}
   155  			if dataType.IsBytes {
   156  				code += "IsBytes: true, "
   157  			}
   158  			if dataType.IsEnum {
   159  				code += "IsEnum: true, "
   160  			}
   161  			if dataType.IsDateTime {
   162  				code += "IsDateTime: true, "
   163  			}
   164  			if dataType.Comment != "" {
   165  				code += "Comment: `" + dataType.Comment + "`, "
   166  			}
   167  			var hasOtherMethod bool
   168  			if dataType.Name == "DATETIME" || dataType.Name == "TIMESTAMP" {
   169  				if isShenTong {
   170  				} else {
   171  					code = strings.TrimSuffix(code, " ")
   172  					hasOtherMethod = true
   173  					code += `
   174  		ColumnDefaultPack: func(param *ParamModel, column *ColumnModel) (columnDefaultPack string, err error) {
   175  			if strings.Contains(strings.ToLower(column.ColumnDefault), "current_timestamp") ||
   176  				strings.Contains(strings.ToLower(column.ColumnDefault), "0000-00-00 00:00:00") {
   177  				columnDefaultPack = "CURRENT_TIMESTAMP"
   178  			}
   179  `
   180  					if isMysql {
   181  						code += `
   182  			if strings.Contains(strings.ToLower(column.ColumnExtra), "on update current_timestamp") {
   183  				columnDefaultPack += " ON UPDATE CURRENT_TIMESTAMP"
   184  			}
   185  `
   186  					}
   187  					code += `
   188  			return
   189  		},
   190  `
   191  				}
   192  			} else if dataType.IsEnum {
   193  				if isMysql {
   194  					hasOtherMethod = true
   195  					code = strings.TrimSuffix(code, " ")
   196  					code += `
   197  		FullColumnByColumnType: func(columnType string, column *ColumnModel) (err error) {
   198  			if strings.Contains(columnType, "(") {
   199  				setStr := columnType[strings.Index(columnType, "(")+1 : strings.Index(columnType, ")")]
   200  				setStr = strings.ReplaceAll(setStr, "'", "")
   201  				column.ColumnEnums = strings.Split(setStr, ",")
   202  			}
   203  			return
   204  		},
   205  `
   206  				} else {
   207  				}
   208  			} else {
   209  
   210  			}
   211  			if hasOtherMethod {
   212  				code += "\t" + "}," + "\n"
   213  			} else {
   214  				code = code[0 : len(code)-2]
   215  				code += "}," + "\n"
   216  			}
   217  		}
   218  		code += "}" + "\n\n"
   219  		fmt.Println(code)
   220  		_, err = outFile.WriteString(code)
   221  		if err != nil {
   222  			return
   223  		}
   224  		fmt.Println("-------- database [" + one.Name + "] end --------")
   225  	}
   226  	return
   227  }
   228  
   229  func formatDataType(dataType map[string]string) (info *dialect.ColumnTypeInfo) {
   230  	info = &dialect.ColumnTypeInfo{}
   231  	name := dataType["名称"]
   232  	format := name
   233  	if strings.Contains(name, "(") {
   234  		nameStart := name[0:strings.Index(name, "(")]
   235  		nameEnd := name[strings.Index(name, ")")+1:]
   236  		inStr := name[strings.Index(name, "("):strings.Index(name, ")")]
   237  		inStr = strings.ReplaceAll(inStr, "(", "")
   238  		inStr = strings.ReplaceAll(inStr, ")", "")
   239  
   240  		ss := strings.Split(inStr, ",")
   241  		format = nameStart + "("
   242  		for _, s := range ss {
   243  			s = strings.TrimSpace(s)
   244  			if strings.EqualFold(s, "p") ||
   245  				strings.EqualFold(s, "precision") ||
   246  				strings.Contains(s, "精度") {
   247  				format += "$p, "
   248  			} else if strings.EqualFold(s, "s") ||
   249  				strings.EqualFold(s, "scale") ||
   250  				strings.Contains(s, "标度") ||
   251  				strings.Contains(s, "刻度") {
   252  				format += "$s, "
   253  			} else {
   254  				format += "$l, "
   255  			}
   256  		}
   257  		format = strings.TrimSuffix(format, ", ")
   258  
   259  		format += ")"
   260  		if strings.Contains(nameEnd, "(") {
   261  			endStart := nameEnd[0:strings.Index(nameEnd, "(")]
   262  			endEnd := nameEnd[strings.Index(nameEnd, ")")+1:]
   263  			inStr = nameEnd[strings.Index(nameEnd, "("):strings.Index(nameEnd, ")")]
   264  			inStr = strings.ReplaceAll(inStr, "(", "")
   265  			inStr = strings.ReplaceAll(inStr, ")", "")
   266  
   267  			ss = strings.Split(inStr, ",")
   268  			format += endStart + "("
   269  			for _, s := range ss {
   270  				s = strings.TrimSpace(s)
   271  				if strings.EqualFold(s, "p") ||
   272  					strings.EqualFold(s, "precision") ||
   273  					strings.Contains(s, "精度") {
   274  					if strings.Contains(s, "小数秒精度") {
   275  						format += "$s, "
   276  					} else {
   277  						format += "$p, "
   278  					}
   279  				} else if strings.EqualFold(s, "s") ||
   280  					strings.EqualFold(s, "scale") ||
   281  					strings.Contains(s, "标度") ||
   282  					strings.Contains(s, "刻度") {
   283  					format += "$s, "
   284  				} else {
   285  					format += "$l, "
   286  				}
   287  			}
   288  			format = strings.TrimSuffix(format, ", ")
   289  			format += ")" + endEnd
   290  			name = nameStart + endStart + endEnd
   291  		} else {
   292  			format += nameEnd
   293  			name = nameStart + nameEnd
   294  		}
   295  
   296  	}
   297  	var typeText = dataType["类型"]
   298  	if strings.Contains(typeText, "整型") {
   299  		info.IsInteger = true
   300  		info.IsNumber = true
   301  	} else if strings.Contains(typeText, "浮点") {
   302  		info.IsFloat = true
   303  		info.IsNumber = true
   304  	} else if strings.Contains(typeText, "定点") {
   305  		info.IsFloat = true
   306  		info.IsNumber = true
   307  	} else if strings.Contains(typeText, "数值") {
   308  		info.IsNumber = true
   309  	} else if strings.Contains(typeText, "字符") {
   310  		info.IsString = true
   311  	} else if strings.Contains(typeText, "二进制") {
   312  		info.IsBytes = true
   313  	} else if strings.Contains(typeText, "布尔") {
   314  		info.IsBoolean = true
   315  	} else if strings.Contains(typeText, "日期") {
   316  		info.IsDateTime = true
   317  	} else if strings.Contains(typeText, "枚举") {
   318  		info.IsEnum = true
   319  	}
   320  	info.Name = name
   321  	info.Format = format
   322  	info.Comment = dataType["说明"]
   323  	matchStr := dataType["匹配"]
   324  	matches := strings.Split(matchStr, "\n")
   325  	for _, match := range matches {
   326  		match = strings.TrimSpace(match)
   327  		if match == "" {
   328  			continue
   329  		}
   330  		if strings.EqualFold(match, "if not found") {
   331  			info.IfNotFound = true
   332  			continue
   333  		}
   334  		info.Matches = append(info.Matches, match)
   335  	}
   336  	return
   337  }