github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/numeric_args.go (about)

     1  package bind
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"unicode/utf8"
     7  
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    11  )
    12  
    13  type NumericArgs struct{}
    14  
    15  func (m NumericArgs) blockID() blockID {
    16  	return blockYQL
    17  }
    18  
    19  func (m NumericArgs) RewriteQuery(sql string, args ...interface{}) (
    20  	yql string, newArgs []interface{}, err error,
    21  ) {
    22  	l := &sqlLexer{
    23  		src:        sql,
    24  		stateFn:    numericArgsStateFn,
    25  		rawStateFn: numericArgsStateFn,
    26  	}
    27  
    28  	for l.stateFn != nil {
    29  		l.stateFn = l.stateFn(l)
    30  	}
    31  
    32  	var (
    33  		buffer = xstring.Buffer()
    34  		param  table.ParameterOption
    35  	)
    36  	defer buffer.Free()
    37  
    38  	if len(args) > 0 {
    39  		newArgs = make([]interface{}, len(args))
    40  	}
    41  
    42  	for _, p := range l.parts {
    43  		switch p := p.(type) {
    44  		case string:
    45  			buffer.WriteString(p)
    46  		case numericArg:
    47  			if p == 0 {
    48  				return "", nil, xerrors.WithStackTrace(ErrUnexpectedNumericArgZero)
    49  			}
    50  			if int(p) > len(args) {
    51  				return "", nil, xerrors.WithStackTrace(
    52  					fmt.Errorf("%w: $p%d, len(args) = %d", ErrInconsistentArgs, p, len(args)),
    53  				)
    54  			}
    55  			paramName := "$p" + strconv.Itoa(int(p-1)) //nolint:goconst
    56  			if newArgs[p-1] == nil {
    57  				param, err = toYdbParam(paramName, args[p-1])
    58  				if err != nil {
    59  					return "", nil, xerrors.WithStackTrace(err)
    60  				}
    61  				newArgs[p-1] = param
    62  				buffer.WriteString(param.Name())
    63  			} else {
    64  				buffer.WriteString(newArgs[p-1].(table.ParameterOption).Name())
    65  			}
    66  		}
    67  	}
    68  
    69  	for i, p := range newArgs {
    70  		if p == nil {
    71  			return "", nil, xerrors.WithStackTrace(
    72  				fmt.Errorf("%w: $p%d, len(args) = %d", ErrInconsistentArgs, i+1, len(args)),
    73  			)
    74  		}
    75  	}
    76  
    77  	if len(newArgs) > 0 {
    78  		const prefix = "-- origin query with numeric args replacement\n"
    79  
    80  		return prefix + buffer.String(), newArgs, nil
    81  	}
    82  
    83  	return buffer.String(), newArgs, nil
    84  }
    85  
    86  func numericArgsStateFn(l *sqlLexer) stateFn {
    87  	for {
    88  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    89  		l.pos += width
    90  
    91  		switch r {
    92  		case '`':
    93  			return backtickState
    94  		case '\'':
    95  			return singleQuoteState
    96  		case '"':
    97  			return doubleQuoteState
    98  		case '$':
    99  			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
   100  			if isNumber(nextRune) {
   101  				if l.pos-l.start > 0 {
   102  					l.parts = append(l.parts, l.src[l.start:l.pos-width])
   103  				}
   104  				l.start = l.pos
   105  
   106  				return numericArgState
   107  			}
   108  		case '-':
   109  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   110  			if nextRune == '-' {
   111  				l.pos += width
   112  
   113  				return oneLineCommentState
   114  			}
   115  		case '/':
   116  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   117  			if nextRune == '*' {
   118  				l.pos += width
   119  
   120  				return multilineCommentState
   121  			}
   122  		case utf8.RuneError:
   123  			if l.pos-l.start > 0 {
   124  				l.parts = append(l.parts, l.src[l.start:l.pos])
   125  				l.start = l.pos
   126  			}
   127  
   128  			return nil
   129  		}
   130  	}
   131  }
   132  
   133  func numericArgState(l *sqlLexer) stateFn {
   134  	numbers := ""
   135  	defer func() {
   136  		if len(numbers) > 0 {
   137  			i, err := strconv.Atoi(numbers)
   138  			if err != nil {
   139  				panic(err)
   140  			}
   141  			l.parts = append(l.parts, numericArg(i))
   142  			l.start = l.pos
   143  		} else {
   144  			l.parts = append(l.parts, l.src[l.start-1:l.pos])
   145  			l.start = l.pos
   146  		}
   147  	}()
   148  	for {
   149  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   150  		l.pos += width
   151  
   152  		switch {
   153  		case isNumber(r):
   154  			numbers += string(r)
   155  		case isLetter(r):
   156  			numbers = ""
   157  
   158  			return l.rawStateFn
   159  		default:
   160  			l.pos -= width
   161  
   162  			return l.rawStateFn
   163  		}
   164  	}
   165  }