github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/bind/params.go (about)

     1  package bind
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"errors"
     7  	"fmt"
     8  	"net/url"
     9  	"sort"
    10  	"time"
    11  
    12  	"github.com/google/uuid"
    13  
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/params"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    18  )
    19  
    20  var (
    21  	errUnsupportedType         = errors.New("unsupported type")
    22  	errUnnamedParam            = errors.New("unnamed param")
    23  	errMultipleQueryParameters = errors.New("only one query arg *table.QueryParameters allowed")
    24  )
    25  
    26  //nolint:gocyclo,funlen
    27  func toValue(v interface{}) (_ types.Value, err error) {
    28  	if valuer, ok := v.(driver.Valuer); ok {
    29  		v, err = valuer.Value()
    30  		if err != nil {
    31  			return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err)
    32  		}
    33  	}
    34  
    35  	switch x := v.(type) {
    36  	case nil:
    37  		return types.VoidValue(), nil
    38  	case value.Value:
    39  		return x, nil
    40  	case bool:
    41  		return types.BoolValue(x), nil
    42  	case *bool:
    43  		return types.NullableBoolValue(x), nil
    44  	case int:
    45  		return types.Int32Value(int32(x)), nil
    46  	case *int:
    47  		if x == nil {
    48  			return types.NullValue(types.TypeInt32), nil
    49  		}
    50  		xx := int32(*x)
    51  
    52  		return types.NullableInt32Value(&xx), nil
    53  	case uint:
    54  		return types.Uint32Value(uint32(x)), nil
    55  	case *uint:
    56  		if x == nil {
    57  			return types.NullValue(types.TypeUint32), nil
    58  		}
    59  		xx := uint32(*x)
    60  
    61  		return types.NullableUint32Value(&xx), nil
    62  	case int8:
    63  		return types.Int8Value(x), nil
    64  	case *int8:
    65  		return types.NullableInt8Value(x), nil
    66  	case uint8:
    67  		return types.Uint8Value(x), nil
    68  	case *uint8:
    69  		return types.NullableUint8Value(x), nil
    70  	case int16:
    71  		return types.Int16Value(x), nil
    72  	case *int16:
    73  		return types.NullableInt16Value(x), nil
    74  	case uint16:
    75  		return types.Uint16Value(x), nil
    76  	case *uint16:
    77  		return types.NullableUint16Value(x), nil
    78  	case int32:
    79  		return types.Int32Value(x), nil
    80  	case *int32:
    81  		return types.NullableInt32Value(x), nil
    82  	case uint32:
    83  		return types.Uint32Value(x), nil
    84  	case *uint32:
    85  		return types.NullableUint32Value(x), nil
    86  	case int64:
    87  		return types.Int64Value(x), nil
    88  	case *int64:
    89  		return types.NullableInt64Value(x), nil
    90  	case uint64:
    91  		return types.Uint64Value(x), nil
    92  	case *uint64:
    93  		return types.NullableUint64Value(x), nil
    94  	case float32:
    95  		return types.FloatValue(x), nil
    96  	case *float32:
    97  		return types.NullableFloatValue(x), nil
    98  	case float64:
    99  		return types.DoubleValue(x), nil
   100  	case *float64:
   101  		return types.NullableDoubleValue(x), nil
   102  	case []byte:
   103  		return types.BytesValue(x), nil
   104  	case *[]byte:
   105  		return types.NullableBytesValue(x), nil
   106  	case string:
   107  		return types.TextValue(x), nil
   108  	case *string:
   109  		return types.NullableTextValue(x), nil
   110  	case []string:
   111  		items := make([]types.Value, len(x))
   112  		for i := range x {
   113  			items[i] = types.TextValue(x[i])
   114  		}
   115  
   116  		return types.ListValue(items...), nil
   117  	case [16]byte:
   118  		return nil, xerrors.Wrap(value.ErrIssue1501BadUUID)
   119  	case *[16]byte:
   120  		return nil, xerrors.Wrap(value.ErrIssue1501BadUUID)
   121  	case types.UUIDBytesWithIssue1501Type:
   122  		return types.UUIDWithIssue1501Value(x.AsBytesArray()), nil
   123  	case *types.UUIDBytesWithIssue1501Type:
   124  		if x == nil {
   125  			return types.NullableUUIDValueWithIssue1501(nil), nil
   126  		}
   127  		val := x.AsBytesArray()
   128  
   129  		return types.NullableUUIDValueWithIssue1501(&val), nil
   130  	case uuid.UUID:
   131  		return types.UuidValue(x), nil
   132  	case *uuid.UUID:
   133  		return types.NullableUUIDTypedValue(x), nil
   134  	case time.Time:
   135  		return types.TimestampValueFromTime(x), nil
   136  	case *time.Time:
   137  		return types.NullableTimestampValueFromTime(x), nil
   138  	case time.Duration:
   139  		return types.IntervalValueFromDuration(x), nil
   140  	case *time.Duration:
   141  		return types.NullableIntervalValueFromDuration(x), nil
   142  	default:
   143  		return nil, xerrors.WithStackTrace(
   144  			fmt.Errorf("%T: %w. Create issue for support new type %s",
   145  				x, errUnsupportedType, supportNewTypeLink(x),
   146  			),
   147  		)
   148  	}
   149  }
   150  
   151  func supportNewTypeLink(x interface{}) string {
   152  	v := url.Values{}
   153  	v.Add("labels", "enhancement,database/sql")
   154  	v.Add("template", "02_FEATURE_REQUEST.md")
   155  	v.Add("title", fmt.Sprintf("feat: Support new type `%T` in `database/sql` query args", x))
   156  
   157  	return "https://github.com/ydb-platform/ydb-go-sdk/issues/new?" + v.Encode()
   158  }
   159  
   160  func toYdbParam(name string, value interface{}) (*params.Parameter, error) {
   161  	if na, ok := value.(driver.NamedValue); ok {
   162  		n, v := na.Name, na.Value
   163  		if n != "" {
   164  			name = n
   165  		}
   166  		value = v
   167  	}
   168  	if na, ok := value.(sql.NamedArg); ok {
   169  		n, v := na.Name, na.Value
   170  		if n != "" {
   171  			name = n
   172  		}
   173  		value = v
   174  	}
   175  	if v, ok := value.(*params.Parameter); ok {
   176  		return v, nil
   177  	}
   178  	v, err := toValue(value)
   179  	if err != nil {
   180  		return nil, xerrors.WithStackTrace(err)
   181  	}
   182  	if name == "" {
   183  		return nil, xerrors.WithStackTrace(errUnnamedParam)
   184  	}
   185  	if name[0] != '$' {
   186  		name = "$" + name
   187  	}
   188  
   189  	return params.Named(name, v), nil
   190  }
   191  
   192  func Params(args ...interface{}) ([]*params.Parameter, error) {
   193  	parameters := make([]*params.Parameter, 0, len(args))
   194  	for i, arg := range args {
   195  		var newParam *params.Parameter
   196  		var newParams []*params.Parameter
   197  		var err error
   198  		switch x := arg.(type) {
   199  		case driver.NamedValue:
   200  			newParams, err = paramHandleNamedValue(x, i, len(args))
   201  		case sql.NamedArg:
   202  			if x.Name == "" {
   203  				return nil, xerrors.WithStackTrace(errUnnamedParam)
   204  			}
   205  			newParam, err = toYdbParam(x.Name, x.Value)
   206  			newParams = append(newParams, newParam)
   207  		case *params.Parameters:
   208  			if len(args) > 1 {
   209  				return nil, xerrors.WithStackTrace(errMultipleQueryParameters)
   210  			}
   211  			parameters = *x
   212  		case *params.Parameter:
   213  			newParams = append(newParams, x)
   214  		default:
   215  			newParam, err = toYdbParam(fmt.Sprintf("$p%d", i), x)
   216  			newParams = append(newParams, newParam)
   217  		}
   218  		if err != nil {
   219  			return nil, xerrors.WithStackTrace(err)
   220  		}
   221  		parameters = append(parameters, newParams...)
   222  	}
   223  	sort.Slice(parameters, func(i, j int) bool {
   224  		return parameters[i].Name() < parameters[j].Name()
   225  	})
   226  
   227  	return parameters, nil
   228  }
   229  
   230  func paramHandleNamedValue(arg driver.NamedValue, paramNumber, argsLen int) ([]*params.Parameter, error) {
   231  	if arg.Name == "" {
   232  		switch x := arg.Value.(type) {
   233  		case *params.Parameters:
   234  			if argsLen > 1 {
   235  				return nil, xerrors.WithStackTrace(errMultipleQueryParameters)
   236  			}
   237  
   238  			return *x, nil
   239  		case *params.Parameter:
   240  			return []*params.Parameter{x}, nil
   241  		default:
   242  			arg.Name = fmt.Sprintf("$p%d", paramNumber)
   243  			param, err := toYdbParam(arg.Name, arg.Value)
   244  			if err != nil {
   245  				return nil, xerrors.WithStackTrace(err)
   246  			}
   247  
   248  			return []*params.Parameter{param}, nil
   249  		}
   250  	} else {
   251  		param, err := toYdbParam(arg.Name, arg.Value)
   252  		if err != nil {
   253  			return nil, xerrors.WithStackTrace(err)
   254  		}
   255  
   256  		return []*params.Parameter{param}, nil
   257  	}
   258  }