github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/bind/numeric_args.go (about) 1 package bind 2 3 import ( 4 "fmt" 5 "strconv" 6 "unicode/utf8" 7 8 "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" 9 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 10 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" 11 "github.com/ydb-platform/ydb-go-sdk/v3/table" 12 ) 13 14 type NumericArgs struct{} 15 16 func (m NumericArgs) blockID() blockID { 17 return blockYQL 18 } 19 20 func (m NumericArgs) RewriteQuery(sql string, args ...interface{}) (yql string, newArgs []interface{}, err error) { 21 l := &sqlLexer{ 22 src: sql, 23 stateFn: numericArgsStateFn, 24 rawStateFn: numericArgsStateFn, 25 } 26 27 for l.stateFn != nil { 28 l.stateFn = l.stateFn(l) 29 } 30 31 buffer := xstring.Buffer() 32 defer buffer.Free() 33 34 if len(args) > 0 { 35 parameters, err := parsePositionalParameters(args) 36 if err != nil { 37 return "", nil, err 38 } 39 newArgs = make([]interface{}, len(parameters)) 40 for i, param := range parameters { 41 newArgs[i] = param 42 } 43 } 44 45 for _, p := range l.parts { 46 switch p := p.(type) { 47 case string: 48 buffer.WriteString(p) 49 case numericArg: 50 if p == 0 { 51 return "", nil, xerrors.WithStackTrace(ErrUnexpectedNumericArgZero) 52 } 53 if int(p) > len(args) { 54 return "", nil, xerrors.WithStackTrace( 55 fmt.Errorf("%w: $%d, len(args) = %d", ErrInconsistentArgs, p, len(args)), 56 ) 57 } 58 paramIndex := int(p - 1) 59 val, ok := newArgs[paramIndex].(table.ParameterOption) 60 if !ok { 61 panic(fmt.Sprintf("unsupported type conversion from %T to table.ParameterOption", val)) 62 } 63 buffer.WriteString(val.Name()) 64 } 65 } 66 67 yql = buffer.String() 68 if len(newArgs) > 0 { 69 yql = "-- origin query with numeric args replacement\n" + yql 70 } 71 72 return yql, newArgs, nil 73 } 74 75 func numericArgsStateFn(l *sqlLexer) stateFn { 76 for { 77 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 78 l.pos += width 79 80 switch r { 81 case '`': 82 return backtickState 83 case '\'': 84 return singleQuoteState 85 case '"': 86 return doubleQuoteState 87 case '$': 88 nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 89 if isNumber(nextRune) { 90 if l.pos-l.start > 0 { 91 l.parts = append(l.parts, l.src[l.start:l.pos-width]) 92 } 93 l.start = l.pos 94 95 return numericArgState 96 } 97 case '-': 98 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 99 if nextRune == '-' { 100 l.pos += width 101 102 return oneLineCommentState 103 } 104 case '/': 105 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 106 if nextRune == '*' { 107 l.pos += width 108 109 return multilineCommentState 110 } 111 case utf8.RuneError: 112 if l.pos-l.start > 0 { 113 l.parts = append(l.parts, l.src[l.start:l.pos]) 114 l.start = l.pos 115 } 116 117 return nil 118 } 119 } 120 } 121 122 func parsePositionalParameters(args []interface{}) ([]*params.Parameter, error) { 123 newArgs := make([]*params.Parameter, len(args)) 124 for i, arg := range args { 125 paramName := fmt.Sprintf("$p%d", i) 126 param, err := toYdbParam(paramName, arg) 127 if err != nil { 128 return nil, err 129 } 130 newArgs[i] = param 131 } 132 133 return newArgs, nil 134 } 135 136 func numericArgState(l *sqlLexer) stateFn { 137 numbers := "" 138 defer func() { 139 if len(numbers) > 0 { 140 i, err := strconv.Atoi(numbers) 141 if err != nil { 142 panic(err) 143 } 144 l.parts = append(l.parts, numericArg(i)) 145 l.start = l.pos 146 } else { 147 l.parts = append(l.parts, l.src[l.start-1:l.pos]) 148 l.start = l.pos 149 } 150 }() 151 for { 152 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 153 l.pos += width 154 155 switch { 156 case isNumber(r): 157 numbers += string(r) 158 case isLetter(r): 159 numbers = "" 160 161 return l.rawStateFn 162 default: 163 l.pos -= width 164 165 return l.rawStateFn 166 } 167 } 168 }