github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/numeric_args.go (about) 1 package bind 2 3 import ( 4 "fmt" 5 "strconv" 6 "unicode/utf8" 7 8 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 9 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" 10 "github.com/ydb-platform/ydb-go-sdk/v3/table" 11 ) 12 13 type NumericArgs struct{} 14 15 func (m NumericArgs) blockID() blockID { 16 return blockYQL 17 } 18 19 func (m NumericArgs) RewriteQuery(sql string, args ...interface{}) ( 20 yql string, newArgs []interface{}, err error, 21 ) { 22 l := &sqlLexer{ 23 src: sql, 24 stateFn: numericArgsStateFn, 25 rawStateFn: numericArgsStateFn, 26 } 27 28 for l.stateFn != nil { 29 l.stateFn = l.stateFn(l) 30 } 31 32 var ( 33 buffer = xstring.Buffer() 34 param table.ParameterOption 35 ) 36 defer buffer.Free() 37 38 if len(args) > 0 { 39 newArgs = make([]interface{}, len(args)) 40 } 41 42 for _, p := range l.parts { 43 switch p := p.(type) { 44 case string: 45 buffer.WriteString(p) 46 case numericArg: 47 if p == 0 { 48 return "", nil, xerrors.WithStackTrace(ErrUnexpectedNumericArgZero) 49 } 50 if int(p) > len(args) { 51 return "", nil, xerrors.WithStackTrace( 52 fmt.Errorf("%w: $p%d, len(args) = %d", ErrInconsistentArgs, p, len(args)), 53 ) 54 } 55 paramName := "$p" + strconv.Itoa(int(p-1)) //nolint:goconst 56 if newArgs[p-1] == nil { 57 param, err = toYdbParam(paramName, args[p-1]) 58 if err != nil { 59 return "", nil, xerrors.WithStackTrace(err) 60 } 61 newArgs[p-1] = param 62 buffer.WriteString(param.Name()) 63 } else { 64 buffer.WriteString(newArgs[p-1].(table.ParameterOption).Name()) 65 } 66 } 67 } 68 69 for i, p := range newArgs { 70 if p == nil { 71 return "", nil, xerrors.WithStackTrace( 72 fmt.Errorf("%w: $p%d, len(args) = %d", ErrInconsistentArgs, i+1, len(args)), 73 ) 74 } 75 } 76 77 if len(newArgs) > 0 { 78 const prefix = "-- origin query with numeric args replacement\n" 79 80 return prefix + buffer.String(), newArgs, nil 81 } 82 83 return buffer.String(), newArgs, nil 84 } 85 86 func numericArgsStateFn(l *sqlLexer) stateFn { 87 for { 88 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 89 l.pos += width 90 91 switch r { 92 case '`': 93 return backtickState 94 case '\'': 95 return singleQuoteState 96 case '"': 97 return doubleQuoteState 98 case '$': 99 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 100 if isNumber(nextRune) { 101 if l.pos-l.start > 0 { 102 l.parts = append(l.parts, l.src[l.start:l.pos-width]) 103 } 104 l.start = l.pos 105 106 return numericArgState 107 } 108 case '-': 109 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 110 if nextRune == '-' { 111 l.pos += width 112 113 return oneLineCommentState 114 } 115 case '/': 116 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 117 if nextRune == '*' { 118 l.pos += width 119 120 return multilineCommentState 121 } 122 case utf8.RuneError: 123 if l.pos-l.start > 0 { 124 l.parts = append(l.parts, l.src[l.start:l.pos]) 125 l.start = l.pos 126 } 127 128 return nil 129 } 130 } 131 } 132 133 func numericArgState(l *sqlLexer) stateFn { 134 numbers := "" 135 defer func() { 136 if len(numbers) > 0 { 137 i, err := strconv.Atoi(numbers) 138 if err != nil { 139 panic(err) 140 } 141 l.parts = append(l.parts, numericArg(i)) 142 l.start = l.pos 143 } else { 144 l.parts = append(l.parts, l.src[l.start-1:l.pos]) 145 l.start = l.pos 146 } 147 }() 148 for { 149 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 150 l.pos += width 151 152 switch { 153 case isNumber(r): 154 numbers += string(r) 155 case isLetter(r): 156 numbers = "" 157 158 return l.rawStateFn 159 default: 160 l.pos -= width 161 162 return l.rawStateFn 163 } 164 } 165 }