gitee.com/curryzheng/dm@v0.0.1/zzl.go (about) 1 /* 2 * Copyright (c) 2000-2018, 达梦数据库有限公司. 3 * All rights reserved. 4 */ 5 package dm 6 7 import ( 8 "bytes" 9 "gitee.com/curryzheng/dm/parser" 10 "gitee.com/curryzheng/dm/util" 11 "strconv" 12 "strings" 13 ) 14 15 func (dc *DmConnection) lex(sql string) ([]*parser.LVal, error) { 16 if dc.lexer == nil { 17 dc.lexer = parser.NewLexer(strings.NewReader(sql), false) 18 } else { 19 dc.lexer.Reset(strings.NewReader(sql)) 20 } 21 22 lexer := dc.lexer 23 var lval *parser.LVal 24 var err error 25 lvalList := make([]*parser.LVal, 0, 64) 26 lval, err = lexer.Yylex() 27 if err != nil { 28 return nil, err 29 } 30 31 for lval != nil { 32 lvalList = append(lvalList, lval) 33 lval.Position = len(lvalList) 34 lval, err = lexer.Yylex() 35 if err != nil { 36 return nil, err 37 } 38 } 39 40 return lvalList, nil 41 } 42 43 func lexSkipWhitespace(sql string, n int) ([]*parser.LVal, error) { 44 lexer := parser.NewLexer(strings.NewReader(sql), false) 45 46 var lval *parser.LVal 47 var err error 48 lvalList := make([]*parser.LVal, 0, 64) 49 lval, err = lexer.Yylex() 50 if err != nil { 51 return nil, err 52 } 53 54 for lval != nil && n > 0 { 55 lval.Position = len(lvalList) 56 if lval.Tp == parser.WHITESPACE_OR_COMMENT { 57 continue 58 } 59 60 lvalList = append(lvalList, lval) 61 n-- 62 lval, err = lexer.Yylex() 63 if err != nil { 64 return nil, err 65 } 66 67 } 68 69 return lvalList, nil 70 } 71 72 func (dc *DmConnection) escape(sql string, keywords []string) (string, error) { 73 74 if (keywords == nil || len(keywords) == 0) && strings.Index(sql, "{") == -1 { 75 return sql, nil 76 } 77 var keywordMap map[string]interface{} 78 if keywords != nil && len(keywords) > 0 { 79 keywordMap = make(map[string]interface{}, len(keywords)) 80 for _, keyword := range keywords { 81 keywordMap[strings.ToUpper(keyword)] = nil 82 } 83 } 84 nsql := bytes.NewBufferString("") 85 stack := make([]bool, 0, 64) 86 lvalList, err := dc.lex(sql) 87 if err != nil { 88 return "", err 89 } 90 91 for i := 0; i < len(lvalList); i++ { 92 lval0 := lvalList[i] 93 if lval0.Tp == parser.NORMAL { 94 if lval0.Value == "{" { 95 lval1 := next(lvalList, i+1) 96 if lval1 == nil || lval1.Tp != parser.NORMAL { 97 stack = append(stack, false) 98 nsql.WriteString(lval0.Value) 99 } else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "escape") || util.StringUtil.EqualsIgnoreCase(lval1.Value, "call") { 100 stack = append(stack, true) 101 } else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "oj") { 102 stack = append(stack, true) 103 lval1.Value = "" 104 lval1.Tp = parser.WHITESPACE_OR_COMMENT 105 } else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "d") { 106 stack = append(stack, true) 107 lval1.Value = "date" 108 } else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "t") { 109 stack = append(stack, true) 110 lval1.Value = "time" 111 } else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "ts") { 112 stack = append(stack, true) 113 lval1.Value = "datetime" 114 } else if util.StringUtil.EqualsIgnoreCase(lval1.Value, "fn") { 115 stack = append(stack, true) 116 lval1.Value = "" 117 lval1.Tp = parser.WHITESPACE_OR_COMMENT 118 lval2 := next(lvalList, lval1.Position+1) 119 if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "database") { 120 lval2.Value = "cur_database" 121 } 122 } else if util.StringUtil.Equals(lval1.Value, "?") { 123 lval2 := next(lvalList, lval1.Position+1) 124 if lval2 != nil && lval2.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval2.Value, "=") { 125 lval3 := next(lvalList, lval2.Position+1) 126 if lval3 != nil && lval3.Tp == parser.NORMAL && util.StringUtil.EqualsIgnoreCase(lval3.Value, "call") { 127 stack = append(stack, true) 128 lval3.Value = "" 129 lval3.Tp = parser.WHITESPACE_OR_COMMENT 130 } else { 131 stack = append(stack, false) 132 nsql.WriteString(lval0.Value) 133 } 134 } else { 135 stack = append(stack, false) 136 nsql.WriteString(lval0.Value) 137 } 138 } else { 139 stack = append(stack, false) 140 nsql.WriteString(lval0.Value) 141 } 142 } else if util.StringUtil.Equals(lval0.Value, "}") { 143 if len(stack) != 0 && stack[len(stack)-1] { 144 145 } else { 146 nsql.WriteString(lval0.Value) 147 } 148 stack = stack[:len(stack)-1] 149 } else { 150 if keywordMap != nil { 151 _, ok := keywordMap[strings.ToUpper(lval0.Value)] 152 if ok { 153 nsql.WriteString("\"" + util.StringUtil.ProcessDoubleQuoteOfName(strings.ToUpper(lval0.Value)) + "\"") 154 } else { 155 nsql.WriteString(lval0.Value) 156 } 157 } else { 158 nsql.WriteString(lval0.Value) 159 } 160 } 161 } else if lval0.Tp == parser.STRING { 162 nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval0.Value) + "'") 163 } else { 164 nsql.WriteString(lval0.Value) 165 } 166 } 167 168 return nsql.String(), nil 169 } 170 171 func next(lvalList []*parser.LVal, start int) *parser.LVal { 172 var lval *parser.LVal 173 174 size := len(lvalList) 175 for i := start; i < size; i++ { 176 lval = lvalList[i] 177 if lval.Tp != parser.WHITESPACE_OR_COMMENT { 178 break 179 } 180 } 181 return lval 182 } 183 184 func (dc *DmConnection) execOpt(sql string, optParamList []OptParameter, serverEncoding string) (string, []OptParameter, error) { 185 nsql := bytes.NewBufferString("") 186 187 lvalList, err := dc.lex(sql) 188 if err != nil { 189 return "", optParamList, err 190 } 191 192 if nil == lvalList || len(lvalList) == 0 { 193 return sql, optParamList, nil 194 } 195 196 firstWord := lvalList[0].Value 197 if !(util.StringUtil.EqualsIgnoreCase(firstWord, "INSERT") || util.StringUtil.EqualsIgnoreCase(firstWord, "SELECT") || 198 util.StringUtil.EqualsIgnoreCase(firstWord, "UPDATE") || util.StringUtil.EqualsIgnoreCase(firstWord, "DELETE")) { 199 return sql, optParamList, nil 200 } 201 202 breakIndex := 0 203 for i := 0; i < len(lvalList); i++ { 204 lval := lvalList[i] 205 switch lval.Tp { 206 case parser.NULL: 207 { 208 nsql.WriteString("?") 209 optParamList = append(optParamList, newOptParameter(nil, NULL, NULL_PREC)) 210 } 211 case parser.INT: 212 { 213 nsql.WriteString("?") 214 value, err := strconv.Atoi(lval.Value) 215 if err != nil { 216 return "", optParamList, err 217 } 218 219 if value <= int(INT32_MAX) && value >= int(INT32_MIN) { 220 optParamList = append(optParamList, newOptParameter(G2DB.toInt32(int32(value)), INT, INT_PREC)) 221 222 } else { 223 optParamList = append(optParamList, newOptParameter(G2DB.toInt64(int64(value)), BIGINT, BIGINT_PREC)) 224 } 225 } 226 case parser.DOUBLE: 227 { 228 nsql.WriteString("?") 229 f, err := strconv.ParseFloat(lval.Value, 64) 230 if err != nil { 231 return "", optParamList, err 232 } 233 234 optParamList = append(optParamList, newOptParameter(G2DB.toFloat64(f), DOUBLE, DOUBLE_PREC)) 235 } 236 case parser.DECIMAL: 237 { 238 nsql.WriteString("?") 239 bytes, err := G2DB.toDecimal(lval.Value, 0, 0) 240 if err != nil { 241 return "", optParamList, err 242 } 243 optParamList = append(optParamList, newOptParameter(bytes, DECIMAL, 0)) 244 } 245 case parser.STRING: 246 { 247 248 if len(lval.Value) > int(INT16_MAX) { 249 250 nsql.WriteString("'" + util.StringUtil.ProcessSingleQuoteOfName(lval.Value) + "'") 251 } else { 252 nsql.WriteString("?") 253 optParamList = append(optParamList, newOptParameter(Dm_build_1.Dm_build_214(lval.Value, serverEncoding, dc), VARCHAR, VARCHAR_PREC)) 254 } 255 } 256 case parser.HEX_INT: 257 258 nsql.WriteString(lval.Value) 259 default: 260 261 nsql.WriteString(lval.Value) 262 } 263 264 if breakIndex > 0 { 265 break 266 } 267 } 268 269 if breakIndex > 0 { 270 for i := breakIndex + 1; i < len(lvalList); i++ { 271 nsql.WriteString(lvalList[i].Value) 272 } 273 } 274 275 return nsql.String(), optParamList, nil 276 } 277 278 func (dc *DmConnection) hasConst(sql string) (bool, error) { 279 lvalList, err := dc.lex(sql) 280 if err != nil { 281 return false, err 282 } 283 284 if nil == lvalList || len(lvalList) == 0 { 285 return false, nil 286 } 287 288 for i := 0; i < len(lvalList); i++ { 289 switch lvalList[i].Tp { 290 case parser.NULL, parser.INT, parser.DOUBLE, parser.DECIMAL, parser.STRING, parser.HEX_INT: 291 return true, nil 292 } 293 } 294 return false, nil 295 } 296 297 type OptParameter struct { 298 bytes []byte 299 ioType byte 300 tp int 301 prec int 302 scale int 303 } 304 305 func newOptParameter(bytes []byte, tp int, prec int) OptParameter { 306 o := new(OptParameter) 307 o.bytes = bytes 308 o.tp = tp 309 o.prec = prec 310 return *o 311 } 312 313 func (parameter *OptParameter) String() string { 314 if parameter.bytes == nil { 315 return "" 316 } 317 return string(parameter.bytes) 318 }