github.com/wanlay/gorm-dm8@v1.0.5/dmr/zzk.go (about)

     1  /*
     2   * Copyright (c) 2000-2018, 达梦数据库有限公司.
     3   * All rights reserved.
     4   */
     5  package dmr
     6  
     7  import (
     8  	"bytes"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/wanlay/gorm-dm8/dmr/util"
    13  
    14  	"github.com/wanlay/gorm-dm8/dmr/parser"
    15  )
    16  
    17  func (dc *DmConnection) lex(sql string) ([]*parser.LVal, error) {
    18  	if dc.lexer == nil {
    19  		dc.lexer = parser.NewLexer(strings.NewReader(sql), false)
    20  	} else {
    21  		dc.lexer.Reset(strings.NewReader(sql))
    22  	}
    23  
    24  	lexer := dc.lexer
    25  	var lval *parser.LVal
    26  	var err error
    27  	lvalList := make([]*parser.LVal, 0, 64)
    28  	lval, err = lexer.Yylex()
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	for lval != nil {
    34  		lvalList = append(lvalList, lval)
    35  		lval.Position = len(lvalList)
    36  		lval, err = lexer.Yylex()
    37  		if err != nil {
    38  			return nil, err
    39  		}
    40  	}
    41  
    42  	return lvalList, nil
    43  }
    44  
    45  func lexSkipWhitespace(sql string, n int) ([]*parser.LVal, error) {
    46  	lexer := parser.NewLexer(strings.NewReader(sql), false)
    47  
    48  	var lval *parser.LVal
    49  	var err error
    50  	lvalList := make([]*parser.LVal, 0, 64)
    51  	lval, err = lexer.Yylex()
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	for lval != nil && n > 0 {
    57  		lval.Position = len(lvalList)
    58  		if lval.Tp == parser.WHITESPACE_OR_COMMENT {
    59  			continue
    60  		}
    61  
    62  		lvalList = append(lvalList, lval)
    63  		n--
    64  		lval, err = lexer.Yylex()
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  
    69  	}
    70  
    71  	return lvalList, nil
    72  }
    73  
    74  func (dc *DmConnection) escape(sql string, keywords []string) (string, error) {
    75  
    76  	if (keywords == nil || len(keywords) == 0) && strings.Index(sql, "{") == -1 {
    77  		return sql, nil
    78  	}
    79  	var keywordMap map[string]interface{}
    80  	if keywords != nil && len(keywords) > 0 {
    81  		keywordMap = make(map[string]interface{}, len(keywords))
    82  		for _, keyword := range keywords {
    83  			keywordMap[strings.ToUpper(keyword)] = nil
    84  		}
    85  	}
    86  	nsql := bytes.NewBufferString("")
    87  	stack := make([]bool, 0, 64)
    88  	lvalList, err := dc.lex(sql)
    89  	if err != nil {
    90  		return "", err
    91  	}
    92  
    93  	for i := 0; i < len(lvalList); i++ {
    94  		lval0 := lvalList[i]
    95  		if lval0.Tp == parser.NORMAL {
    96  			if lval0.Value == "{" {
    97  				lval1 := next(lvalList, i+1)
    98  				if lval1 == nil || lval1.Tp != parser.NORMAL {
    99  					stack = append(stack, false)
   100  					nsql.WriteString(lval0.Value)
   101  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "escape") || util.StringUtil.EqualsIgnoreCase(lval1.Value, "call") {
   102  					stack = append(stack, true)
   103  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "oj") {
   104  					stack = append(stack, true)
   105  					lval1.Value = ""
   106  					lval1.Tp = parser.WHITESPACE_OR_COMMENT
   107  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "d") {
   108  					stack = append(stack, true)
   109  					lval1.Value = "date"
   110  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "t") {
   111  					stack = append(stack, true)
   112  					lval1.Value = "time"
   113  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "ts") {
   114  					stack = append(stack, true)
   115  					lval1.Value = "datetime"
   116  				} else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "fn") {
   117  					stack = append(stack, true)
   118  					lval1.Value = ""
   119  					lval1.Tp = parser.WHITESPACE_OR_COMMENT
   120  					lval2 := next(lvalList, lval1.Position+1)
   121  					if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "database") {
   122  						lval2.Value = "cur_database"
   123  					}
   124  				} else if util.StringUtil.Equals(lval1.Value, "?") {
   125  					lval2 := next(lvalList, lval1.Position+1)
   126  					if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "=") {
   127  						lval3 := next(lvalList, lval2.Position+1)
   128  						if lval3 != nil && lval3.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval3.Value, "call") {
   129  							stack = append(stack, true)
   130  							lval3.Value = ""
   131  							lval3.Tp = parser.WHITESPACE_OR_COMMENT
   132  						} else {
   133  							stack = append(stack, false)
   134  							nsql.WriteString(lval0.Value)
   135  						}
   136  					} else {
   137  						stack = append(stack, false)
   138  						nsql.WriteString(lval0.Value)
   139  					}
   140  				} else {
   141  					stack = append(stack, false)
   142  					nsql.WriteString(lval0.Value)
   143  				}
   144  			} else if util.StringUtil.Equals(lval0.Value, "}") {
   145  				if len(stack) != 0 && stack[len(stack)-1] {
   146  
   147  				} else {
   148  					nsql.WriteString(lval0.Value)
   149  				}
   150  				stack = stack[:len(stack)-1]
   151  			} else {
   152  				if keywordMap != nil {
   153  					_, ok := keywordMap[strings.ToUpper(lval0.Value)]
   154  					if ok {
   155  						nsql.WriteString("\"" + util.StringUtil.ProcessDoubleQuoteOfName(strings.ToUpper(lval0.Value)) + "\"")
   156  					} else {
   157  						nsql.WriteString(lval0.Value)
   158  					}
   159  				} else {
   160  					nsql.WriteString(lval0.Value)
   161  				}
   162  			}
   163  		} else if lval0.Tp == parser.STRING {
   164  			nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval0.Value) + "'")
   165  		} else {
   166  			nsql.WriteString(lval0.Value)
   167  		}
   168  	}
   169  
   170  	return nsql.String(), nil
   171  }
   172  
   173  func next(lvalList []*parser.LVal, start int) *parser.LVal {
   174  	var lval *parser.LVal
   175  
   176  	size := len(lvalList)
   177  	for i := start; i < size; i++ {
   178  		lval = lvalList[i]
   179  		if lval.Tp != parser.WHITESPACE_OR_COMMENT {
   180  			break
   181  		}
   182  	}
   183  	return lval
   184  }
   185  
   186  func (dc *DmConnection) execOpt(sql string, optParamList []OptParameter, serverEncoding string) (string, []OptParameter, error) {
   187  	nsql := bytes.NewBufferString("")
   188  
   189  	lvalList, err := dc.lex(sql)
   190  	if err != nil {
   191  		return "", optParamList, err
   192  	}
   193  
   194  	if nil == lvalList || len(lvalList) == 0 {
   195  		return sql, optParamList, nil
   196  	}
   197  
   198  	firstWord := lvalList[0].Value
   199  	if !(util.StringUtil.EqualsIgnoreCase(firstWord, "INSERT") || util.StringUtil.EqualsIgnoreCase(firstWord, "SELECT") ||
   200  		util.StringUtil.EqualsIgnoreCase(firstWord, "UPDATE") || util.StringUtil.EqualsIgnoreCase(firstWord, "DELETE")) {
   201  		return sql, optParamList, nil
   202  	}
   203  
   204  	breakIndex := 0
   205  	for i := 0; i < len(lvalList); i++ {
   206  		lval := lvalList[i]
   207  		switch lval.Tp {
   208  		case parser.NULL:
   209  			{
   210  				nsql.WriteString("?")
   211  				optParamList = append(optParamList, newOptParameter(nil, NULL, NULL_PREC))
   212  			}
   213  		case parser.INT:
   214  			{
   215  				nsql.WriteString("?")
   216  				value, err := strconv.Atoi(lval.Value)
   217  				if err != nil {
   218  					return "", optParamList, err
   219  				}
   220  
   221  				if value <= int(INT32_MAX) && value >= int(INT32_MIN) {
   222  					optParamList = append(optParamList, newOptParameter(G2DB.toInt32(int32(value)), INT, INT_PREC))
   223  
   224  				} else {
   225  					optParamList = append(optParamList, newOptParameter(G2DB.toInt64(int64(value)), BIGINT, BIGINT_PREC))
   226  				}
   227  			}
   228  		case parser.DOUBLE:
   229  			{
   230  				nsql.WriteString("?")
   231  				f, err := strconv.ParseFloat(lval.Value, 64)
   232  				if err != nil {
   233  					return "", optParamList, err
   234  				}
   235  
   236  				optParamList = append(optParamList, newOptParameter(G2DB.toFloat64(f), DOUBLE, DOUBLE_PREC))
   237  			}
   238  		case parser.DECIMAL:
   239  			{
   240  				nsql.WriteString("?")
   241  				bytes, err := G2DB.toDecimal(lval.Value, 0, 0)
   242  				if err != nil {
   243  					return "", optParamList, err
   244  				}
   245  				optParamList = append(optParamList, newOptParameter(bytes, DECIMAL, 0))
   246  			}
   247  		case parser.STRING:
   248  			{
   249  
   250  				if len(lval.Value) > int(INT16_MAX) {
   251  
   252  					nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval.Value) + "'")
   253  				} else {
   254  					nsql.WriteString("?")
   255  					optParamList = append(optParamList, newOptParameter(Dm_build_1220.Dm_build_1433(lval.Value, serverEncoding, dc), VARCHAR, VARCHAR_PREC))
   256  				}
   257  			}
   258  		case parser.HEX_INT:
   259  
   260  			nsql.WriteString(lval.Value)
   261  		default:
   262  
   263  			nsql.WriteString(lval.Value)
   264  		}
   265  
   266  		if breakIndex > 0 {
   267  			break
   268  		}
   269  	}
   270  
   271  	if breakIndex > 0 {
   272  		for i := breakIndex + 1; i < len(lvalList); i++ {
   273  			nsql.WriteString(lvalList[i].Value)
   274  		}
   275  	}
   276  
   277  	return nsql.String(), optParamList, nil
   278  }
   279  
   280  func (dc *DmConnection) hasConst(sql string) (bool, error) {
   281  	lvalList, err := dc.lex(sql)
   282  	if err != nil {
   283  		return false, err
   284  	}
   285  
   286  	if nil == lvalList || len(lvalList) == 0 {
   287  		return false, nil
   288  	}
   289  
   290  	for i := 0; i < len(lvalList); i++ {
   291  		switch lvalList[i].Tp {
   292  		case parser.NULL, parser.INT, parser.DOUBLE, parser.DECIMAL, parser.STRING, parser.HEX_INT:
   293  			return true, nil
   294  		}
   295  	}
   296  	return false, nil
   297  }
   298  
   299  type OptParameter struct {
   300  	bytes  []byte
   301  	ioType byte
   302  	tp     int
   303  	prec   int
   304  	scale  int
   305  }
   306  
   307  func newOptParameter(bytes []byte, tp int, prec int) OptParameter {
   308  	o := new(OptParameter)
   309  	o.bytes = bytes
   310  	o.tp = tp
   311  	o.prec = prec
   312  	return *o
   313  }
   314  
   315  func (parameter *OptParameter) String() string {
   316  	if parameter.bytes == nil {
   317  		return ""
   318  	}
   319  	return string(parameter.bytes)
   320  }