github.com/zlyuancn/zstr@v0.0.0-20230412074414-14d6b645962f/sql_template.go (about) 1 /* 2 ------------------------------------------------- 3 Author : Zhang Fan 4 date: 2020/7/18 5 Description : 6 ------------------------------------------------- 7 */ 8 9 package zstr 10 11 import ( 12 "bytes" 13 "fmt" 14 "reflect" 15 "strconv" 16 "strings" 17 ) 18 19 const defaultSqlCompareFlag = "=" 20 21 var ( 22 // 操作符 23 sqlTemplateOperationMapp = map[int32]struct{}{ 24 '&': {}, 25 '|': {}, 26 '#': {}, 27 '@': {}, 28 } 29 // 标记 30 sqlTemplateFlagMapp = map[string]struct{}{ 31 ">": {}, 32 ">=": {}, 33 "<": {}, 34 "<=": {}, 35 "!=": {}, 36 "<>": {}, 37 "=": {}, 38 "in": {}, 39 "notin": {}, 40 "not_in": {}, 41 "like": {}, 42 "likestart": {}, 43 "like_start": {}, 44 "likeend": {}, 45 "like_end": {}, 46 } 47 // 选项 48 sqlTemplateOptsMapp = map[int32]struct{}{ 49 'a': {}, // attention, 不会忽略参数值为该类型的零值 50 'd': {}, // direct, 直接将值写入sql语句 51 'm': {}, // must, 必填 52 } 53 ) 54 55 type sqlTemplate struct { 56 data map[string]interface{} 57 names []string 58 values []interface{} 59 keyCounter *counter // key计数器 60 sub int // 下标计数器 61 } 62 63 func newSqlTemplate(values []interface{}) *sqlTemplate { 64 return &sqlTemplate{ 65 data: makeMapOfValues(values), 66 keyCounter: newCounter(-1), 67 } 68 } 69 70 func (m *sqlTemplate) calculateTemplate(ss []rune, start int) (int, int, bool, bool) { 71 var crust, has, ok bool 72 // 查找开头 73 for i := start; i < len(ss); i++ { 74 if ss[i] == '{' { 75 start, crust, has = i, true, true 76 break 77 } 78 if _, ok = sqlTemplateOperationMapp[ss[i]]; ok { 79 start, crust, has = i, false, true 80 break 81 } 82 } 83 if !has { 84 return 0, 0, false, false 85 } 86 87 // 预检 88 if crust && (len(ss)-start < 4) || (len(ss)-start < 2) { 89 return 0, 0, false, false 90 } 91 92 if !crust { 93 for i := start + 1; i < len(ss); i++ { 94 _, ok = templateVariableNameMap[ss[i]] 95 if !ok { // 表示查找变量结束了 96 if i-start < 2 || ss[i-1] == '.' { // 操作符占一个位置, 变量长度不可能为0 97 return m.calculateTemplate(ss, i) 98 } 99 return start, i, false, true // 中间的数据就是需要的变量 100 } 101 } 102 // 可能整个字符串都是需要的数据 103 return start, len(ss), false, len(ss)-start >= 2 && ss[len(ss)-1] != '.' 104 } 105 106 // 以下包含{ 107 for i := start + 1; i < len(ss); i++ { 108 if ss[i] != '}' { 109 continue 110 } 111 return start, i + 1, true, true 112 } 113 return 0, 0, false, false 114 } 115 116 func (m *sqlTemplate) replaceAllFunc(s string, fn func(s string, crust bool) string) string { 117 ss := []rune(s) 118 var buff bytes.Buffer 119 for offset := 0; offset < len(ss); { 120 start, end, crust, has := m.calculateTemplate(ss, offset) 121 if !has { 122 buff.WriteString(string(ss[offset:])) 123 break 124 } 125 126 buff.WriteString(string(ss[offset:start])) 127 buff.WriteString(fn(string(ss[start:end]), crust)) 128 offset = end 129 } 130 return buff.String() 131 } 132 133 func (m *sqlTemplate) addValue(name string, value interface{}) { 134 m.names = append(m.names, name) 135 m.values = append(m.values, value) 136 } 137 138 func (m *sqlTemplate) Parse(sql_template string) (sql_str string, names []string, args []interface{}) { 139 sql_str = m.replaceAllFunc(sql_template, func(s string, crust bool) string { 140 if crust { 141 s = s[1 : len(s)-1] 142 } 143 144 operation, name, flag, opts, err := m.sqlTemplateSyntaxParse(s) 145 if err != nil { 146 panic(err) 147 } 148 return m.translate(operation, name, flag, opts) 149 }) 150 return m.repairSql(sql_str), m.names, m.values 151 } 152 153 func (m *sqlTemplate) translate(operation, name, flag string, opts string) string { 154 // 选项检查 155 var attention_opt, direct_opt, must_opt bool 156 for _, o := range opts { 157 switch o { 158 case 'a': 159 attention_opt = true 160 case 'd': 161 direct_opt = true 162 case 'm': 163 must_opt = true 164 default: 165 panic(fmt.Sprintf(`syntax error, non-supported option "%s"`, string(o))) 166 } 167 } 168 switch operation { 169 case "#": 170 attention_opt = true 171 case "@": 172 attention_opt = false 173 direct_opt = true 174 } 175 176 vName := name + "[" + strconv.Itoa(m.keyCounter.Incr(name)) + "]" 177 value, has := m.data[vName] 178 if !has { 179 vName = name 180 value, has = m.data[name] 181 } 182 if !has { 183 vName = "*[" + strconv.Itoa(m.sub) + "]" 184 value, has = m.data[vName] 185 } 186 m.sub++ // 每次一定+1 187 188 // 无值返回空sql语句 189 if !has { 190 if must_opt { 191 panic(fmt.Sprintf(`"%s" must have a value`, name)) 192 } 193 return "" 194 } 195 196 // 非注意模式且值为零值返回空sql语句 197 if !attention_opt && IsZero(value) { 198 return "" 199 } 200 201 // 操作检查 202 switch operation { 203 case "&": 204 operation = "and" 205 case "|": 206 operation = "or" 207 case "#": 208 // nil改为null 209 if value == nil { 210 return "null" 211 } 212 if direct_opt { 213 return anyToSqlString(value, true) 214 } 215 m.addValue(vName, value) 216 return "?" 217 case "@": // !attention_opt + direct 218 return anyToSqlString(value, false) 219 default: 220 panic(fmt.Errorf(`syntax error, non-supported operation "%s"`, operation)) 221 } 222 223 // nil 修改语句 224 if value == nil { 225 switch flag { 226 case "!=", "<>", "notin", "not_in", ">", "<": 227 return fmt.Sprintf(`%s %s is not null`, operation, name) 228 case "=", "like", "likestart", "like_start", "likeend", "like_end": 229 return fmt.Sprintf(`%s %s is null`, operation, name) 230 case "in", ">=", "<=": 231 return "" 232 } 233 } 234 235 var makeSqlStr func() string 236 var directWrite func() string 237 // 标记 238 switch flag { 239 case ">", ">=", "<", "<=", "!=", "<>", "=": 240 makeSqlStr = func() string { 241 m.addValue(vName, value) 242 return fmt.Sprintf(`%s %s %s ?`, operation, name, flag) 243 } 244 directWrite = func() string { 245 return fmt.Sprintf(`%s %s %s %s`, operation, name, flag, anyToSqlString(value, true)) 246 } 247 case "in": 248 values := m.parseToSlice(value) 249 if len(values) == 0 { 250 return "" 251 } 252 makeSqlStr = func() string { 253 if len(values) == 1 { 254 m.addValue(vName, values[0]) 255 return fmt.Sprintf(`%s %s = ?`, operation, name) 256 } 257 fs := make([]string, len(values)) 258 for i, s := range values { 259 m.addValue(fmt.Sprintf("%s.in(%d)", vName, i), s) 260 fs[i] = "?" 261 } 262 return fmt.Sprintf(`%s %s in (%s)`, operation, name, strings.Join(fs, ",")) 263 } 264 directWrite = func() string { 265 if len(values) == 1 { 266 return fmt.Sprintf(`%s %s = %s`, operation, name, anyToSqlString(values[0], true)) 267 } 268 return fmt.Sprintf(`%s %s in %s`, operation, name, anyToSqlString(value, true)) 269 } 270 case "notin", "not_in": 271 values := m.parseToSlice(value) 272 if len(values) == 0 { 273 return "" 274 } 275 makeSqlStr = func() string { 276 if len(values) == 1 { 277 m.addValue(vName, values[0]) 278 return fmt.Sprintf(`%s %s != ?`, operation, name) 279 } 280 fs := make([]string, len(values)) 281 for i, s := range values { 282 m.addValue(fmt.Sprintf("%s.not_in(%d)", vName, i), s) 283 fs[i] = "?" 284 } 285 return fmt.Sprintf(`%s %s not in (%s)`, operation, name, strings.Join(fs, ",")) 286 } 287 directWrite = func() string { 288 if len(values) == 1 { 289 return fmt.Sprintf(`%s %s != %s`, operation, name, anyToSqlString(values[0], true)) 290 } 291 return fmt.Sprintf(`%s %s not in %s`, operation, name, anyToSqlString(value, true)) 292 } 293 case "like": // 包含xx 294 makeSqlStr = func() string { 295 m.addValue(vName, "%"+anyToSqlString(value, false)+"%") 296 return fmt.Sprintf(`%s %s like ?`, operation, name) 297 } 298 directWrite = func() string { 299 return fmt.Sprintf(`%s %s like '%%%s%%'`, operation, name, anyToSqlString(value, false)) 300 } 301 case "likestart", "like_start": // 以xx开始 302 makeSqlStr = func() string { 303 m.addValue(vName, anyToSqlString(value, false)+"%") 304 return fmt.Sprintf(`%s %s like ?`, operation, name) 305 } 306 directWrite = func() string { 307 return fmt.Sprintf(`%s %s like '%s%%'`, operation, name, anyToSqlString(value, false)) 308 } 309 case "likeend", "like_end": // 以xx结束 310 makeSqlStr = func() string { 311 m.addValue(vName, "%"+anyToSqlString(value, false)) 312 return fmt.Sprintf(`%s %s like ?`, operation, name) 313 } 314 directWrite = func() string { 315 return fmt.Sprintf(`%s %s like '%%%s'`, operation, name, anyToSqlString(value, false)) 316 } 317 default: 318 panic(fmt.Errorf(`syntax error, non-supported flag "%s"`, flag)) 319 } 320 321 // 直接模式, 将值写入sql语句 322 if direct_opt { 323 return directWrite() 324 } 325 return makeSqlStr() 326 } 327 328 func (m *sqlTemplate) Render(sql_template string) string { 329 result := m.replaceAllFunc(sql_template, func(s string, crust bool) string { 330 if crust { 331 s = s[1 : len(s)-1] 332 } 333 334 operation, name, flag, opts, err := m.sqlTemplateSyntaxParse(s) 335 if err != nil { 336 panic(err) 337 } 338 return m.sqlTranslate(operation, name, flag, opts) 339 }) 340 return m.repairSql(result) 341 } 342 343 func (m *sqlTemplate) sqlTranslate(operation, name, flag string, opts string) string { 344 // 选项检查 345 var attention_opt, must_opt bool 346 for _, o := range opts { 347 switch o { 348 case 'a': 349 attention_opt = true 350 case 'd': 351 case 'm': 352 must_opt = true 353 default: 354 panic(fmt.Sprintf(`syntax error, non-supported option "%s"`, string(o))) 355 } 356 } 357 switch operation { 358 case "#": 359 attention_opt = true 360 case "@": 361 attention_opt = false 362 } 363 364 value, has := m.data[name+"["+strconv.Itoa(m.keyCounter.Incr(name))+"]"] 365 if !has { 366 value, has = m.data[name] 367 } 368 if !has { 369 value, has = m.data["*["+strconv.Itoa(m.sub)+"]"] 370 } 371 m.sub++ // 每次一定+1 372 373 // 无值返回空sql语句 374 if !has { 375 if must_opt { 376 panic(fmt.Sprintf(`"%s" must have a value`, name)) 377 } 378 return "" 379 } 380 381 // 非注意模式, 零值返回空sql语句 382 if !attention_opt && IsZero(value) { 383 return "" 384 } 385 386 switch operation { 387 case "&": 388 operation = "and" 389 case "|": 390 operation = "or" 391 case "#": 392 // nil改为null 393 if value == nil { 394 return "null" 395 } 396 return anyToSqlString(value, true) 397 case "@": 398 return anyToSqlString(value, false) 399 default: 400 panic(fmt.Errorf(`syntax error, non-supported operation "%s"`, operation)) 401 } 402 403 // nil 修改语句 404 if value == nil { 405 switch flag { 406 case "!=", "<>", "notin", "not_in", ">", "<": 407 return fmt.Sprintf(`%s %s is not null`, operation, name) 408 case "=", "like", "likestart", "like_start", "likeend", "like_end": 409 return fmt.Sprintf(`%s %s is null`, operation, name) 410 case "in", ">=", "<=": 411 return "" 412 } 413 } 414 415 var sql_str string 416 switch flag { 417 case ">", ">=", "<", "<=", "!=", "<>", "=": 418 sql_str = fmt.Sprintf(`%s %s %s %s`, operation, name, flag, anyToSqlString(value, true)) 419 case "in": 420 values := m.parseToSlice(value) 421 if len(values) == 0 { 422 return "" 423 } 424 if len(values) == 1 { 425 return fmt.Sprintf(`%s %s = %s`, operation, name, anyToSqlString(values[0], true)) 426 } 427 sql_str = fmt.Sprintf(`%s %s in %s`, operation, name, anyToSqlString(value, true)) 428 case "notin", "not_in": 429 values := m.parseToSlice(value) 430 if len(values) == 0 { 431 return "" 432 } 433 if len(values) == 1 { 434 return fmt.Sprintf(`%s %s != %s`, operation, name, anyToSqlString(values[0], true)) 435 } 436 sql_str = fmt.Sprintf(`%s %s not in %s`, operation, name, anyToSqlString(value, true)) 437 case "like": // 包含xx 438 sql_str = fmt.Sprintf(`%s %s like '%%%s%%'`, operation, name, anyToSqlString(value, false)) 439 case "likestart", "like_start": // 以xx开始 440 sql_str = fmt.Sprintf(`%s %s like '%s%%'`, operation, name, anyToSqlString(value, false)) 441 case "likeend", "like_end": // 以xx结束 442 sql_str = fmt.Sprintf(`%s %s like '%%%s'`, operation, name, anyToSqlString(value, false)) 443 default: 444 panic(fmt.Errorf(`syntax error, non-supported flag "%s"`, flag)) 445 } 446 447 return sql_str 448 } 449 450 // 将数据解析为切片 451 func (m *sqlTemplate) parseToSlice(a interface{}) []interface{} { 452 switch v := a.(type) { 453 454 case nil: 455 return []interface{}{"null"} 456 457 case string, []byte, bool, 458 int, int8, int16, int32, int64, 459 uint, uint8, uint16, uint32, uint64, 460 float32, float64: 461 return []interface{}{v} 462 } 463 464 r_v := reflect.Indirect(reflect.ValueOf(a)) 465 if r_v.Kind() != reflect.Slice && r_v.Kind() != reflect.Array { 466 return []interface{}{fmt.Sprint(a)} 467 } 468 469 l := r_v.Len() 470 out := make([]interface{}, 0, l) 471 for i := 0; i < l; i++ { 472 v := reflect.Indirect(r_v.Index(i)).Interface() 473 out = append(out, m.parseToSlice(v)...) 474 } 475 return out 476 } 477 478 // sql模板语法解析 479 // 480 // 语法格式: (操作符)(name) 481 // 语法格式: {(操作符)(name)} 482 // 语法格式: {(操作符)(name) (标志)} 483 // 语法格式: {(操作符)(name) (标志) (选项)} 484 // 语法格式: {(操作符)(name) (选项)} 485 // 486 // 操作符: 487 // 488 // &: 转为 and name flag value 489 // |: 转为 or name flag value 490 // #: 转为 value, 自带 attention 选项, 仅支持以下格式 491 // (操作符)(name) 492 // {(操作符)(name)} 493 // {(操作符)(name) (选项)} 494 // @: attention 选项无效且自带 direct 选项, 且不会为字符串加上引号, 仅支持以下格式, 一般用于写入一条语句 495 // (操作符)(name) 496 // {(操作符)(name)} 497 // {(操作符)(name) (选项)} 498 // 499 // name: 示例: a a2 a_2 a_2.b a_2.b_2 500 // 501 // 标志: > >= < <= != <> = in notin not_in like likestart like_start likeend like_end 502 // 503 // 选项: 504 // 505 // a: attention, 不会忽略参数值为该类型的零值 506 // d: direct, 直接将值写入sql语句中 507 // m: must, 必须传值, 值可以为零值 508 // 509 // 输入的values必须为:map[string]string, map[string]interface{},或按顺序传入值 510 // 511 // 寻值优先级: 512 // 513 // 匹配名下标 > 匹配名 > *下标 514 // 如: a[0] > a > *[0] 515 // 516 // 注意: 517 // 518 // 一般情况下如果name没有传参或为该类型的零值, 则替换为空字符串 519 // 如果name的值为nil, 不同的标志会转为不同的语句 520 // 我们不会去检查name是否完全符合变量名标志, 因为这是无意义且消耗资源的 521 // 变量名首位可以为数字, 变量中间可以连续出现多个小数点, 如 0..a 是合法的 522 // 523 // 示例: 524 // 525 // s := SqlRender("select * from t where &a {&b} {&c !=} {&d in} {|e} limit 1", map[string]interface{}{ 526 // "a": 1, 527 // "b[0]": "2", 528 // "*[2]": 3.3, 529 // "d": []string{"4"}, 530 // "e": nil, 531 // }) 532 func (m *sqlTemplate) sqlTemplateSyntaxParse(text string) (operation, name, flag, opts string, err error) { 533 // 去头去尾 534 temp := strings.TrimSpace(text) 535 // 空数据 536 if temp == "" { 537 err = fmt.Errorf("syntax error, {%s}, empty data", text) 538 return 539 } 540 541 // 分离操作符 542 operation, temp = temp[:1], temp[1:] 543 544 // 缩进空格 545 temp = m.retractAllSpace(temp) 546 547 // 分离数据 548 texts := strings.SplitN(temp, " ", 4) // 4为考虑尾部可能有空格 549 if len(texts) >= 1 { 550 name = texts[0] 551 } 552 if len(texts) >= 2 { 553 flag = texts[1] 554 } else { 555 flag = defaultSqlCompareFlag 556 } 557 if len(texts) >= 3 { 558 opts = texts[2] 559 } 560 if len(texts) >= 4 && texts[3] != " " { 561 err = fmt.Errorf("syntax error, {%s}, redundant data", text) 562 return 563 } 564 565 // 检查操作符 566 if _, ok := sqlTemplateOperationMapp[int32(operation[0])]; !ok { 567 err = fmt.Errorf(`syntax error, {%s}, non-supported operation "%s"`, text, operation) 568 return 569 } 570 571 // 检查变量名 572 if name == "" { 573 err = fmt.Errorf("syntax error, {%s}, no variable name", text) 574 return 575 } 576 577 if name[0] == '.' || name[len(name)-1] == '.' { 578 err = fmt.Errorf("syntax error, {%s}, Invalid variable name", text) 579 return 580 } 581 for _, v := range []rune(name) { 582 if _, ok := templateVariableNameMap[v]; !ok { 583 err = fmt.Errorf("syntax error, {%s}, Invalid variable name", text) 584 return 585 } 586 } 587 588 // 检查标记 589 if _, ok := sqlTemplateFlagMapp[flag]; !ok { 590 if opts != "" { 591 err = fmt.Errorf(`syntax error, {%s}, non-supported flag "%s"`, text, flag) 592 return 593 } 594 flag, opts = defaultSqlCompareFlag, flag 595 } 596 597 // 检查选项 598 os := make(map[int32]struct{}) 599 for _, o := range opts { 600 if _, ok := sqlTemplateOptsMapp[o]; !ok { 601 err = fmt.Errorf(`syntax error, {%s}, non-supported option "%s"`, text, string(o)) 602 return 603 } 604 // 重复选项 605 if _, ok := os[o]; ok { 606 err = fmt.Errorf(`syntax error, {%s}, repetitive option "%s"`, text, string(o)) 607 return 608 } 609 os[o] = struct{}{} 610 } 611 612 return 613 } 614 615 // sql模板解析, 和 SqlParse 一样, 只是加长了函数名 616 func SqlTemplateParse(sqlTemplate string, values ...interface{}) (sql_str string, names []string, args []interface{}) { 617 return newSqlTemplate(values).Parse(sqlTemplate) 618 } 619 620 // sql模板解析 621 func SqlParse(sqlTemplate string, values ...interface{}) (sql_str string, names []string, args []interface{}) { 622 return newSqlTemplate(values).Parse(sqlTemplate) 623 } 624 625 // sql模板渲染, 和 SqlRender 一样, 只是加长了函数名 626 func SqlTemplateRender(sqlTemplate string, values ...interface{}) string { 627 return newSqlTemplate(values).Render(sqlTemplate) 628 } 629 630 // sql模板渲染(不推荐) 631 // 632 // 值会直接写入sql语句中, 不支持sql注入检查 633 func SqlRender(sqlTemplate string, values ...interface{}) string { 634 return newSqlTemplate(values).Render(sqlTemplate) 635 }