github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/bind/numeric_args.go (about)

     1  package bind
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"unicode/utf8"
     7  
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/params"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    12  )
    13  
    14  type NumericArgs struct{}
    15  
    16  func (m NumericArgs) blockID() blockID {
    17  	return blockYQL
    18  }
    19  
    20  func (m NumericArgs) RewriteQuery(sql string, args ...interface{}) (yql string, newArgs []interface{}, err error) {
    21  	l := &sqlLexer{
    22  		src:        sql,
    23  		stateFn:    numericArgsStateFn,
    24  		rawStateFn: numericArgsStateFn,
    25  	}
    26  
    27  	for l.stateFn != nil {
    28  		l.stateFn = l.stateFn(l)
    29  	}
    30  
    31  	buffer := xstring.Buffer()
    32  	defer buffer.Free()
    33  
    34  	if len(args) > 0 {
    35  		parameters, err := parsePositionalParameters(args)
    36  		if err != nil {
    37  			return "", nil, err
    38  		}
    39  		newArgs = make([]interface{}, len(parameters))
    40  		for i, param := range parameters {
    41  			newArgs[i] = param
    42  		}
    43  	}
    44  
    45  	for _, p := range l.parts {
    46  		switch p := p.(type) {
    47  		case string:
    48  			buffer.WriteString(p)
    49  		case numericArg:
    50  			if p == 0 {
    51  				return "", nil, xerrors.WithStackTrace(ErrUnexpectedNumericArgZero)
    52  			}
    53  			if int(p) > len(args) {
    54  				return "", nil, xerrors.WithStackTrace(
    55  					fmt.Errorf("%w: $%d, len(args) = %d", ErrInconsistentArgs, p, len(args)),
    56  				)
    57  			}
    58  			paramIndex := int(p - 1)
    59  			val, ok := newArgs[paramIndex].(table.ParameterOption)
    60  			if !ok {
    61  				panic(fmt.Sprintf("unsupported type conversion from %T to table.ParameterOption", val))
    62  			}
    63  			buffer.WriteString(val.Name())
    64  		}
    65  	}
    66  
    67  	yql = buffer.String()
    68  	if len(newArgs) > 0 {
    69  		yql = "-- origin query with numeric args replacement\n" + yql
    70  	}
    71  
    72  	return yql, newArgs, nil
    73  }
    74  
    75  func numericArgsStateFn(l *sqlLexer) stateFn {
    76  	for {
    77  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
    78  		l.pos += width
    79  
    80  		switch r {
    81  		case '`':
    82  			return backtickState
    83  		case '\'':
    84  			return singleQuoteState
    85  		case '"':
    86  			return doubleQuoteState
    87  		case '$':
    88  			nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
    89  			if isNumber(nextRune) {
    90  				if l.pos-l.start > 0 {
    91  					l.parts = append(l.parts, l.src[l.start:l.pos-width])
    92  				}
    93  				l.start = l.pos
    94  
    95  				return numericArgState
    96  			}
    97  		case '-':
    98  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
    99  			if nextRune == '-' {
   100  				l.pos += width
   101  
   102  				return oneLineCommentState
   103  			}
   104  		case '/':
   105  			nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
   106  			if nextRune == '*' {
   107  				l.pos += width
   108  
   109  				return multilineCommentState
   110  			}
   111  		case utf8.RuneError:
   112  			if l.pos-l.start > 0 {
   113  				l.parts = append(l.parts, l.src[l.start:l.pos])
   114  				l.start = l.pos
   115  			}
   116  
   117  			return nil
   118  		}
   119  	}
   120  }
   121  
   122  func parsePositionalParameters(args []interface{}) ([]*params.Parameter, error) {
   123  	newArgs := make([]*params.Parameter, len(args))
   124  	for i, arg := range args {
   125  		paramName := fmt.Sprintf("$p%d", i)
   126  		param, err := toYdbParam(paramName, arg)
   127  		if err != nil {
   128  			return nil, err
   129  		}
   130  		newArgs[i] = param
   131  	}
   132  
   133  	return newArgs, nil
   134  }
   135  
   136  func numericArgState(l *sqlLexer) stateFn {
   137  	numbers := ""
   138  	defer func() {
   139  		if len(numbers) > 0 {
   140  			i, err := strconv.Atoi(numbers)
   141  			if err != nil {
   142  				panic(err)
   143  			}
   144  			l.parts = append(l.parts, numericArg(i))
   145  			l.start = l.pos
   146  		} else {
   147  			l.parts = append(l.parts, l.src[l.start-1:l.pos])
   148  			l.start = l.pos
   149  		}
   150  	}()
   151  	for {
   152  		r, width := utf8.DecodeRuneInString(l.src[l.pos:])
   153  		l.pos += width
   154  
   155  		switch {
   156  		case isNumber(r):
   157  			numbers += string(r)
   158  		case isLetter(r):
   159  			numbers = ""
   160  
   161  			return l.rawStateFn
   162  		default:
   163  			l.pos -= width
   164  
   165  			return l.rawStateFn
   166  		}
   167  	}
   168  }