github.com/kotovmak/go-admin@v1.1.1/modules/db/mssql.go (about)

     1  // Copyright 2019 GoAdmin Core Team. All rights reserved.
     2  // Use of this source code is governed by a Apache-2.0 style
     3  // license that can be found in the LICENSE file.
     4  
     5  package db
     6  
     7  import (
     8  	"database/sql"
     9  	"fmt"
    10  	"regexp"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/kotovmak/go-admin/modules/config"
    15  )
    16  
    17  // Mssql is a Connection of mssql.
    18  type Mssql struct {
    19  	Base
    20  }
    21  
    22  // GetMssqlDB return the global mssql connection.
    23  func GetMssqlDB() *Mssql {
    24  	return &Mssql{
    25  		Base: Base{
    26  			DbList: make(map[string]*sql.DB),
    27  		},
    28  	}
    29  }
    30  
    31  // GetDelimiter implements the method Connection.GetDelimiter.
    32  func (db *Mssql) GetDelimiter() string {
    33  	return "["
    34  }
    35  
    36  // GetDelimiter2 implements the method Connection.GetDelimiter2.
    37  func (db *Mssql) GetDelimiter2() string {
    38  	return "]"
    39  }
    40  
    41  // GetDelimiters implements the method Connection.GetDelimiters.
    42  func (db *Mssql) GetDelimiters() []string {
    43  	return []string{"[", "]"}
    44  }
    45  
    46  // Name implements the method Connection.Name.
    47  func (db *Mssql) Name() string {
    48  	return "mssql"
    49  }
    50  
    51  // TODO: 整理优化
    52  
    53  func replaceStringFunc(pattern, src string, rpl func(s string) string) (string, error) {
    54  
    55  	r, err := regexp.Compile(pattern)
    56  	if err != nil {
    57  		return "", err
    58  	}
    59  
    60  	bytes := r.ReplaceAllFunc([]byte(src), func(bytes []byte) []byte {
    61  		return []byte(rpl(string(bytes)))
    62  	})
    63  
    64  	return string(bytes), nil
    65  }
    66  
    67  func replace(pattern string, replace, src []byte) ([]byte, error) {
    68  
    69  	r, err := regexp.Compile(pattern)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	return r.ReplaceAll(src, replace), nil
    75  }
    76  
    77  func replaceString(pattern, rep, src string) (string, error) {
    78  	r, e := replace(pattern, []byte(rep), []byte(src))
    79  	return string(r), e
    80  }
    81  
    82  func matchAllString(pattern string, src string) ([][]string, error) {
    83  	r, err := regexp.Compile(pattern)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	return r.FindAllStringSubmatch(src, -1), nil
    88  }
    89  
    90  func isMatch(pattern string, src []byte) bool {
    91  	r, err := regexp.Compile(pattern)
    92  	if err != nil {
    93  		return false
    94  	}
    95  	return r.Match(src)
    96  }
    97  
    98  func isMatchString(pattern string, src string) bool {
    99  	return isMatch(pattern, []byte(src))
   100  }
   101  
   102  func matchString(pattern string, src string) ([]string, error) {
   103  	r, err := regexp.Compile(pattern)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	return r.FindStringSubmatch(src), nil
   108  }
   109  
   110  // 从Gf框架复制
   111  // 在执行sql之前对sql进行进一步处理
   112  func (db *Mssql) handleSqlBeforeExec(query string) string {
   113  	index := 0
   114  	str, _ := replaceStringFunc("\\?", query, func(s string) string {
   115  		index++
   116  		return fmt.Sprintf("@p%d", index)
   117  	})
   118  
   119  	str, _ = replaceString("\"", "", str)
   120  
   121  	return db.parseSql(str)
   122  }
   123  
   124  // 将MYSQL的SQL语法转换为MSSQL的语法
   125  // 1.由于mssql不支持limit写法所以需要对mysql中的limit用法做转换
   126  func (db *Mssql) parseSql(sql string) string {
   127  	//下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出
   128  	patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
   129  	if !isMatchString(patten, sql) {
   130  		//fmt.Println("not matched..")
   131  		return sql
   132  	}
   133  
   134  	res, err := matchAllString(patten, sql)
   135  	if err != nil {
   136  		//fmt.Println("MatchString error.", err)
   137  		return ""
   138  	}
   139  
   140  	index := 0
   141  	keyword := strings.TrimSpace(res[index][0])
   142  	keyword = strings.ToUpper(keyword)
   143  
   144  	index++
   145  	switch keyword {
   146  	case "SELECT":
   147  		//不含LIMIT关键字则不处理
   148  		if len(res) < 2 || (!strings.HasPrefix(res[index][0], "LIMIT") && !strings.HasPrefix(res[index][0], "limit")) {
   149  			break
   150  		}
   151  
   152  		//不含LIMIT则不处理
   153  		if !isMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) {
   154  			break
   155  		}
   156  
   157  		//判断SQL中是否含有order by
   158  		selectStr := ""
   159  		orderbyStr := ""
   160  		haveOrderby := isMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
   161  		if haveOrderby {
   162  			//取order by 前面的字符串
   163  			queryExpr, _ := matchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
   164  
   165  			if len(queryExpr) != 4 || !strings.EqualFold(queryExpr[1], "SELECT") || !strings.EqualFold(queryExpr[3], "ORDER BY") {
   166  				break
   167  			}
   168  			selectStr = queryExpr[2]
   169  
   170  			//取order by表达式的值
   171  			orderbyExpr, _ := matchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql)
   172  			if len(orderbyExpr) != 4 || !strings.EqualFold(orderbyExpr[1], "ORDER BY") || !strings.EqualFold(orderbyExpr[3], "LIMIT") {
   173  				break
   174  			}
   175  			orderbyStr = orderbyExpr[2]
   176  		} else {
   177  			queryExpr, _ := matchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
   178  			if len(queryExpr) != 4 || !strings.EqualFold(queryExpr[1], "SELECT") || !strings.EqualFold(queryExpr[3], "LIMIT") {
   179  				break
   180  			}
   181  			selectStr = queryExpr[2]
   182  		}
   183  
   184  		//取limit后面的取值范围
   185  		first, limit := 0, 0
   186  		for i := 1; i < len(res[index]); i++ {
   187  			if strings.TrimSpace(res[index][i]) == "" {
   188  				continue
   189  			}
   190  
   191  			if strings.HasPrefix(res[index][i], "LIMIT") || strings.HasPrefix(res[index][i], "limit") {
   192  				first, _ = strconv.Atoi(res[index][i+1])
   193  				limit, _ = strconv.Atoi(res[index][i+2])
   194  				break
   195  			}
   196  		}
   197  
   198  		if haveOrderby {
   199  			sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s   ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit)
   200  		} else {
   201  			if first == 0 {
   202  				first = limit
   203  			} else {
   204  				first = limit - first
   205  			}
   206  			sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr)
   207  		}
   208  	default:
   209  	}
   210  	return sql
   211  }
   212  
   213  // QueryWithConnection implements the method Connection.QueryWithConnection.
   214  func (db *Mssql) QueryWithConnection(con string, query string, args ...interface{}) ([]map[string]interface{}, error) {
   215  	query = db.handleSqlBeforeExec(query)
   216  	return CommonQuery(db.DbList[con], query, args...)
   217  }
   218  
   219  // ExecWithConnection implements the method Connection.ExecWithConnection.
   220  func (db *Mssql) ExecWithConnection(con string, query string, args ...interface{}) (sql.Result, error) {
   221  	query = db.handleSqlBeforeExec(query)
   222  	return CommonExec(db.DbList[con], query, args...)
   223  }
   224  
   225  // Query implements the method Connection.Query.
   226  func (db *Mssql) Query(query string, args ...interface{}) ([]map[string]interface{}, error) {
   227  	query = db.handleSqlBeforeExec(query)
   228  	return CommonQuery(db.DbList["default"], query, args...)
   229  }
   230  
   231  // Exec implements the method Connection.Exec.
   232  func (db *Mssql) Exec(query string, args ...interface{}) (sql.Result, error) {
   233  	query = db.handleSqlBeforeExec(query)
   234  	return CommonExec(db.DbList["default"], query, args...)
   235  }
   236  
   237  func (db *Mssql) QueryWith(tx *sql.Tx, conn, query string, args ...interface{}) ([]map[string]interface{}, error) {
   238  	if tx != nil {
   239  		return db.QueryWithTx(tx, query, args...)
   240  	}
   241  	return db.QueryWithConnection(conn, query, args...)
   242  }
   243  
   244  func (db *Mssql) ExecWith(tx *sql.Tx, conn, query string, args ...interface{}) (sql.Result, error) {
   245  	if tx != nil {
   246  		return db.ExecWithTx(tx, query, args...)
   247  	}
   248  	return db.ExecWithConnection(conn, query, args...)
   249  }
   250  
   251  // InitDB implements the method Connection.InitDB.
   252  func (db *Mssql) InitDB(cfgs map[string]config.Database) Connection {
   253  	db.Configs = cfgs
   254  	db.Once.Do(func() {
   255  		for conn, cfg := range cfgs {
   256  
   257  			sqlDB, err := sql.Open("sqlserver", cfg.GetDSN())
   258  
   259  			if sqlDB == nil {
   260  				panic("invalid connection")
   261  			}
   262  
   263  			if err != nil {
   264  				_ = sqlDB.Close()
   265  				panic(err.Error())
   266  			}
   267  
   268  			sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
   269  			sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
   270  			sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
   271  			sqlDB.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
   272  
   273  			db.DbList[conn] = sqlDB
   274  
   275  			if err := sqlDB.Ping(); err != nil {
   276  				panic(err)
   277  			}
   278  		}
   279  	})
   280  	return db
   281  }
   282  
   283  // BeginTxWithReadUncommitted starts a transaction with level LevelReadUncommitted.
   284  func (db *Mssql) BeginTxWithReadUncommitted() *sql.Tx {
   285  	return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelReadUncommitted)
   286  }
   287  
   288  // BeginTxWithReadCommitted starts a transaction with level LevelReadCommitted.
   289  func (db *Mssql) BeginTxWithReadCommitted() *sql.Tx {
   290  	return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelReadCommitted)
   291  }
   292  
   293  // BeginTxWithRepeatableRead starts a transaction with level LevelRepeatableRead.
   294  func (db *Mssql) BeginTxWithRepeatableRead() *sql.Tx {
   295  	return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelRepeatableRead)
   296  }
   297  
   298  // BeginTx starts a transaction with level LevelDefault.
   299  func (db *Mssql) BeginTx() *sql.Tx {
   300  	return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelDefault)
   301  }
   302  
   303  // BeginTxWithLevel starts a transaction with given transaction isolation level.
   304  func (db *Mssql) BeginTxWithLevel(level sql.IsolationLevel) *sql.Tx {
   305  	return CommonBeginTxWithLevel(db.DbList["default"], level)
   306  }
   307  
   308  // BeginTxWithReadUncommittedAndConnection starts a transaction with level LevelReadUncommitted and connection.
   309  func (db *Mssql) BeginTxWithReadUncommittedAndConnection(conn string) *sql.Tx {
   310  	return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelReadUncommitted)
   311  }
   312  
   313  // BeginTxWithReadCommittedAndConnection starts a transaction with level LevelReadCommitted and connection.
   314  func (db *Mssql) BeginTxWithReadCommittedAndConnection(conn string) *sql.Tx {
   315  	return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelReadCommitted)
   316  }
   317  
   318  // BeginTxWithRepeatableReadAndConnection starts a transaction with level LevelRepeatableRead and connection.
   319  func (db *Mssql) BeginTxWithRepeatableReadAndConnection(conn string) *sql.Tx {
   320  	return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelRepeatableRead)
   321  }
   322  
   323  // BeginTxAndConnection starts a transaction with level LevelDefault and connection.
   324  func (db *Mssql) BeginTxAndConnection(conn string) *sql.Tx {
   325  	return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelDefault)
   326  }
   327  
   328  // BeginTxWithLevelAndConnection starts a transaction with given transaction isolation level and connection.
   329  func (db *Mssql) BeginTxWithLevelAndConnection(conn string, level sql.IsolationLevel) *sql.Tx {
   330  	return CommonBeginTxWithLevel(db.DbList[conn], level)
   331  }
   332  
   333  // QueryWithTx is query method within the transaction.
   334  func (db *Mssql) QueryWithTx(tx *sql.Tx, query string, args ...interface{}) ([]map[string]interface{}, error) {
   335  	query = db.handleSqlBeforeExec(query)
   336  	return CommonQueryWithTx(tx, query, args...)
   337  }
   338  
   339  // ExecWithTx is exec method within the transaction.
   340  func (db *Mssql) ExecWithTx(tx *sql.Tx, query string, args ...interface{}) (sql.Result, error) {
   341  	query = db.handleSqlBeforeExec(query)
   342  	return CommonExecWithTx(tx, query, args...)
   343  }