github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/positional_args.go (about) 1 package bind 2 3 import ( 4 "fmt" 5 "strconv" 6 "unicode/utf8" 7 8 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 9 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" 10 "github.com/ydb-platform/ydb-go-sdk/v3/table" 11 ) 12 13 type PositionalArgs struct{} 14 15 func (m PositionalArgs) blockID() blockID { 16 return blockYQL 17 } 18 19 func (m PositionalArgs) RewriteQuery(sql string, args ...interface{}) ( 20 yql string, newArgs []interface{}, err error, 21 ) { 22 l := &sqlLexer{ 23 src: sql, 24 stateFn: positionalArgsStateFn, 25 rawStateFn: positionalArgsStateFn, 26 } 27 28 for l.stateFn != nil { 29 l.stateFn = l.stateFn(l) 30 } 31 32 var ( 33 buffer = xstring.Buffer() 34 position = 0 35 param table.ParameterOption 36 ) 37 defer buffer.Free() 38 39 for _, p := range l.parts { 40 switch p := p.(type) { 41 case string: 42 buffer.WriteString(p) 43 case positionalArg: 44 if position > len(args)-1 { 45 return "", nil, xerrors.WithStackTrace( 46 fmt.Errorf("%w: position %d, len(args) = %d", ErrInconsistentArgs, position, len(args)), 47 ) 48 } 49 paramName := "$p" + strconv.Itoa(position) 50 param, err = toYdbParam(paramName, args[position]) 51 if err != nil { 52 return "", nil, xerrors.WithStackTrace(err) 53 } 54 newArgs = append(newArgs, param) 55 buffer.WriteString(paramName) 56 position++ 57 } 58 } 59 60 if len(args) != position { 61 return "", nil, xerrors.WithStackTrace( 62 fmt.Errorf("%w: (positional args %d, query args %d)", ErrInconsistentArgs, position, len(args)), 63 ) 64 } 65 66 if position > 0 { 67 const prefix = "-- origin query with positional args replacement\n" 68 69 return prefix + buffer.String(), newArgs, nil 70 } 71 72 return buffer.String(), newArgs, nil 73 } 74 75 func positionalArgsStateFn(l *sqlLexer) stateFn { 76 for { 77 r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 78 l.pos += width 79 80 switch r { 81 case '`': 82 return backtickState 83 case '\'': 84 return singleQuoteState 85 case '"': 86 return doubleQuoteState 87 case '?': 88 l.parts = append(l.parts, l.src[l.start:l.pos-1], positionalArg{}) 89 l.start = l.pos 90 case '-': 91 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 92 if nextRune == '-' { 93 l.pos += width 94 95 return oneLineCommentState 96 } 97 case '/': 98 nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 99 if nextRune == '*' { 100 l.pos += width 101 102 return multilineCommentState 103 } 104 case utf8.RuneError: 105 if l.pos-l.start > 0 { 106 l.parts = append(l.parts, l.src[l.start:l.pos]) 107 l.start = l.pos 108 } 109 110 return nil 111 } 112 } 113 }