github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/params.go (about)

     1  package bind
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"errors"
     7  	"fmt"
     8  	"net/url"
     9  	"sort"
    10  	"time"
    11  
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/params"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    16  )
    17  
    18  var (
    19  	errUnsupportedType         = errors.New("unsupported type")
    20  	errUnnamedParam            = errors.New("unnamed param")
    21  	errMultipleQueryParameters = errors.New("only one query arg *table.QueryParameters allowed")
    22  )
    23  
    24  //nolint:gocyclo
    25  func toValue(v interface{}) (_ types.Value, err error) {
    26  	if valuer, ok := v.(driver.Valuer); ok {
    27  		v, err = valuer.Value()
    28  		if err != nil {
    29  			return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err)
    30  		}
    31  	}
    32  
    33  	switch x := v.(type) {
    34  	case nil:
    35  		return types.VoidValue(), nil
    36  	case value.Value:
    37  		return x, nil
    38  	case bool:
    39  		return types.BoolValue(x), nil
    40  	case *bool:
    41  		return types.NullableBoolValue(x), nil
    42  	case int:
    43  		return types.Int32Value(int32(x)), nil
    44  	case *int:
    45  		if x == nil {
    46  			return types.NullValue(types.TypeInt32), nil
    47  		}
    48  		xx := int32(*x)
    49  
    50  		return types.NullableInt32Value(&xx), nil
    51  	case uint:
    52  		return types.Uint32Value(uint32(x)), nil
    53  	case *uint:
    54  		if x == nil {
    55  			return types.NullValue(types.TypeUint32), nil
    56  		}
    57  		xx := uint32(*x)
    58  
    59  		return types.NullableUint32Value(&xx), nil
    60  	case int8:
    61  		return types.Int8Value(x), nil
    62  	case *int8:
    63  		return types.NullableInt8Value(x), nil
    64  	case uint8:
    65  		return types.Uint8Value(x), nil
    66  	case *uint8:
    67  		return types.NullableUint8Value(x), nil
    68  	case int16:
    69  		return types.Int16Value(x), nil
    70  	case *int16:
    71  		return types.NullableInt16Value(x), nil
    72  	case uint16:
    73  		return types.Uint16Value(x), nil
    74  	case *uint16:
    75  		return types.NullableUint16Value(x), nil
    76  	case int32:
    77  		return types.Int32Value(x), nil
    78  	case *int32:
    79  		return types.NullableInt32Value(x), nil
    80  	case uint32:
    81  		return types.Uint32Value(x), nil
    82  	case *uint32:
    83  		return types.NullableUint32Value(x), nil
    84  	case int64:
    85  		return types.Int64Value(x), nil
    86  	case *int64:
    87  		return types.NullableInt64Value(x), nil
    88  	case uint64:
    89  		return types.Uint64Value(x), nil
    90  	case *uint64:
    91  		return types.NullableUint64Value(x), nil
    92  	case float32:
    93  		return types.FloatValue(x), nil
    94  	case *float32:
    95  		return types.NullableFloatValue(x), nil
    96  	case float64:
    97  		return types.DoubleValue(x), nil
    98  	case *float64:
    99  		return types.NullableDoubleValue(x), nil
   100  	case []byte:
   101  		return types.BytesValue(x), nil
   102  	case *[]byte:
   103  		return types.NullableBytesValue(x), nil
   104  	case string:
   105  		return types.TextValue(x), nil
   106  	case *string:
   107  		return types.NullableTextValue(x), nil
   108  	case []string:
   109  		items := make([]types.Value, len(x))
   110  		for i := range x {
   111  			items[i] = types.TextValue(x[i])
   112  		}
   113  
   114  		return types.ListValue(items...), nil
   115  	case [16]byte:
   116  		return types.UUIDValue(x), nil
   117  	case *[16]byte:
   118  		return types.NullableUUIDValue(x), nil
   119  	case time.Time:
   120  		return types.TimestampValueFromTime(x), nil
   121  	case *time.Time:
   122  		return types.NullableTimestampValueFromTime(x), nil
   123  	case time.Duration:
   124  		return types.IntervalValueFromDuration(x), nil
   125  	case *time.Duration:
   126  		return types.NullableIntervalValueFromDuration(x), nil
   127  	default:
   128  		return nil, xerrors.WithStackTrace(
   129  			fmt.Errorf("%T: %w. Create issue for support new type %s",
   130  				x, errUnsupportedType, supportNewTypeLink(x),
   131  			),
   132  		)
   133  	}
   134  }
   135  
   136  func supportNewTypeLink(x interface{}) string {
   137  	v := url.Values{}
   138  	v.Add("labels", "enhancement,database/sql")
   139  	v.Add("template", "02_FEATURE_REQUEST.md")
   140  	v.Add("title", fmt.Sprintf("feat: Support new type `%T` in `database/sql` query args", x))
   141  
   142  	return "https://github.com/ydb-platform/ydb-go-sdk/issues/new?" + v.Encode()
   143  }
   144  
   145  func toYdbParam(name string, value interface{}) (*params.Parameter, error) {
   146  	if na, ok := value.(driver.NamedValue); ok {
   147  		n, v := na.Name, na.Value
   148  		if n != "" {
   149  			name = n
   150  		}
   151  		value = v
   152  	}
   153  	if na, ok := value.(sql.NamedArg); ok {
   154  		n, v := na.Name, na.Value
   155  		if n != "" {
   156  			name = n
   157  		}
   158  		value = v
   159  	}
   160  	if v, ok := value.(*params.Parameter); ok {
   161  		return v, nil
   162  	}
   163  	v, err := toValue(value)
   164  	if err != nil {
   165  		return nil, xerrors.WithStackTrace(err)
   166  	}
   167  	if name == "" {
   168  		return nil, xerrors.WithStackTrace(errUnnamedParam)
   169  	}
   170  	if name[0] != '$' {
   171  		name = "$" + name
   172  	}
   173  
   174  	return params.Named(name, v), nil
   175  }
   176  
   177  func Params(args ...interface{}) (parameters []*params.Parameter, _ error) {
   178  	parameters = make([]*params.Parameter, 0, len(args))
   179  	for i, arg := range args {
   180  		switch x := arg.(type) {
   181  		case driver.NamedValue:
   182  			if x.Name == "" {
   183  				switch xx := x.Value.(type) {
   184  				case *params.Parameters:
   185  					if len(args) > 1 {
   186  						return nil, xerrors.WithStackTrace(errMultipleQueryParameters)
   187  					}
   188  					parameters = *xx
   189  				case *params.Parameter:
   190  					parameters = append(parameters, xx)
   191  				default:
   192  					x.Name = fmt.Sprintf("$p%d", i)
   193  					param, err := toYdbParam(x.Name, x.Value)
   194  					if err != nil {
   195  						return nil, xerrors.WithStackTrace(err)
   196  					}
   197  					parameters = append(parameters, param)
   198  				}
   199  			} else {
   200  				param, err := toYdbParam(x.Name, x.Value)
   201  				if err != nil {
   202  					return nil, xerrors.WithStackTrace(err)
   203  				}
   204  				parameters = append(parameters, param)
   205  			}
   206  		case sql.NamedArg:
   207  			if x.Name == "" {
   208  				return nil, xerrors.WithStackTrace(errUnnamedParam)
   209  			}
   210  			param, err := toYdbParam(x.Name, x.Value)
   211  			if err != nil {
   212  				return nil, xerrors.WithStackTrace(err)
   213  			}
   214  			parameters = append(parameters, param)
   215  		case *params.Parameters:
   216  			if len(args) > 1 {
   217  				return nil, xerrors.WithStackTrace(errMultipleQueryParameters)
   218  			}
   219  			parameters = *x
   220  		case *params.Parameter:
   221  			parameters = append(parameters, x)
   222  		default:
   223  			param, err := toYdbParam(fmt.Sprintf("$p%d", i), x)
   224  			if err != nil {
   225  				return nil, xerrors.WithStackTrace(err)
   226  			}
   227  			parameters = append(parameters, param)
   228  		}
   229  	}
   230  	sort.Slice(parameters, func(i, j int) bool {
   231  		return parameters[i].Name() < parameters[j].Name()
   232  	})
   233  
   234  	return parameters, nil
   235  }