github.com/team-ide/go-dialect@v1.9.20/dialect/dialect_mapping.go (about)

     1  package dialect
     2  
     3  import (
     4  	"errors"
     5  	"reflect"
     6  	"strings"
     7  )
     8  
     9  func NewMappingDialect(mapping *SqlMapping) (dia Dialect, err error) {
    10  	mappingDia := &mappingDialect{
    11  		SqlMapping: mapping,
    12  	}
    13  
    14  	err = mappingDia.init()
    15  	if err != nil {
    16  		return
    17  	}
    18  	dia = mappingDia
    19  	return
    20  }
    21  
    22  type mappingDialect struct {
    23  	*SqlMapping
    24  
    25  	OwnersSelect *RootStatement
    26  	OwnerSelect  *RootStatement
    27  	OwnerCreate  *RootStatement
    28  	OwnerDelete  *RootStatement
    29  
    30  	TablesSelect          *RootStatement
    31  	TableSelect           *RootStatement
    32  	TableCreate           *RootStatement
    33  	TableCreateColumn     *RootStatement
    34  	TableCreatePrimaryKey *RootStatement
    35  	TableDelete           *RootStatement
    36  	TableComment          *RootStatement
    37  	TableRename           *RootStatement
    38  
    39  	ColumnsSelect *RootStatement
    40  	ColumnSelect  *RootStatement
    41  	ColumnAdd     *RootStatement
    42  	ColumnDelete  *RootStatement
    43  	ColumnComment *RootStatement
    44  	ColumnRename  *RootStatement
    45  	ColumnUpdate  *RootStatement
    46  	ColumnAfter   *RootStatement
    47  
    48  	PrimaryKeysSelect *RootStatement
    49  	PrimaryKeyAdd     *RootStatement
    50  	PrimaryKeyDelete  *RootStatement
    51  
    52  	IndexesSelect *RootStatement
    53  	IndexAdd      *RootStatement
    54  	IndexDelete   *RootStatement
    55  }
    56  
    57  func (this_ *mappingDialect) init() (err error) {
    58  	this_.SqlMapping.dialect = this_
    59  	rootStatementType := reflect.TypeOf(&RootStatement{})
    60  
    61  	mappingValue := reflect.ValueOf(this_.SqlMapping).Elem()
    62  	sqlStatementValue := reflect.ValueOf(this_).Elem()
    63  	sqlStatementType := reflect.TypeOf(this_).Elem()
    64  	var statement *RootStatement
    65  	for i := 0; i < sqlStatementValue.NumField(); i++ {
    66  		fieldValue := sqlStatementValue.Field(i)
    67  		fieldType := sqlStatementType.Field(i)
    68  		if fieldType.Type != rootStatementType {
    69  			continue
    70  		}
    71  		mappingField := mappingValue.FieldByName(fieldType.Name)
    72  		if mappingField.Kind() == reflect.Invalid {
    73  			err = errors.New("mapping field [" + fieldType.Name + "] is invalid")
    74  			return
    75  		}
    76  		sqlTemplate := strings.TrimSpace(mappingField.String())
    77  		if len(sqlTemplate) == 0 {
    78  			continue
    79  		}
    80  		statement, err = statementParse(sqlTemplate)
    81  		if err != nil {
    82  			return
    83  		}
    84  		fieldValue.Set(reflect.ValueOf(statement))
    85  	}
    86  	return
    87  }
    88  
    89  type StatementScript struct {
    90  	*ParamModel
    91  	Dialect
    92  }
    93  
    94  func (this_ StatementScript) sqlValuePack(value interface{}) (res string) {
    95  
    96  	res = this_.SqlValuePack(this_.ParamModel, nil, value)
    97  	return
    98  }
    99  
   100  func (this_ StatementScript) doubleQuotationMarksPack(value interface{}) (res string) {
   101  	res = packingValue(nil, nil, "\"", "\\", value)
   102  	return
   103  }
   104  
   105  func (this_ StatementScript) columnNotNull(columnNotNull interface{}) (res string) {
   106  	if isTrue(columnNotNull) {
   107  		res = "NOT NULL"
   108  	}
   109  	return
   110  }
   111  
   112  func (this_ StatementScript) equalFold(arg1 interface{}, arg2 interface{}) bool {
   113  	if arg1 == arg2 {
   114  		return true
   115  	}
   116  	str1 := GetStringValue(arg1)
   117  	str2 := GetStringValue(arg2)
   118  	return strings.EqualFold(str1, str2)
   119  }
   120  
   121  func (this_ StatementScript) joins(joinList interface{}, joinObj interface{}) (res string) {
   122  	if joinList == nil {
   123  		return
   124  	}
   125  	list, ok := joinList.([]string)
   126  	if !ok {
   127  		objList := joinList.([]interface{})
   128  		for _, one := range objList {
   129  			list = append(list, GetStringValue(one))
   130  		}
   131  	}
   132  	res = strings.Join(list, GetStringValue(joinObj))
   133  	return
   134  }
   135  
   136  func (this_ *mappingDialect) NewStatementContext(param *ParamModel, dataList ...interface{}) (statementContext *StatementContext, err error) {
   137  	statementContext = NewStatementContext()
   138  
   139  	statementScript := &StatementScript{
   140  		ParamModel: param,
   141  		Dialect:    this_,
   142  	}
   143  	statementContext.AddMethod("sqlValuePack", statementScript.sqlValuePack)
   144  	statementContext.AddMethod("columnNotNull", statementScript.columnNotNull)
   145  	statementContext.AddMethod("joins", statementScript.joins)
   146  	statementContext.AddMethod("equalFold", statementScript.equalFold)
   147  	statementContext.AddMethod("doubleQuotationMarksPack", statementScript.doubleQuotationMarksPack)
   148  
   149  	if this_.MethodCache != nil {
   150  		for name, method := range this_.MethodCache {
   151  			statementContext.AddMethod(name, method)
   152  		}
   153  	}
   154  	if param != nil {
   155  		err = statementContext.SetJSONData(param)
   156  		if err != nil {
   157  			return
   158  		}
   159  		err = statementContext.SetJSONData(param.CustomData)
   160  		if err != nil {
   161  			return
   162  		}
   163  	}
   164  	for _, data := range dataList {
   165  		err = statementContext.SetJSONData(data)
   166  		if err != nil {
   167  			return
   168  		}
   169  	}
   170  
   171  	ownerNamePack := ""
   172  	ownerName, _ := statementContext.GetData("ownerName")
   173  	if ownerName != nil && ownerName != "" {
   174  		ownerNamePack = this_.OwnerNamePack(param, ownerName.(string))
   175  	}
   176  	statementContext.SetData("ownerNamePack", ownerNamePack)
   177  
   178  	tableNamePack := ""
   179  	tableName, _ := statementContext.GetData("tableName")
   180  	if tableName != nil && tableName != "" {
   181  		tableNamePack = this_.TableNamePack(param, tableName.(string))
   182  	}
   183  	statementContext.SetData("tableNamePack", tableNamePack)
   184  
   185  	oldTableNamePack := ""
   186  	oldTableName, _ := statementContext.GetData("oldTableName")
   187  	if oldTableName != nil && oldTableName != "" {
   188  		oldTableNamePack = this_.TableNamePack(param, oldTableName.(string))
   189  	}
   190  	statementContext.SetData("oldTableNamePack", oldTableNamePack)
   191  
   192  	columnNamePack := ""
   193  	columnName, _ := statementContext.GetData("columnName")
   194  	if columnName != nil && columnName != "" {
   195  		columnNamePack = this_.ColumnNamePack(param, columnName.(string))
   196  	}
   197  	statementContext.SetData("columnNamePack", columnNamePack)
   198  
   199  	oldColumnNamePack := ""
   200  	oldColumnName, _ := statementContext.GetData("oldColumnName")
   201  	if oldColumnName != nil && oldColumnName != "" {
   202  		oldColumnNamePack = this_.ColumnNamePack(param, oldColumnName.(string))
   203  	}
   204  	statementContext.SetData("oldColumnNamePack", oldColumnNamePack)
   205  
   206  	columnAfterColumnPack := ""
   207  	columnAfterColumn, _ := statementContext.GetData("columnAfterColumn")
   208  	if columnAfterColumn != nil && columnAfterColumn != "" {
   209  		columnAfterColumnPack = this_.ColumnNamePack(param, columnAfterColumn.(string))
   210  	}
   211  	statementContext.SetData("columnAfterColumnPack", columnAfterColumnPack)
   212  
   213  	columnNamesPack := ""
   214  	columnNames, _ := statementContext.GetData("columnNames")
   215  	if columnNames != nil {
   216  		list := columnNames.([]interface{})
   217  		var stringList []string
   218  		for _, one := range list {
   219  			stringList = append(stringList, one.(string))
   220  		}
   221  		columnNamesPack = this_.ColumnNamesPack(param, stringList)
   222  	}
   223  	statementContext.SetData("columnNamesPack", columnNamesPack)
   224  
   225  	primaryKeysPack := ""
   226  	primaryKeys, _ := statementContext.GetData("primaryKeys")
   227  	if primaryKeys != nil {
   228  		list := primaryKeys.([]interface{})
   229  		var stringList []string
   230  		for _, one := range list {
   231  			stringList = append(stringList, one.(string))
   232  		}
   233  		primaryKeysPack = this_.ColumnNamesPack(param, stringList)
   234  	}
   235  	statementContext.SetData("primaryKeysPack", primaryKeysPack)
   236  
   237  	indexNamePack := ""
   238  	indexName, _ := statementContext.GetData("indexName")
   239  	if indexName != nil && indexName != "" {
   240  		indexNamePack = this_.ColumnNamePack(param, indexName.(string))
   241  	}
   242  	statementContext.SetData("indexNamePack", indexNamePack)
   243  
   244  	return
   245  }
   246  
   247  func (this_ *mappingDialect) FormatSql(statement *RootStatement, param *ParamModel, dataList ...interface{}) (sqlList []string, err error) {
   248  	if statement == nil {
   249  		return
   250  	}
   251  	statementContext, err := this_.NewStatementContext(param, dataList...)
   252  	if err != nil {
   253  		return
   254  	}
   255  	sqlInfo, err := statement.Format(statementContext)
   256  	if err != nil {
   257  		return
   258  	}
   259  	sqlList = this_.SqlSplit(sqlInfo)
   260  
   261  	//fmt.Println("FormatSql sql data cache")
   262  	//fmt.Println(statementContext.dataCache)
   263  	//fmt.Println("FormatSql sql list")
   264  	//for _, sqlOne := range sqlList {
   265  	//	fmt.Println("sql:", sqlOne)
   266  	//}
   267  	return
   268  }