gitee.com/runner.mei/dm@v0.0.0-20220207044607-a9ba0dc20bf7/zzk.go (about)

     1  /*
     2   * Copyright (c) 2000-2018, 达梦数据库有限公司.
     3   * All rights reserved.
     4   */
     5  package dm
     6  
     7  import (
     8  	"bytes"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"gitee.com/runner.mei/dm/parser"
    13  	"gitee.com/runner.mei/dm/util"
    14  )
    15  
    16  func (dc *DmConnection) lex(sql string) ([]*parser.LVal, error) {
    17  	if dc.lexer == nil {
    18  		dc.lexer = parser.NewLexer(strings.NewReader(sql), false)
    19  	} else {
    20  		dc.lexer.Reset(strings.NewReader(sql))
    21  	}
    22  
    23  	lexer := dc.lexer
    24  	var lval *parser.LVal
    25  	var err error
    26  	lvalList := make([]*parser.LVal, 0, 64)
    27  	lval, err = lexer.Yylex()
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	for lval != nil {
    33  		lvalList = append(lvalList, lval)
    34  		lval.Position = len(lvalList)
    35  		lval, err = lexer.Yylex()
    36  		if err != nil {
    37  			return nil, err
    38  		}
    39  	}
    40  
    41  	return lvalList, nil
    42  }
    43  
    44  func lexSkipWhitespace(sql string, n int) ([]*parser.LVal, error) {
    45  	lexer := parser.NewLexer(strings.NewReader(sql), false)
    46  
    47  	var lval *parser.LVal
    48  	var err error
    49  	lvalList := make([]*parser.LVal, 0, 64)
    50  	lval, err = lexer.Yylex()
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	for lval != nil && n > 0 {
    56  		lval.Position = len(lvalList)
    57  		if lval.Tp == parser.WHITESPACE_OR_COMMENT {
    58  			continue
    59  		}
    60  
    61  		lvalList = append(lvalList, lval)
    62  		n--
    63  		lval, err = lexer.Yylex()
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  
    68  	}
    69  
    70  	return lvalList, nil
    71  }
    72  
    73  func (dc *DmConnection) escape(sql string, keywords []string) (string, error) {
    74  
    75  	if (keywords == nil || len(keywords) == 0) && strings.Index(sql, "{") == -1 {
    76  		return sql, nil
    77  	}
    78  	var keywordMap map[string]interface{}
    79  	if keywords != nil && len(keywords) > 0 {
    80  		keywordMap = make(map[string]interface{}, len(keywords))
    81  		for _, keyword := range keywords {
    82  			keywordMap[strings.ToUpper(keyword)] = nil
    83  		}
    84  	}
    85  	nsql := bytes.NewBufferString("")
    86  	stack := make([]bool, 0, 64)
    87  	lvalList, err := dc.lex(sql)
    88  	if err != nil {
    89  		return "", err
    90  	}
    91  
    92  	for i := 0; i < len(lvalList); i++ {
    93  		lval0 := lvalList[i]
    94  		if lval0.Tp == parser.NORMAL {
    95  			if lval0.Value == "{" {
    96  				lval1 := next(lvalList, i+1)
    97  				if lval1 == nil || lval1.Tp != parser.NORMAL {
    98  					stack = append(stack, false)
    99  					nsql.WriteString(lval0.Value)
   100  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "escape") || util.StringUtil.EqualsIgnoreCase(lval1.Value, "call") {
   101  					stack = append(stack, true)
   102  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "oj") {
   103  					stack = append(stack, true)
   104  					lval1.Value = ""
   105  					lval1.Tp = parser.WHITESPACE_OR_COMMENT
   106  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "d") {
   107  					stack = append(stack, true)
   108  					lval1.Value = "date"
   109  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "t") {
   110  					stack = append(stack, true)
   111  					lval1.Value = "time"
   112  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "ts") {
   113  					stack = append(stack, true)
   114  					lval1.Value = "datetime"
   115  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "fn") {
   116  					stack = append(stack, true)
   117  					lval1.Value = ""
   118  					lval1.Tp = parser.WHITESPACE_OR_COMMENT
   119  					lval2 := next(lvalList, lval1.Position+1)
   120  					if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "database") {
   121  						lval2.Value = "cur_database"
   122  					}
   123  				} else if util.StringUtil.Equals(lval1.Value, "?") {
   124  					lval2 := next(lvalList, lval1.Position+1)
   125  					if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "=") {
   126  						lval3 := next(lvalList, lval2.Position+1)
   127  						if lval3 != nil && lval3.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval3.Value, "call") {
   128  							stack = append(stack, true)
   129  							lval3.Value = ""
   130  							lval3.Tp = parser.WHITESPACE_OR_COMMENT
   131  						} else {
   132  							stack = append(stack, false)
   133  							nsql.WriteString(lval0.Value)
   134  						}
   135  					} else {
   136  						stack = append(stack, false)
   137  						nsql.WriteString(lval0.Value)
   138  					}
   139  				} else {
   140  					stack = append(stack, false)
   141  					nsql.WriteString(lval0.Value)
   142  				}
   143  			} else if util.StringUtil.Equals(lval0.Value, "}") {
   144  				if len(stack) != 0 && stack[len(stack)-1] {
   145  
   146  				} else {
   147  					nsql.WriteString(lval0.Value)
   148  				}
   149  				stack = stack[:len(stack)-1]
   150  			} else {
   151  				if keywordMap != nil {
   152  					_, ok := keywordMap[strings.ToUpper(lval0.Value)]
   153  					if ok {
   154  						nsql.WriteString("\"" + util.StringUtil.ProcessDoubleQuoteOfName(strings.ToUpper(lval0.Value)) + "\"")
   155  					} else {
   156  						nsql.WriteString(lval0.Value)
   157  					}
   158  				} else {
   159  					nsql.WriteString(lval0.Value)
   160  				}
   161  			}
   162  		} else if lval0.Tp == parser.STRING {
   163  			nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval0.Value) + "'")
   164  		} else {
   165  			nsql.WriteString(lval0.Value)
   166  		}
   167  	}
   168  
   169  	return nsql.String(), nil
   170  }
   171  
   172  func next(lvalList []*parser.LVal, start int) *parser.LVal {
   173  	var lval *parser.LVal
   174  
   175  	size := len(lvalList)
   176  	for i := start; i < size; i++ {
   177  		lval = lvalList[i]
   178  		if lval.Tp != parser.WHITESPACE_OR_COMMENT {
   179  			break
   180  		}
   181  	}
   182  	return lval
   183  }
   184  
   185  func (dc *DmConnection) execOpt(sql string, optParamList []OptParameter, serverEncoding string) (string, []OptParameter, error) {
   186  	nsql := bytes.NewBufferString("")
   187  
   188  	lvalList, err := dc.lex(sql)
   189  	if err != nil {
   190  		return "", optParamList, err
   191  	}
   192  
   193  	if nil == lvalList || len(lvalList) == 0 {
   194  		return sql, optParamList, nil
   195  	}
   196  
   197  	firstWord := lvalList[0].Value
   198  	if !(util.StringUtil.EqualsIgnoreCase(firstWord, "INSERT") || util.StringUtil.EqualsIgnoreCase(firstWord, "SELECT") ||
   199  		util.StringUtil.EqualsIgnoreCase(firstWord, "UPDATE") || util.StringUtil.EqualsIgnoreCase(firstWord, "DELETE")) {
   200  		return sql, optParamList, nil
   201  	}
   202  
   203  	breakIndex := 0
   204  	for i := 0; i < len(lvalList); i++ {
   205  		lval := lvalList[i]
   206  		switch lval.Tp {
   207  		case parser.NULL:
   208  			{
   209  				nsql.WriteString("?")
   210  				optParamList = append(optParamList, newOptParameter(nil, NULL, NULL_PREC))
   211  			}
   212  		case parser.INT:
   213  			{
   214  				nsql.WriteString("?")
   215  				value, err := strconv.Atoi(lval.Value)
   216  				if err != nil {
   217  					return "", optParamList, err
   218  				}
   219  
   220  				if value <= int(INT32_MAX) && value >= int(INT32_MIN) {
   221  					optParamList = append(optParamList, newOptParameter(G2DB.toInt32(int32(value)), INT, INT_PREC))
   222  
   223  				} else {
   224  					optParamList = append(optParamList, newOptParameter(G2DB.toInt64(int64(value)), BIGINT, BIGINT_PREC))
   225  				}
   226  			}
   227  		case parser.DOUBLE:
   228  			{
   229  				nsql.WriteString("?")
   230  				f, err := strconv.ParseFloat(lval.Value, 64)
   231  				if err != nil {
   232  					return "", optParamList, err
   233  				}
   234  
   235  				optParamList = append(optParamList, newOptParameter(G2DB.toFloat64(f), DOUBLE, DOUBLE_PREC))
   236  			}
   237  		case parser.DECIMAL:
   238  			{
   239  				nsql.WriteString("?")
   240  				bytes, err := G2DB.toDecimal(lval.Value, 0, 0)
   241  				if err != nil {
   242  					return "", optParamList, err
   243  				}
   244  				optParamList = append(optParamList, newOptParameter(bytes, DECIMAL, 0))
   245  			}
   246  		case parser.STRING:
   247  			{
   248  
   249  				if len(lval.Value) > int(INT16_MAX) {
   250  
   251  					nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval.Value) + "'")
   252  				} else {
   253  					nsql.WriteString("?")
   254  					optParamList = append(optParamList, newOptParameter(Dm_build_623.Dm_build_836(lval.Value, serverEncoding, dc), VARCHAR, VARCHAR_PREC))
   255  				}
   256  			}
   257  		case parser.HEX_INT:
   258  
   259  			nsql.WriteString(lval.Value)
   260  		default:
   261  
   262  			nsql.WriteString(lval.Value)
   263  		}
   264  
   265  		if breakIndex > 0 {
   266  			break
   267  		}
   268  	}
   269  
   270  	if breakIndex > 0 {
   271  		for i := breakIndex + 1; i < len(lvalList); i++ {
   272  			nsql.WriteString(lvalList[i].Value)
   273  		}
   274  	}
   275  
   276  	return nsql.String(), optParamList, nil
   277  }
   278  
   279  func (dc *DmConnection) hasConst(sql string) (bool, error) {
   280  	lvalList, err := dc.lex(sql)
   281  	if err != nil {
   282  		return false, err
   283  	}
   284  
   285  	if nil == lvalList || len(lvalList) == 0 {
   286  		return false, nil
   287  	}
   288  
   289  	for i := 0; i < len(lvalList); i++ {
   290  		switch lvalList[i].Tp {
   291  		case parser.NULL, parser.INT, parser.DOUBLE, parser.DECIMAL, parser.STRING, parser.HEX_INT:
   292  			return true, nil
   293  		}
   294  	}
   295  	return false, nil
   296  }
   297  
   298  type OptParameter struct {
   299  	bytes  []byte
   300  	ioType byte
   301  	tp     int
   302  	prec   int
   303  	scale  int
   304  }
   305  
   306  func newOptParameter(bytes []byte, tp int, prec int) OptParameter {
   307  	o := new(OptParameter)
   308  	o.bytes = bytes
   309  	o.tp = tp
   310  	o.prec = prec
   311  	return *o
   312  }
   313  
   314  func (parameter *OptParameter) String() string {
   315  	if parameter.bytes == nil {
   316  		return ""
   317  	}
   318  	return string(parameter.bytes)
   319  }