github.com/kotovmak/go-admin@v1.1.1/modules/db/mssql.go (about) 1 // Copyright 2019 GoAdmin Core Team. All rights reserved. 2 // Use of this source code is governed by a Apache-2.0 style 3 // license that can be found in the LICENSE file. 4 5 package db 6 7 import ( 8 "database/sql" 9 "fmt" 10 "regexp" 11 "strconv" 12 "strings" 13 14 "github.com/kotovmak/go-admin/modules/config" 15 ) 16 17 // Mssql is a Connection of mssql. 18 type Mssql struct { 19 Base 20 } 21 22 // GetMssqlDB return the global mssql connection. 23 func GetMssqlDB() *Mssql { 24 return &Mssql{ 25 Base: Base{ 26 DbList: make(map[string]*sql.DB), 27 }, 28 } 29 } 30 31 // GetDelimiter implements the method Connection.GetDelimiter. 32 func (db *Mssql) GetDelimiter() string { 33 return "[" 34 } 35 36 // GetDelimiter2 implements the method Connection.GetDelimiter2. 37 func (db *Mssql) GetDelimiter2() string { 38 return "]" 39 } 40 41 // GetDelimiters implements the method Connection.GetDelimiters. 42 func (db *Mssql) GetDelimiters() []string { 43 return []string{"[", "]"} 44 } 45 46 // Name implements the method Connection.Name. 47 func (db *Mssql) Name() string { 48 return "mssql" 49 } 50 51 // TODO: 整理优化 52 53 func replaceStringFunc(pattern, src string, rpl func(s string) string) (string, error) { 54 55 r, err := regexp.Compile(pattern) 56 if err != nil { 57 return "", err 58 } 59 60 bytes := r.ReplaceAllFunc([]byte(src), func(bytes []byte) []byte { 61 return []byte(rpl(string(bytes))) 62 }) 63 64 return string(bytes), nil 65 } 66 67 func replace(pattern string, replace, src []byte) ([]byte, error) { 68 69 r, err := regexp.Compile(pattern) 70 if err != nil { 71 return nil, err 72 } 73 74 return r.ReplaceAll(src, replace), nil 75 } 76 77 func replaceString(pattern, rep, src string) (string, error) { 78 r, e := replace(pattern, []byte(rep), []byte(src)) 79 return string(r), e 80 } 81 82 func matchAllString(pattern string, src string) ([][]string, error) { 83 r, err := regexp.Compile(pattern) 84 if err != nil { 85 return nil, err 86 } 87 return r.FindAllStringSubmatch(src, -1), nil 88 } 89 90 func isMatch(pattern string, src []byte) bool { 91 r, err := regexp.Compile(pattern) 92 if err != nil { 93 return false 94 } 95 return r.Match(src) 96 } 97 98 func isMatchString(pattern string, src string) bool { 99 return isMatch(pattern, []byte(src)) 100 } 101 102 func matchString(pattern string, src string) ([]string, error) { 103 r, err := regexp.Compile(pattern) 104 if err != nil { 105 return nil, err 106 } 107 return r.FindStringSubmatch(src), nil 108 } 109 110 // 从Gf框架复制 111 // 在执行sql之前对sql进行进一步处理 112 func (db *Mssql) handleSqlBeforeExec(query string) string { 113 index := 0 114 str, _ := replaceStringFunc("\\?", query, func(s string) string { 115 index++ 116 return fmt.Sprintf("@p%d", index) 117 }) 118 119 str, _ = replaceString("\"", "", str) 120 121 return db.parseSql(str) 122 } 123 124 // 将MYSQL的SQL语法转换为MSSQL的语法 125 // 1.由于mssql不支持limit写法所以需要对mysql中的limit用法做转换 126 func (db *Mssql) parseSql(sql string) string { 127 //下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出 128 patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` 129 if !isMatchString(patten, sql) { 130 //fmt.Println("not matched..") 131 return sql 132 } 133 134 res, err := matchAllString(patten, sql) 135 if err != nil { 136 //fmt.Println("MatchString error.", err) 137 return "" 138 } 139 140 index := 0 141 keyword := strings.TrimSpace(res[index][0]) 142 keyword = strings.ToUpper(keyword) 143 144 index++ 145 switch keyword { 146 case "SELECT": 147 //不含LIMIT关键字则不处理 148 if len(res) < 2 || (!strings.HasPrefix(res[index][0], "LIMIT") && !strings.HasPrefix(res[index][0], "limit")) { 149 break 150 } 151 152 //不含LIMIT则不处理 153 if !isMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) { 154 break 155 } 156 157 //判断SQL中是否含有order by 158 selectStr := "" 159 orderbyStr := "" 160 haveOrderby := isMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) 161 if haveOrderby { 162 //取order by 前面的字符串 163 queryExpr, _ := matchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) 164 165 if len(queryExpr) != 4 || !strings.EqualFold(queryExpr[1], "SELECT") || !strings.EqualFold(queryExpr[3], "ORDER BY") { 166 break 167 } 168 selectStr = queryExpr[2] 169 170 //取order by表达式的值 171 orderbyExpr, _ := matchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql) 172 if len(orderbyExpr) != 4 || !strings.EqualFold(orderbyExpr[1], "ORDER BY") || !strings.EqualFold(orderbyExpr[3], "LIMIT") { 173 break 174 } 175 orderbyStr = orderbyExpr[2] 176 } else { 177 queryExpr, _ := matchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) 178 if len(queryExpr) != 4 || !strings.EqualFold(queryExpr[1], "SELECT") || !strings.EqualFold(queryExpr[3], "LIMIT") { 179 break 180 } 181 selectStr = queryExpr[2] 182 } 183 184 //取limit后面的取值范围 185 first, limit := 0, 0 186 for i := 1; i < len(res[index]); i++ { 187 if strings.TrimSpace(res[index][i]) == "" { 188 continue 189 } 190 191 if strings.HasPrefix(res[index][i], "LIMIT") || strings.HasPrefix(res[index][i], "limit") { 192 first, _ = strconv.Atoi(res[index][i+1]) 193 limit, _ = strconv.Atoi(res[index][i+2]) 194 break 195 } 196 } 197 198 if haveOrderby { 199 sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit) 200 } else { 201 if first == 0 { 202 first = limit 203 } else { 204 first = limit - first 205 } 206 sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr) 207 } 208 default: 209 } 210 return sql 211 } 212 213 // QueryWithConnection implements the method Connection.QueryWithConnection. 214 func (db *Mssql) QueryWithConnection(con string, query string, args ...interface{}) ([]map[string]interface{}, error) { 215 query = db.handleSqlBeforeExec(query) 216 return CommonQuery(db.DbList[con], query, args...) 217 } 218 219 // ExecWithConnection implements the method Connection.ExecWithConnection. 220 func (db *Mssql) ExecWithConnection(con string, query string, args ...interface{}) (sql.Result, error) { 221 query = db.handleSqlBeforeExec(query) 222 return CommonExec(db.DbList[con], query, args...) 223 } 224 225 // Query implements the method Connection.Query. 226 func (db *Mssql) Query(query string, args ...interface{}) ([]map[string]interface{}, error) { 227 query = db.handleSqlBeforeExec(query) 228 return CommonQuery(db.DbList["default"], query, args...) 229 } 230 231 // Exec implements the method Connection.Exec. 232 func (db *Mssql) Exec(query string, args ...interface{}) (sql.Result, error) { 233 query = db.handleSqlBeforeExec(query) 234 return CommonExec(db.DbList["default"], query, args...) 235 } 236 237 func (db *Mssql) QueryWith(tx *sql.Tx, conn, query string, args ...interface{}) ([]map[string]interface{}, error) { 238 if tx != nil { 239 return db.QueryWithTx(tx, query, args...) 240 } 241 return db.QueryWithConnection(conn, query, args...) 242 } 243 244 func (db *Mssql) ExecWith(tx *sql.Tx, conn, query string, args ...interface{}) (sql.Result, error) { 245 if tx != nil { 246 return db.ExecWithTx(tx, query, args...) 247 } 248 return db.ExecWithConnection(conn, query, args...) 249 } 250 251 // InitDB implements the method Connection.InitDB. 252 func (db *Mssql) InitDB(cfgs map[string]config.Database) Connection { 253 db.Configs = cfgs 254 db.Once.Do(func() { 255 for conn, cfg := range cfgs { 256 257 sqlDB, err := sql.Open("sqlserver", cfg.GetDSN()) 258 259 if sqlDB == nil { 260 panic("invalid connection") 261 } 262 263 if err != nil { 264 _ = sqlDB.Close() 265 panic(err.Error()) 266 } 267 268 sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) 269 sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) 270 sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime) 271 sqlDB.SetConnMaxIdleTime(cfg.ConnMaxIdleTime) 272 273 db.DbList[conn] = sqlDB 274 275 if err := sqlDB.Ping(); err != nil { 276 panic(err) 277 } 278 } 279 }) 280 return db 281 } 282 283 // BeginTxWithReadUncommitted starts a transaction with level LevelReadUncommitted. 284 func (db *Mssql) BeginTxWithReadUncommitted() *sql.Tx { 285 return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelReadUncommitted) 286 } 287 288 // BeginTxWithReadCommitted starts a transaction with level LevelReadCommitted. 289 func (db *Mssql) BeginTxWithReadCommitted() *sql.Tx { 290 return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelReadCommitted) 291 } 292 293 // BeginTxWithRepeatableRead starts a transaction with level LevelRepeatableRead. 294 func (db *Mssql) BeginTxWithRepeatableRead() *sql.Tx { 295 return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelRepeatableRead) 296 } 297 298 // BeginTx starts a transaction with level LevelDefault. 299 func (db *Mssql) BeginTx() *sql.Tx { 300 return CommonBeginTxWithLevel(db.DbList["default"], sql.LevelDefault) 301 } 302 303 // BeginTxWithLevel starts a transaction with given transaction isolation level. 304 func (db *Mssql) BeginTxWithLevel(level sql.IsolationLevel) *sql.Tx { 305 return CommonBeginTxWithLevel(db.DbList["default"], level) 306 } 307 308 // BeginTxWithReadUncommittedAndConnection starts a transaction with level LevelReadUncommitted and connection. 309 func (db *Mssql) BeginTxWithReadUncommittedAndConnection(conn string) *sql.Tx { 310 return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelReadUncommitted) 311 } 312 313 // BeginTxWithReadCommittedAndConnection starts a transaction with level LevelReadCommitted and connection. 314 func (db *Mssql) BeginTxWithReadCommittedAndConnection(conn string) *sql.Tx { 315 return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelReadCommitted) 316 } 317 318 // BeginTxWithRepeatableReadAndConnection starts a transaction with level LevelRepeatableRead and connection. 319 func (db *Mssql) BeginTxWithRepeatableReadAndConnection(conn string) *sql.Tx { 320 return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelRepeatableRead) 321 } 322 323 // BeginTxAndConnection starts a transaction with level LevelDefault and connection. 324 func (db *Mssql) BeginTxAndConnection(conn string) *sql.Tx { 325 return CommonBeginTxWithLevel(db.DbList[conn], sql.LevelDefault) 326 } 327 328 // BeginTxWithLevelAndConnection starts a transaction with given transaction isolation level and connection. 329 func (db *Mssql) BeginTxWithLevelAndConnection(conn string, level sql.IsolationLevel) *sql.Tx { 330 return CommonBeginTxWithLevel(db.DbList[conn], level) 331 } 332 333 // QueryWithTx is query method within the transaction. 334 func (db *Mssql) QueryWithTx(tx *sql.Tx, query string, args ...interface{}) ([]map[string]interface{}, error) { 335 query = db.handleSqlBeforeExec(query) 336 return CommonQueryWithTx(tx, query, args...) 337 } 338 339 // ExecWithTx is exec method within the transaction. 340 func (db *Mssql) ExecWithTx(tx *sql.Tx, query string, args ...interface{}) (sql.Result, error) { 341 query = db.handleSqlBeforeExec(query) 342 return CommonExecWithTx(tx, query, args...) 343 }