github.com/kotovmak/go-admin@v1.1.1/modules/db/dialect/dialect.go (about)

     1  // Copyright 2019 GoAdmin Core Team. All rights reserved.
     2  // Use of this source code is governed by a Apache-2.0 style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dialect
     6  
     7  import (
     8  	"strings"
     9  
    10  	"github.com/kotovmak/go-admin/modules/config"
    11  )
    12  
    13  // Dialect is methods set of different driver.
    14  type Dialect interface {
    15  	// GetName get dialect's name
    16  	GetName() string
    17  
    18  	// ShowColumns show columns of specified table
    19  	ShowColumns(table string) string
    20  
    21  	// ShowTables show tables of database
    22  	ShowTables() string
    23  
    24  	// Insert
    25  	Insert(comp *SQLComponent) string
    26  
    27  	// Delete
    28  	Delete(comp *SQLComponent) string
    29  
    30  	// Update
    31  	Update(comp *SQLComponent) string
    32  
    33  	// Select
    34  	Select(comp *SQLComponent) string
    35  
    36  	// GetDelimiter return the delimiter of Dialect.
    37  	GetDelimiter() string
    38  }
    39  
    40  // GetDialect return the default Dialect.
    41  func GetDialect() Dialect {
    42  	return GetDialectByDriver(config.GetDatabases().GetDefault().Driver)
    43  }
    44  
    45  // GetDialectByDriver return the Dialect of given driver.
    46  func GetDialectByDriver(driver string) Dialect {
    47  	switch driver {
    48  	case "mysql":
    49  		return mysql{
    50  			commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
    51  		}
    52  	case "mssql":
    53  		return mssql{
    54  			commonDialect: commonDialect{delimiter: "[", delimiter2: "]"},
    55  		}
    56  	case "postgresql":
    57  		return postgresql{
    58  			commonDialect: commonDialect{delimiter: `"`, delimiter2: `"`},
    59  		}
    60  	case "sqlite":
    61  		return sqlite{
    62  			commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
    63  		}
    64  	case "oceanbase":
    65  		return oceanbase{
    66  			commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
    67  		}
    68  	default:
    69  		return commonDialect{delimiter: "`", delimiter2: "`"}
    70  	}
    71  }
    72  
    73  // H is a shorthand of map.
    74  type H map[string]interface{}
    75  
    76  // SQLComponent is a sql components set.
    77  type SQLComponent struct {
    78  	Fields     []string
    79  	Functions  []string
    80  	TableName  string
    81  	Wheres     []Where
    82  	Leftjoins  []Join
    83  	Args       []interface{}
    84  	Order      string
    85  	Offset     string
    86  	Limit      string
    87  	WhereRaws  string
    88  	UpdateRaws []RawUpdate
    89  	Group      string
    90  	Statement  string
    91  	Values     H
    92  }
    93  
    94  // Where contains the operation and field.
    95  type Where struct {
    96  	Operation string
    97  	Field     string
    98  	Qmark     string
    99  }
   100  
   101  // Join contains the table and field and operation.
   102  type Join struct {
   103  	Table     string
   104  	FieldA    string
   105  	Operation string
   106  	FieldB    string
   107  }
   108  
   109  // RawUpdate contains the expression and arguments.
   110  type RawUpdate struct {
   111  	Expression string
   112  	Args       []interface{}
   113  }
   114  
   115  // *******************************
   116  // internal help function
   117  // *******************************
   118  
   119  func (sql *SQLComponent) getLimit() string {
   120  	if sql.Limit == "" {
   121  		return ""
   122  	}
   123  	return " limit " + sql.Limit + " "
   124  }
   125  
   126  func (sql *SQLComponent) getOffset() string {
   127  	if sql.Offset == "" {
   128  		return ""
   129  	}
   130  	return " offset " + sql.Offset + " "
   131  }
   132  
   133  func (sql *SQLComponent) getOrderBy() string {
   134  	if sql.Order == "" {
   135  		return ""
   136  	}
   137  	return " order by " + sql.Order + " "
   138  }
   139  
   140  func (sql *SQLComponent) getGroupBy() string {
   141  	if sql.Group == "" {
   142  		return ""
   143  	}
   144  	return " group by " + sql.Group + " "
   145  }
   146  
   147  func (sql *SQLComponent) getJoins(delimiter, delimiter2 string) string {
   148  	if len(sql.Leftjoins) == 0 {
   149  		return ""
   150  	}
   151  	joins := ""
   152  	for _, join := range sql.Leftjoins {
   153  		joins += " left join " + wrap(delimiter, delimiter2, join.Table) + " on " +
   154  			sql.processLeftJoinField(join.FieldA, delimiter, delimiter2) + " " + join.Operation + " " +
   155  			sql.processLeftJoinField(join.FieldB, delimiter, delimiter2) + " "
   156  	}
   157  	return joins
   158  }
   159  
   160  func (sql *SQLComponent) processLeftJoinField(field, delimiter, delimiter2 string) string {
   161  	arr := strings.Split(field, ".")
   162  	if len(arr) > 0 {
   163  		return delimiter + arr[0] + delimiter2 + "." + delimiter + arr[1] + delimiter2
   164  	}
   165  	return field
   166  }
   167  
   168  func (sql *SQLComponent) getFields(delimiter, delimiter2 string) string {
   169  	if len(sql.Fields) == 0 {
   170  		return "*"
   171  	}
   172  	fields := ""
   173  	if len(sql.Leftjoins) == 0 {
   174  		for k, field := range sql.Fields {
   175  			if sql.Functions[k] != "" {
   176  				fields += sql.Functions[k] + "(" + wrap(delimiter, delimiter2, field) + "),"
   177  			} else {
   178  				fields += wrap(delimiter, delimiter2, field) + ","
   179  			}
   180  		}
   181  	} else {
   182  		for _, field := range sql.Fields {
   183  			arr := strings.Split(field, ".")
   184  			if len(arr) > 1 {
   185  				fields += wrap(delimiter, delimiter2, arr[0]) + "." + wrap(delimiter, delimiter2, arr[1]) + ","
   186  			} else {
   187  				fields += wrap(delimiter, delimiter2, field) + ","
   188  			}
   189  		}
   190  	}
   191  	return fields[:len(fields)-1]
   192  }
   193  
   194  func wrap(delimiter, delimiter2, field string) string {
   195  	if field == "*" {
   196  		return "*"
   197  	}
   198  	return delimiter + field + delimiter2
   199  }
   200  
   201  func (sql *SQLComponent) getWheres(delimiter, delimiter2 string) string {
   202  	if len(sql.Wheres) == 0 {
   203  		if sql.WhereRaws != "" {
   204  			return " where " + sql.WhereRaws
   205  		}
   206  		return ""
   207  	}
   208  	wheres := " where "
   209  	var arr []string
   210  	for _, where := range sql.Wheres {
   211  		arr = strings.Split(where.Field, ".")
   212  		if len(arr) > 1 {
   213  			wheres += arr[0] + "." + wrap(delimiter, delimiter2, arr[1]) + " " + where.Operation + " " + where.Qmark + " and "
   214  		} else {
   215  			wheres += wrap(delimiter, delimiter2, where.Field) + " " + where.Operation + " " + where.Qmark + " and "
   216  		}
   217  	}
   218  
   219  	if sql.WhereRaws != "" {
   220  		return wheres + sql.WhereRaws
   221  	}
   222  	return wheres[:len(wheres)-5]
   223  }
   224  
   225  func (sql *SQLComponent) prepareUpdate(delimiter, delimiter2 string) {
   226  	fields := ""
   227  	args := make([]interface{}, 0)
   228  
   229  	if len(sql.Values) != 0 {
   230  
   231  		for key, value := range sql.Values {
   232  			fields += wrap(delimiter, delimiter2, key) + " = ?, "
   233  			args = append(args, value)
   234  		}
   235  
   236  		if len(sql.UpdateRaws) == 0 {
   237  			fields = fields[:len(fields)-2]
   238  		} else {
   239  			for i := 0; i < len(sql.UpdateRaws); i++ {
   240  				if i == len(sql.UpdateRaws)-1 {
   241  					fields += sql.UpdateRaws[i].Expression + " "
   242  				} else {
   243  					fields += sql.UpdateRaws[i].Expression + ","
   244  				}
   245  				args = append(args, sql.UpdateRaws[i].Args...)
   246  			}
   247  		}
   248  
   249  		sql.Args = append(args, sql.Args...)
   250  	} else {
   251  		if len(sql.UpdateRaws) == 0 {
   252  			panic("prepareUpdate: wrong parameter")
   253  		} else {
   254  			for i := 0; i < len(sql.UpdateRaws); i++ {
   255  				if i == len(sql.UpdateRaws)-1 {
   256  					fields += sql.UpdateRaws[i].Expression + " "
   257  				} else {
   258  					fields += sql.UpdateRaws[i].Expression + ","
   259  				}
   260  				args = append(args, sql.UpdateRaws[i].Args...)
   261  			}
   262  		}
   263  		sql.Args = append(args, sql.Args...)
   264  	}
   265  
   266  	sql.Statement = "update " + delimiter + sql.TableName + delimiter2 + " set " + fields + sql.getWheres(delimiter, delimiter2)
   267  }
   268  
   269  func (sql *SQLComponent) prepareInsert(delimiter, delimiter2 string) {
   270  	fields := " ("
   271  	quesMark := "("
   272  
   273  	for key, value := range sql.Values {
   274  		fields += wrap(delimiter, delimiter2, key) + ","
   275  		quesMark += "?,"
   276  		sql.Args = append(sql.Args, value)
   277  	}
   278  	fields = fields[:len(fields)-1] + ")"
   279  	quesMark = quesMark[:len(quesMark)-1] + ")"
   280  
   281  	sql.Statement = "insert into " + delimiter + sql.TableName + delimiter2 + fields + " values " + quesMark
   282  }