gitee.com/curryzheng/dm@v0.0.1/zzl.go (about)

     1  /*
     2   * Copyright (c) 2000-2018, 达梦数据库有限公司.
     3   * All rights reserved.
     4   */
     5  package dm
     6  
     7  import (
     8  	"bytes"
     9  	"gitee.com/curryzheng/dm/parser"
    10  	"gitee.com/curryzheng/dm/util"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  func (dc *DmConnection) lex(sql string) ([]*parser.LVal, error) {
    16  	if dc.lexer == nil {
    17  		dc.lexer = parser.NewLexer(strings.NewReader(sql), false)
    18  	} else {
    19  		dc.lexer.Reset(strings.NewReader(sql))
    20  	}
    21  
    22  	lexer := dc.lexer
    23  	var lval *parser.LVal
    24  	var err error
    25  	lvalList := make([]*parser.LVal, 0, 64)
    26  	lval, err = lexer.Yylex()
    27  	if err != nil {
    28  		return nil, err
    29  	}
    30  
    31  	for lval != nil {
    32  		lvalList = append(lvalList, lval)
    33  		lval.Position = len(lvalList)
    34  		lval, err = lexer.Yylex()
    35  		if err != nil {
    36  			return nil, err
    37  		}
    38  	}
    39  
    40  	return lvalList, nil
    41  }
    42  
    43  func lexSkipWhitespace(sql string, n int) ([]*parser.LVal, error) {
    44  	lexer := parser.NewLexer(strings.NewReader(sql), false)
    45  
    46  	var lval *parser.LVal
    47  	var err error
    48  	lvalList := make([]*parser.LVal, 0, 64)
    49  	lval, err = lexer.Yylex()
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	for lval != nil && n > 0 {
    55  		lval.Position = len(lvalList)
    56  		if lval.Tp == parser.WHITESPACE_OR_COMMENT {
    57  			continue
    58  		}
    59  
    60  		lvalList = append(lvalList, lval)
    61  		n--
    62  		lval, err = lexer.Yylex()
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  
    67  	}
    68  
    69  	return lvalList, nil
    70  }
    71  
    72  func (dc *DmConnection) escape(sql string, keywords []string) (string, error) {
    73  
    74  	if (keywords == nil || len(keywords) == 0) && strings.Index(sql, "{") == -1 {
    75  		return sql, nil
    76  	}
    77  	var keywordMap map[string]interface{}
    78  	if keywords != nil && len(keywords) > 0 {
    79  		keywordMap = make(map[string]interface{}, len(keywords))
    80  		for _, keyword := range keywords {
    81  			keywordMap[strings.ToUpper(keyword)] = nil
    82  		}
    83  	}
    84  	nsql := bytes.NewBufferString("")
    85  	stack := make([]bool, 0, 64)
    86  	lvalList, err := dc.lex(sql)
    87  	if err != nil {
    88  		return "", err
    89  	}
    90  
    91  	for i := 0; i < len(lvalList); i++ {
    92  		lval0 := lvalList[i]
    93  		if lval0.Tp == parser.NORMAL {
    94  			if lval0.Value == "{" {
    95  				lval1 := next(lvalList, i+1)
    96  				if lval1 == nil || lval1.Tp != parser.NORMAL {
    97  					stack = append(stack, false)
    98  					nsql.WriteString(lval0.Value)
    99  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "escape") || util.StringUtil.EqualsIgnoreCase(lval1.Value, "call") {
   100  					stack = append(stack, true)
   101  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "oj") {
   102  					stack = append(stack, true)
   103  					lval1.Value = ""
   104  					lval1.Tp = parser.WHITESPACE_OR_COMMENT
   105  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "d") {
   106  					stack = append(stack, true)
   107  					lval1.Value = "date"
   108  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "t") {
   109  					stack = append(stack, true)
   110  					lval1.Value = "time"
   111  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "ts") {
   112  					stack = append(stack, true)
   113  					lval1.Value = "datetime"
   114  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "fn") {
   115  					stack = append(stack, true)
   116  					lval1.Value = ""
   117  					lval1.Tp = parser.WHITESPACE_OR_COMMENT
   118  					lval2 := next(lvalList, lval1.Position+1)
   119  					if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "database") {
   120  						lval2.Value = "cur_database"
   121  					}
   122  				} else if util.StringUtil.Equals(lval1.Value, "?") {
   123  					lval2 := next(lvalList, lval1.Position+1)
   124  					if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "=") {
   125  						lval3 := next(lvalList, lval2.Position+1)
   126  						if lval3 != nil && lval3.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval3.Value, "call") {
   127  							stack = append(stack, true)
   128  							lval3.Value = ""
   129  							lval3.Tp = parser.WHITESPACE_OR_COMMENT
   130  						} else {
   131  							stack = append(stack, false)
   132  							nsql.WriteString(lval0.Value)
   133  						}
   134  					} else {
   135  						stack = append(stack, false)
   136  						nsql.WriteString(lval0.Value)
   137  					}
   138  				} else {
   139  					stack = append(stack, false)
   140  					nsql.WriteString(lval0.Value)
   141  				}
   142  			} else if util.StringUtil.Equals(lval0.Value, "}") {
   143  				if len(stack) != 0 && stack[len(stack)-1] {
   144  
   145  				} else {
   146  					nsql.WriteString(lval0.Value)
   147  				}
   148  				stack = stack[:len(stack)-1]
   149  			} else {
   150  				if keywordMap != nil {
   151  					_, ok := keywordMap[strings.ToUpper(lval0.Value)]
   152  					if ok {
   153  						nsql.WriteString("\"" + util.StringUtil.ProcessDoubleQuoteOfName(strings.ToUpper(lval0.Value)) + "\"")
   154  					} else {
   155  						nsql.WriteString(lval0.Value)
   156  					}
   157  				} else {
   158  					nsql.WriteString(lval0.Value)
   159  				}
   160  			}
   161  		} else if lval0.Tp == parser.STRING {
   162  			nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval0.Value) + "'")
   163  		} else {
   164  			nsql.WriteString(lval0.Value)
   165  		}
   166  	}
   167  
   168  	return nsql.String(), nil
   169  }
   170  
   171  func next(lvalList []*parser.LVal, start int) *parser.LVal {
   172  	var lval *parser.LVal
   173  
   174  	size := len(lvalList)
   175  	for i := start; i < size; i++ {
   176  		lval = lvalList[i]
   177  		if lval.Tp != parser.WHITESPACE_OR_COMMENT {
   178  			break
   179  		}
   180  	}
   181  	return lval
   182  }
   183  
   184  func (dc *DmConnection) execOpt(sql string, optParamList []OptParameter, serverEncoding string) (string, []OptParameter, error) {
   185  	nsql := bytes.NewBufferString("")
   186  
   187  	lvalList, err := dc.lex(sql)
   188  	if err != nil {
   189  		return "", optParamList, err
   190  	}
   191  
   192  	if nil == lvalList || len(lvalList) == 0 {
   193  		return sql, optParamList, nil
   194  	}
   195  
   196  	firstWord := lvalList[0].Value
   197  	if !(util.StringUtil.EqualsIgnoreCase(firstWord, "INSERT") || util.StringUtil.EqualsIgnoreCase(firstWord, "SELECT") ||
   198  		util.StringUtil.EqualsIgnoreCase(firstWord, "UPDATE") || util.StringUtil.EqualsIgnoreCase(firstWord, "DELETE")) {
   199  		return sql, optParamList, nil
   200  	}
   201  
   202  	breakIndex := 0
   203  	for i := 0; i < len(lvalList); i++ {
   204  		lval := lvalList[i]
   205  		switch lval.Tp {
   206  		case parser.NULL:
   207  			{
   208  				nsql.WriteString("?")
   209  				optParamList = append(optParamList, newOptParameter(nil, NULL, NULL_PREC))
   210  			}
   211  		case parser.INT:
   212  			{
   213  				nsql.WriteString("?")
   214  				value, err := strconv.Atoi(lval.Value)
   215  				if err != nil {
   216  					return "", optParamList, err
   217  				}
   218  
   219  				if value <= int(INT32_MAX) && value >= int(INT32_MIN) {
   220  					optParamList = append(optParamList, newOptParameter(G2DB.toInt32(int32(value)), INT, INT_PREC))
   221  
   222  				} else {
   223  					optParamList = append(optParamList, newOptParameter(G2DB.toInt64(int64(value)), BIGINT, BIGINT_PREC))
   224  				}
   225  			}
   226  		case parser.DOUBLE:
   227  			{
   228  				nsql.WriteString("?")
   229  				f, err := strconv.ParseFloat(lval.Value, 64)
   230  				if err != nil {
   231  					return "", optParamList, err
   232  				}
   233  
   234  				optParamList = append(optParamList, newOptParameter(G2DB.toFloat64(f), DOUBLE, DOUBLE_PREC))
   235  			}
   236  		case parser.DECIMAL:
   237  			{
   238  				nsql.WriteString("?")
   239  				bytes, err := G2DB.toDecimal(lval.Value, 0, 0)
   240  				if err != nil {
   241  					return "", optParamList, err
   242  				}
   243  				optParamList = append(optParamList, newOptParameter(bytes, DECIMAL, 0))
   244  			}
   245  		case parser.STRING:
   246  			{
   247  
   248  				if len(lval.Value) > int(INT16_MAX) {
   249  
   250  					nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval.Value) + "'")
   251  				} else {
   252  					nsql.WriteString("?")
   253  					optParamList = append(optParamList, newOptParameter(Dm_build_1.Dm_build_214(lval.Value, serverEncoding, dc), VARCHAR, VARCHAR_PREC))
   254  				}
   255  			}
   256  		case parser.HEX_INT:
   257  
   258  			nsql.WriteString(lval.Value)
   259  		default:
   260  
   261  			nsql.WriteString(lval.Value)
   262  		}
   263  
   264  		if breakIndex > 0 {
   265  			break
   266  		}
   267  	}
   268  
   269  	if breakIndex > 0 {
   270  		for i := breakIndex + 1; i < len(lvalList); i++ {
   271  			nsql.WriteString(lvalList[i].Value)
   272  		}
   273  	}
   274  
   275  	return nsql.String(), optParamList, nil
   276  }
   277  
   278  func (dc *DmConnection) hasConst(sql string) (bool, error) {
   279  	lvalList, err := dc.lex(sql)
   280  	if err != nil {
   281  		return false, err
   282  	}
   283  
   284  	if nil == lvalList || len(lvalList) == 0 {
   285  		return false, nil
   286  	}
   287  
   288  	for i := 0; i < len(lvalList); i++ {
   289  		switch lvalList[i].Tp {
   290  		case parser.NULL, parser.INT, parser.DOUBLE, parser.DECIMAL, parser.STRING, parser.HEX_INT:
   291  			return true, nil
   292  		}
   293  	}
   294  	return false, nil
   295  }
   296  
   297  type OptParameter struct {
   298  	bytes  []byte
   299  	ioType byte
   300  	tp     int
   301  	prec   int
   302  	scale  int
   303  }
   304  
   305  func newOptParameter(bytes []byte, tp int, prec int) OptParameter {
   306  	o := new(OptParameter)
   307  	o.bytes = bytes
   308  	o.tp = tp
   309  	o.prec = prec
   310  	return *o
   311  }
   312  
   313  func (parameter *OptParameter) String() string {
   314  	if parameter.bytes == nil {
   315  		return ""
   316  	}
   317  	return string(parameter.bytes)
   318  }