github.com/zlyuancn/zstr@v0.0.0-20230412074414-14d6b645962f/sql_template.go (about)

     1  /*
     2  -------------------------------------------------
     3     Author :       Zhang Fan
     4     date:         2020/7/18
     5     Description :
     6  -------------------------------------------------
     7  */
     8  
     9  package zstr
    10  
    11  import (
    12  	"bytes"
    13  	"fmt"
    14  	"reflect"
    15  	"strconv"
    16  	"strings"
    17  )
    18  
    19  const defaultSqlCompareFlag = "="
    20  
    21  var (
    22  	// 操作符
    23  	sqlTemplateOperationMapp = map[int32]struct{}{
    24  		'&': {},
    25  		'|': {},
    26  		'#': {},
    27  		'@': {},
    28  	}
    29  	// 标记
    30  	sqlTemplateFlagMapp = map[string]struct{}{
    31  		">":          {},
    32  		">=":         {},
    33  		"<":          {},
    34  		"<=":         {},
    35  		"!=":         {},
    36  		"<>":         {},
    37  		"=":          {},
    38  		"in":         {},
    39  		"notin":      {},
    40  		"not_in":     {},
    41  		"like":       {},
    42  		"likestart":  {},
    43  		"like_start": {},
    44  		"likeend":    {},
    45  		"like_end":   {},
    46  	}
    47  	// 选项
    48  	sqlTemplateOptsMapp = map[int32]struct{}{
    49  		'a': {}, // attention, 不会忽略参数值为该类型的零值
    50  		'd': {}, // direct, 直接将值写入sql语句
    51  		'm': {}, // must, 必填
    52  	}
    53  )
    54  
    55  type sqlTemplate struct {
    56  	data       map[string]interface{}
    57  	names      []string
    58  	values     []interface{}
    59  	keyCounter *counter // key计数器
    60  	sub        int      // 下标计数器
    61  }
    62  
    63  func newSqlTemplate(values []interface{}) *sqlTemplate {
    64  	return &sqlTemplate{
    65  		data:       makeMapOfValues(values),
    66  		keyCounter: newCounter(-1),
    67  	}
    68  }
    69  
    70  func (m *sqlTemplate) calculateTemplate(ss []rune, start int) (int, int, bool, bool) {
    71  	var crust, has, ok bool
    72  	// 查找开头
    73  	for i := start; i < len(ss); i++ {
    74  		if ss[i] == '{' {
    75  			start, crust, has = i, true, true
    76  			break
    77  		}
    78  		if _, ok = sqlTemplateOperationMapp[ss[i]]; ok {
    79  			start, crust, has = i, false, true
    80  			break
    81  		}
    82  	}
    83  	if !has {
    84  		return 0, 0, false, false
    85  	}
    86  
    87  	// 预检
    88  	if crust && (len(ss)-start < 4) || (len(ss)-start < 2) {
    89  		return 0, 0, false, false
    90  	}
    91  
    92  	if !crust {
    93  		for i := start + 1; i < len(ss); i++ {
    94  			_, ok = templateVariableNameMap[ss[i]]
    95  			if !ok { // 表示查找变量结束了
    96  				if i-start < 2 || ss[i-1] == '.' { // 操作符占一个位置, 变量长度不可能为0
    97  					return m.calculateTemplate(ss, i)
    98  				}
    99  				return start, i, false, true // 中间的数据就是需要的变量
   100  			}
   101  		}
   102  		// 可能整个字符串都是需要的数据
   103  		return start, len(ss), false, len(ss)-start >= 2 && ss[len(ss)-1] != '.'
   104  	}
   105  
   106  	// 以下包含{
   107  	for i := start + 1; i < len(ss); i++ {
   108  		if ss[i] != '}' {
   109  			continue
   110  		}
   111  		return start, i + 1, true, true
   112  	}
   113  	return 0, 0, false, false
   114  }
   115  
   116  func (m *sqlTemplate) replaceAllFunc(s string, fn func(s string, crust bool) string) string {
   117  	ss := []rune(s)
   118  	var buff bytes.Buffer
   119  	for offset := 0; offset < len(ss); {
   120  		start, end, crust, has := m.calculateTemplate(ss, offset)
   121  		if !has {
   122  			buff.WriteString(string(ss[offset:]))
   123  			break
   124  		}
   125  
   126  		buff.WriteString(string(ss[offset:start]))
   127  		buff.WriteString(fn(string(ss[start:end]), crust))
   128  		offset = end
   129  	}
   130  	return buff.String()
   131  }
   132  
   133  func (m *sqlTemplate) addValue(name string, value interface{}) {
   134  	m.names = append(m.names, name)
   135  	m.values = append(m.values, value)
   136  }
   137  
   138  func (m *sqlTemplate) Parse(sql_template string) (sql_str string, names []string, args []interface{}) {
   139  	sql_str = m.replaceAllFunc(sql_template, func(s string, crust bool) string {
   140  		if crust {
   141  			s = s[1 : len(s)-1]
   142  		}
   143  
   144  		operation, name, flag, opts, err := m.sqlTemplateSyntaxParse(s)
   145  		if err != nil {
   146  			panic(err)
   147  		}
   148  		return m.translate(operation, name, flag, opts)
   149  	})
   150  	return m.repairSql(sql_str), m.names, m.values
   151  }
   152  
   153  func (m *sqlTemplate) translate(operation, name, flag string, opts string) string {
   154  	// 选项检查
   155  	var attention_opt, direct_opt, must_opt bool
   156  	for _, o := range opts {
   157  		switch o {
   158  		case 'a':
   159  			attention_opt = true
   160  		case 'd':
   161  			direct_opt = true
   162  		case 'm':
   163  			must_opt = true
   164  		default:
   165  			panic(fmt.Sprintf(`syntax error, non-supported option "%s"`, string(o)))
   166  		}
   167  	}
   168  	switch operation {
   169  	case "#":
   170  		attention_opt = true
   171  	case "@":
   172  		attention_opt = false
   173  		direct_opt = true
   174  	}
   175  
   176  	vName := name + "[" + strconv.Itoa(m.keyCounter.Incr(name)) + "]"
   177  	value, has := m.data[vName]
   178  	if !has {
   179  		vName = name
   180  		value, has = m.data[name]
   181  	}
   182  	if !has {
   183  		vName = "*[" + strconv.Itoa(m.sub) + "]"
   184  		value, has = m.data[vName]
   185  	}
   186  	m.sub++ // 每次一定+1
   187  
   188  	// 无值返回空sql语句
   189  	if !has {
   190  		if must_opt {
   191  			panic(fmt.Sprintf(`"%s" must have a value`, name))
   192  		}
   193  		return ""
   194  	}
   195  
   196  	// 非注意模式且值为零值返回空sql语句
   197  	if !attention_opt && IsZero(value) {
   198  		return ""
   199  	}
   200  
   201  	// 操作检查
   202  	switch operation {
   203  	case "&":
   204  		operation = "and"
   205  	case "|":
   206  		operation = "or"
   207  	case "#":
   208  		// nil改为null
   209  		if value == nil {
   210  			return "null"
   211  		}
   212  		if direct_opt {
   213  			return anyToSqlString(value, true)
   214  		}
   215  		m.addValue(vName, value)
   216  		return "?"
   217  	case "@": // !attention_opt + direct
   218  		return anyToSqlString(value, false)
   219  	default:
   220  		panic(fmt.Errorf(`syntax error, non-supported operation "%s"`, operation))
   221  	}
   222  
   223  	// nil 修改语句
   224  	if value == nil {
   225  		switch flag {
   226  		case "!=", "<>", "notin", "not_in", ">", "<":
   227  			return fmt.Sprintf(`%s %s is not null`, operation, name)
   228  		case "=", "like", "likestart", "like_start", "likeend", "like_end":
   229  			return fmt.Sprintf(`%s %s is null`, operation, name)
   230  		case "in", ">=", "<=":
   231  			return ""
   232  		}
   233  	}
   234  
   235  	var makeSqlStr func() string
   236  	var directWrite func() string
   237  	// 标记
   238  	switch flag {
   239  	case ">", ">=", "<", "<=", "!=", "<>", "=":
   240  		makeSqlStr = func() string {
   241  			m.addValue(vName, value)
   242  			return fmt.Sprintf(`%s %s %s ?`, operation, name, flag)
   243  		}
   244  		directWrite = func() string {
   245  			return fmt.Sprintf(`%s %s %s %s`, operation, name, flag, anyToSqlString(value, true))
   246  		}
   247  	case "in":
   248  		values := m.parseToSlice(value)
   249  		if len(values) == 0 {
   250  			return ""
   251  		}
   252  		makeSqlStr = func() string {
   253  			if len(values) == 1 {
   254  				m.addValue(vName, values[0])
   255  				return fmt.Sprintf(`%s %s = ?`, operation, name)
   256  			}
   257  			fs := make([]string, len(values))
   258  			for i, s := range values {
   259  				m.addValue(fmt.Sprintf("%s.in(%d)", vName, i), s)
   260  				fs[i] = "?"
   261  			}
   262  			return fmt.Sprintf(`%s %s in (%s)`, operation, name, strings.Join(fs, ","))
   263  		}
   264  		directWrite = func() string {
   265  			if len(values) == 1 {
   266  				return fmt.Sprintf(`%s %s = %s`, operation, name, anyToSqlString(values[0], true))
   267  			}
   268  			return fmt.Sprintf(`%s %s in %s`, operation, name, anyToSqlString(value, true))
   269  		}
   270  	case "notin", "not_in":
   271  		values := m.parseToSlice(value)
   272  		if len(values) == 0 {
   273  			return ""
   274  		}
   275  		makeSqlStr = func() string {
   276  			if len(values) == 1 {
   277  				m.addValue(vName, values[0])
   278  				return fmt.Sprintf(`%s %s != ?`, operation, name)
   279  			}
   280  			fs := make([]string, len(values))
   281  			for i, s := range values {
   282  				m.addValue(fmt.Sprintf("%s.not_in(%d)", vName, i), s)
   283  				fs[i] = "?"
   284  			}
   285  			return fmt.Sprintf(`%s %s not in (%s)`, operation, name, strings.Join(fs, ","))
   286  		}
   287  		directWrite = func() string {
   288  			if len(values) == 1 {
   289  				return fmt.Sprintf(`%s %s != %s`, operation, name, anyToSqlString(values[0], true))
   290  			}
   291  			return fmt.Sprintf(`%s %s not in %s`, operation, name, anyToSqlString(value, true))
   292  		}
   293  	case "like": // 包含xx
   294  		makeSqlStr = func() string {
   295  			m.addValue(vName, "%"+anyToSqlString(value, false)+"%")
   296  			return fmt.Sprintf(`%s %s like ?`, operation, name)
   297  		}
   298  		directWrite = func() string {
   299  			return fmt.Sprintf(`%s %s like '%%%s%%'`, operation, name, anyToSqlString(value, false))
   300  		}
   301  	case "likestart", "like_start": // 以xx开始
   302  		makeSqlStr = func() string {
   303  			m.addValue(vName, anyToSqlString(value, false)+"%")
   304  			return fmt.Sprintf(`%s %s like ?`, operation, name)
   305  		}
   306  		directWrite = func() string {
   307  			return fmt.Sprintf(`%s %s like '%s%%'`, operation, name, anyToSqlString(value, false))
   308  		}
   309  	case "likeend", "like_end": // 以xx结束
   310  		makeSqlStr = func() string {
   311  			m.addValue(vName, "%"+anyToSqlString(value, false))
   312  			return fmt.Sprintf(`%s %s like ?`, operation, name)
   313  		}
   314  		directWrite = func() string {
   315  			return fmt.Sprintf(`%s %s like '%%%s'`, operation, name, anyToSqlString(value, false))
   316  		}
   317  	default:
   318  		panic(fmt.Errorf(`syntax error, non-supported flag "%s"`, flag))
   319  	}
   320  
   321  	// 直接模式, 将值写入sql语句
   322  	if direct_opt {
   323  		return directWrite()
   324  	}
   325  	return makeSqlStr()
   326  }
   327  
   328  func (m *sqlTemplate) Render(sql_template string) string {
   329  	result := m.replaceAllFunc(sql_template, func(s string, crust bool) string {
   330  		if crust {
   331  			s = s[1 : len(s)-1]
   332  		}
   333  
   334  		operation, name, flag, opts, err := m.sqlTemplateSyntaxParse(s)
   335  		if err != nil {
   336  			panic(err)
   337  		}
   338  		return m.sqlTranslate(operation, name, flag, opts)
   339  	})
   340  	return m.repairSql(result)
   341  }
   342  
   343  func (m *sqlTemplate) sqlTranslate(operation, name, flag string, opts string) string {
   344  	// 选项检查
   345  	var attention_opt, must_opt bool
   346  	for _, o := range opts {
   347  		switch o {
   348  		case 'a':
   349  			attention_opt = true
   350  		case 'd':
   351  		case 'm':
   352  			must_opt = true
   353  		default:
   354  			panic(fmt.Sprintf(`syntax error, non-supported option "%s"`, string(o)))
   355  		}
   356  	}
   357  	switch operation {
   358  	case "#":
   359  		attention_opt = true
   360  	case "@":
   361  		attention_opt = false
   362  	}
   363  
   364  	value, has := m.data[name+"["+strconv.Itoa(m.keyCounter.Incr(name))+"]"]
   365  	if !has {
   366  		value, has = m.data[name]
   367  	}
   368  	if !has {
   369  		value, has = m.data["*["+strconv.Itoa(m.sub)+"]"]
   370  	}
   371  	m.sub++ // 每次一定+1
   372  
   373  	// 无值返回空sql语句
   374  	if !has {
   375  		if must_opt {
   376  			panic(fmt.Sprintf(`"%s" must have a value`, name))
   377  		}
   378  		return ""
   379  	}
   380  
   381  	// 非注意模式, 零值返回空sql语句
   382  	if !attention_opt && IsZero(value) {
   383  		return ""
   384  	}
   385  
   386  	switch operation {
   387  	case "&":
   388  		operation = "and"
   389  	case "|":
   390  		operation = "or"
   391  	case "#":
   392  		// nil改为null
   393  		if value == nil {
   394  			return "null"
   395  		}
   396  		return anyToSqlString(value, true)
   397  	case "@":
   398  		return anyToSqlString(value, false)
   399  	default:
   400  		panic(fmt.Errorf(`syntax error, non-supported operation "%s"`, operation))
   401  	}
   402  
   403  	// nil 修改语句
   404  	if value == nil {
   405  		switch flag {
   406  		case "!=", "<>", "notin", "not_in", ">", "<":
   407  			return fmt.Sprintf(`%s %s is not null`, operation, name)
   408  		case "=", "like", "likestart", "like_start", "likeend", "like_end":
   409  			return fmt.Sprintf(`%s %s is null`, operation, name)
   410  		case "in", ">=", "<=":
   411  			return ""
   412  		}
   413  	}
   414  
   415  	var sql_str string
   416  	switch flag {
   417  	case ">", ">=", "<", "<=", "!=", "<>", "=":
   418  		sql_str = fmt.Sprintf(`%s %s %s %s`, operation, name, flag, anyToSqlString(value, true))
   419  	case "in":
   420  		values := m.parseToSlice(value)
   421  		if len(values) == 0 {
   422  			return ""
   423  		}
   424  		if len(values) == 1 {
   425  			return fmt.Sprintf(`%s %s = %s`, operation, name, anyToSqlString(values[0], true))
   426  		}
   427  		sql_str = fmt.Sprintf(`%s %s in %s`, operation, name, anyToSqlString(value, true))
   428  	case "notin", "not_in":
   429  		values := m.parseToSlice(value)
   430  		if len(values) == 0 {
   431  			return ""
   432  		}
   433  		if len(values) == 1 {
   434  			return fmt.Sprintf(`%s %s != %s`, operation, name, anyToSqlString(values[0], true))
   435  		}
   436  		sql_str = fmt.Sprintf(`%s %s not in %s`, operation, name, anyToSqlString(value, true))
   437  	case "like": // 包含xx
   438  		sql_str = fmt.Sprintf(`%s %s like '%%%s%%'`, operation, name, anyToSqlString(value, false))
   439  	case "likestart", "like_start": // 以xx开始
   440  		sql_str = fmt.Sprintf(`%s %s like '%s%%'`, operation, name, anyToSqlString(value, false))
   441  	case "likeend", "like_end": // 以xx结束
   442  		sql_str = fmt.Sprintf(`%s %s like '%%%s'`, operation, name, anyToSqlString(value, false))
   443  	default:
   444  		panic(fmt.Errorf(`syntax error, non-supported flag "%s"`, flag))
   445  	}
   446  
   447  	return sql_str
   448  }
   449  
   450  // 将数据解析为切片
   451  func (m *sqlTemplate) parseToSlice(a interface{}) []interface{} {
   452  	switch v := a.(type) {
   453  
   454  	case nil:
   455  		return []interface{}{"null"}
   456  
   457  	case string, []byte, bool,
   458  		int, int8, int16, int32, int64,
   459  		uint, uint8, uint16, uint32, uint64,
   460  		float32, float64:
   461  		return []interface{}{v}
   462  	}
   463  
   464  	r_v := reflect.Indirect(reflect.ValueOf(a))
   465  	if r_v.Kind() != reflect.Slice && r_v.Kind() != reflect.Array {
   466  		return []interface{}{fmt.Sprint(a)}
   467  	}
   468  
   469  	l := r_v.Len()
   470  	out := make([]interface{}, 0, l)
   471  	for i := 0; i < l; i++ {
   472  		v := reflect.Indirect(r_v.Index(i)).Interface()
   473  		out = append(out, m.parseToSlice(v)...)
   474  	}
   475  	return out
   476  }
   477  
   478  // sql模板语法解析
   479  //
   480  // 语法格式:   (操作符)(name)
   481  // 语法格式:   {(操作符)(name)}
   482  // 语法格式:   {(操作符)(name) (标志)}
   483  // 语法格式:   {(操作符)(name) (标志) (选项)}
   484  // 语法格式:   {(操作符)(name) (选项)}
   485  //
   486  // 操作符:
   487  //
   488  //	&: 转为 and name flag value
   489  //	|: 转为 or name flag value
   490  //	#: 转为 value, 自带 attention 选项, 仅支持以下格式
   491  //	     (操作符)(name)
   492  //	     {(操作符)(name)}
   493  //	     {(操作符)(name) (选项)}
   494  //	@: attention 选项无效且自带 direct 选项, 且不会为字符串加上引号, 仅支持以下格式, 一般用于写入一条语句
   495  //	     (操作符)(name)
   496  //	     {(操作符)(name)}
   497  //	     {(操作符)(name) (选项)}
   498  //
   499  // name:   示例:    a   a2   a_2   a_2.b   a_2.b_2
   500  //
   501  // 标志:   >   >=   <   <=   !=   <>   =   in   notin   not_in   like   likestart    like_start   likeend   like_end
   502  //
   503  // 选项:
   504  //
   505  //	a:   attention, 不会忽略参数值为该类型的零值
   506  //	d:   direct, 直接将值写入sql语句中
   507  //	m:   must, 必须传值, 值可以为零值
   508  //
   509  // 输入的values必须为:map[string]string, map[string]interface{},或按顺序传入值
   510  //
   511  // 寻值优先级:
   512  //
   513  //	匹配名下标 > 匹配名 > *下标
   514  //	如:  a[0] > a > *[0]
   515  //
   516  // 注意:
   517  //
   518  //	一般情况下如果name没有传参或为该类型的零值, 则替换为空字符串
   519  //	如果name的值为nil, 不同的标志会转为不同的语句
   520  //	我们不会去检查name是否完全符合变量名标志, 因为这是无意义且消耗资源的
   521  //	    变量名首位可以为数字, 变量中间可以连续出现多个小数点, 如 0..a 是合法的
   522  //
   523  // 示例:
   524  //
   525  //	   s := SqlRender("select * from t where &a {&b} {&c !=} {&d in} {|e} limit 1", map[string]interface{}{
   526  //			"a": 1,
   527  //			"b[0]": "2",
   528  //			"*[2]": 3.3,
   529  //			"d": []string{"4"},
   530  //			"e": nil,
   531  //		  })
   532  func (m *sqlTemplate) sqlTemplateSyntaxParse(text string) (operation, name, flag, opts string, err error) {
   533  	// 去头去尾
   534  	temp := strings.TrimSpace(text)
   535  	// 空数据
   536  	if temp == "" {
   537  		err = fmt.Errorf("syntax error, {%s}, empty data", text)
   538  		return
   539  	}
   540  
   541  	// 分离操作符
   542  	operation, temp = temp[:1], temp[1:]
   543  
   544  	// 缩进空格
   545  	temp = m.retractAllSpace(temp)
   546  
   547  	// 分离数据
   548  	texts := strings.SplitN(temp, " ", 4) // 4为考虑尾部可能有空格
   549  	if len(texts) >= 1 {
   550  		name = texts[0]
   551  	}
   552  	if len(texts) >= 2 {
   553  		flag = texts[1]
   554  	} else {
   555  		flag = defaultSqlCompareFlag
   556  	}
   557  	if len(texts) >= 3 {
   558  		opts = texts[2]
   559  	}
   560  	if len(texts) >= 4 && texts[3] != " " {
   561  		err = fmt.Errorf("syntax error, {%s}, redundant data", text)
   562  		return
   563  	}
   564  
   565  	// 检查操作符
   566  	if _, ok := sqlTemplateOperationMapp[int32(operation[0])]; !ok {
   567  		err = fmt.Errorf(`syntax error, {%s}, non-supported operation "%s"`, text, operation)
   568  		return
   569  	}
   570  
   571  	// 检查变量名
   572  	if name == "" {
   573  		err = fmt.Errorf("syntax error, {%s}, no variable name", text)
   574  		return
   575  	}
   576  
   577  	if name[0] == '.' || name[len(name)-1] == '.' {
   578  		err = fmt.Errorf("syntax error, {%s}, Invalid variable name", text)
   579  		return
   580  	}
   581  	for _, v := range []rune(name) {
   582  		if _, ok := templateVariableNameMap[v]; !ok {
   583  			err = fmt.Errorf("syntax error, {%s}, Invalid variable name", text)
   584  			return
   585  		}
   586  	}
   587  
   588  	// 检查标记
   589  	if _, ok := sqlTemplateFlagMapp[flag]; !ok {
   590  		if opts != "" {
   591  			err = fmt.Errorf(`syntax error, {%s}, non-supported flag "%s"`, text, flag)
   592  			return
   593  		}
   594  		flag, opts = defaultSqlCompareFlag, flag
   595  	}
   596  
   597  	// 检查选项
   598  	os := make(map[int32]struct{})
   599  	for _, o := range opts {
   600  		if _, ok := sqlTemplateOptsMapp[o]; !ok {
   601  			err = fmt.Errorf(`syntax error, {%s}, non-supported option "%s"`, text, string(o))
   602  			return
   603  		}
   604  		// 重复选项
   605  		if _, ok := os[o]; ok {
   606  			err = fmt.Errorf(`syntax error, {%s}, repetitive option "%s"`, text, string(o))
   607  			return
   608  		}
   609  		os[o] = struct{}{}
   610  	}
   611  
   612  	return
   613  }
   614  
   615  // sql模板解析, 和 SqlParse 一样, 只是加长了函数名
   616  func SqlTemplateParse(sqlTemplate string, values ...interface{}) (sql_str string, names []string, args []interface{}) {
   617  	return newSqlTemplate(values).Parse(sqlTemplate)
   618  }
   619  
   620  // sql模板解析
   621  func SqlParse(sqlTemplate string, values ...interface{}) (sql_str string, names []string, args []interface{}) {
   622  	return newSqlTemplate(values).Parse(sqlTemplate)
   623  }
   624  
   625  // sql模板渲染, 和 SqlRender 一样, 只是加长了函数名
   626  func SqlTemplateRender(sqlTemplate string, values ...interface{}) string {
   627  	return newSqlTemplate(values).Render(sqlTemplate)
   628  }
   629  
   630  // sql模板渲染(不推荐)
   631  //
   632  // 值会直接写入sql语句中, 不支持sql注入检查
   633  func SqlRender(sqlTemplate string, values ...interface{}) string {
   634  	return newSqlTemplate(values).Render(sqlTemplate)
   635  }