github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/positional_args.go (about)

     1  package bind
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"unicode/utf8"
     7  
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    11  )
    12  
    13  type PositionalArgs struct{}
    14  
    15  func (m PositionalArgs) blockID() blockID {
    16  	return blockYQL
    17  }
    18  
    19  func (m PositionalArgs) RewriteQuery(sql string, args ...interface{}) (
    20  	yql string, newArgs []interface{}, err error,
    21  ) {
    22  	l := &sqlLexer{
    23  		src:        sql,
    24  		stateFn:    positionalArgsStateFn,
    25  		rawStateFn: positionalArgsStateFn,
    26  	}
    27  
    28  	for l.stateFn != nil {
    29  		l.stateFn = l.stateFn(l)
    30  	}
    31  
    32  	var (
    33  		buffer   = xstring.Buffer()
    34  		position = 0
    35  		param    table.ParameterOption
    36  	)
    37  	defer buffer.Free()
    38  
    39  	for _, p := range l.parts {
    40  		switch p := p.(type) {
    41  		case string:
    42  			buffer.WriteString(p)
    43  		case positionalArg:
    44  			if position > len(args)-1 {
    45  				return "", nil, xerrors.WithStackTrace(
    46  					fmt.Errorf("%w: position %d, len(args) = %d", ErrInconsistentArgs, position, len(args)),
    47  				)
    48  			}
    49  			paramName := "$p" + strconv.Itoa(position)
    50  			param, err = toYdbParam(paramName, args[position])
    51  			if err != nil {
    52  				return "", nil, xerrors.WithStackTrace(err)
    53  			}
    54  			newArgs = append(newArgs, param)
    55  			buffer.WriteString(paramName)
    56  			position++
    57  		}
    58  	}
    59  
    60  	if len(args) != position {
    61  		return "", nil, xerrors.WithStackTrace(
    62  			fmt.Errorf("%w: (positional args %d, query args %d)", ErrInconsistentArgs, position, len(args)),
    63  		)
    64  	}
    65  
    66  	if position > 0 {
    67  		const prefix = "-- origin query with positional args replacement\n"
    68  
    69  		return prefix + buffer.String(), newArgs, nil
    70  	}
    71  
    72  	return buffer.String(), newArgs, nil
    73  }
    74  
    75  func positionalArgsStateFn(l *sqlLexer) stateFn {
    76  	for {
    77  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    78  		l.pos += width
    79  
    80  		switch r {
    81  		case '`':
    82  			return backtickState
    83  		case '\'':
    84  			return singleQuoteState
    85  		case '"':
    86  			return doubleQuoteState
    87  		case '?':
    88  			l.parts = append(l.parts, l.src[l.start:l.pos-1], positionalArg{})
    89  			l.start = l.pos
    90  		case '-':
    91  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    92  			if nextRune == '-' {
    93  				l.pos += width
    94  
    95  				return oneLineCommentState
    96  			}
    97  		case '/':
    98  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    99  			if nextRune == '*' {
   100  				l.pos += width
   101  
   102  				return multilineCommentState
   103  			}
   104  		case utf8.RuneError:
   105  			if l.pos-l.start > 0 {
   106  				l.parts = append(l.parts, l.src[l.start:l.pos])
   107  				l.start = l.pos
   108  			}
   109  
   110  			return nil
   111  		}
   112  	}
   113  }