github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/bind/params.go (about) 1 package bind 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "errors" 7 "fmt" 8 "net/url" 9 "sort" 10 "time" 11 12 "github.com/google/uuid" 13 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 17 "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 18 ) 19 20 var ( 21 errUnsupportedType = errors.New("unsupported type") 22 errUnnamedParam = errors.New("unnamed param") 23 errMultipleQueryParameters = errors.New("only one query arg *table.QueryParameters allowed") 24 ) 25 26 //nolint:gocyclo,funlen 27 func toValue(v interface{}) (_ types.Value, err error) { 28 if valuer, ok := v.(driver.Valuer); ok { 29 v, err = valuer.Value() 30 if err != nil { 31 return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err) 32 } 33 } 34 35 switch x := v.(type) { 36 case nil: 37 return types.VoidValue(), nil 38 case value.Value: 39 return x, nil 40 case bool: 41 return types.BoolValue(x), nil 42 case *bool: 43 return types.NullableBoolValue(x), nil 44 case int: 45 return types.Int32Value(int32(x)), nil 46 case *int: 47 if x == nil { 48 return types.NullValue(types.TypeInt32), nil 49 } 50 xx := int32(*x) 51 52 return types.NullableInt32Value(&xx), nil 53 case uint: 54 return types.Uint32Value(uint32(x)), nil 55 case *uint: 56 if x == nil { 57 return types.NullValue(types.TypeUint32), nil 58 } 59 xx := uint32(*x) 60 61 return types.NullableUint32Value(&xx), nil 62 case int8: 63 return types.Int8Value(x), nil 64 case *int8: 65 return types.NullableInt8Value(x), nil 66 case uint8: 67 return types.Uint8Value(x), nil 68 case *uint8: 69 return types.NullableUint8Value(x), nil 70 case int16: 71 return types.Int16Value(x), nil 72 case *int16: 73 return types.NullableInt16Value(x), nil 74 case uint16: 75 return types.Uint16Value(x), nil 76 case *uint16: 77 return types.NullableUint16Value(x), nil 78 case int32: 79 return types.Int32Value(x), nil 80 case *int32: 81 return types.NullableInt32Value(x), nil 82 case uint32: 83 return types.Uint32Value(x), nil 84 case *uint32: 85 return types.NullableUint32Value(x), nil 86 case int64: 87 return types.Int64Value(x), nil 88 case *int64: 89 return types.NullableInt64Value(x), nil 90 case uint64: 91 return types.Uint64Value(x), nil 92 case *uint64: 93 return types.NullableUint64Value(x), nil 94 case float32: 95 return types.FloatValue(x), nil 96 case *float32: 97 return types.NullableFloatValue(x), nil 98 case float64: 99 return types.DoubleValue(x), nil 100 case *float64: 101 return types.NullableDoubleValue(x), nil 102 case []byte: 103 return types.BytesValue(x), nil 104 case *[]byte: 105 return types.NullableBytesValue(x), nil 106 case string: 107 return types.TextValue(x), nil 108 case *string: 109 return types.NullableTextValue(x), nil 110 case []string: 111 items := make([]types.Value, len(x)) 112 for i := range x { 113 items[i] = types.TextValue(x[i]) 114 } 115 116 return types.ListValue(items...), nil 117 case [16]byte: 118 return nil, xerrors.Wrap(value.ErrIssue1501BadUUID) 119 case *[16]byte: 120 return nil, xerrors.Wrap(value.ErrIssue1501BadUUID) 121 case types.UUIDBytesWithIssue1501Type: 122 return types.UUIDWithIssue1501Value(x.AsBytesArray()), nil 123 case *types.UUIDBytesWithIssue1501Type: 124 if x == nil { 125 return types.NullableUUIDValueWithIssue1501(nil), nil 126 } 127 val := x.AsBytesArray() 128 129 return types.NullableUUIDValueWithIssue1501(&val), nil 130 case uuid.UUID: 131 return types.UuidValue(x), nil 132 case *uuid.UUID: 133 return types.NullableUUIDTypedValue(x), nil 134 case time.Time: 135 return types.TimestampValueFromTime(x), nil 136 case *time.Time: 137 return types.NullableTimestampValueFromTime(x), nil 138 case time.Duration: 139 return types.IntervalValueFromDuration(x), nil 140 case *time.Duration: 141 return types.NullableIntervalValueFromDuration(x), nil 142 default: 143 return nil, xerrors.WithStackTrace( 144 fmt.Errorf("%T: %w. Create issue for support new type %s", 145 x, errUnsupportedType, supportNewTypeLink(x), 146 ), 147 ) 148 } 149 } 150 151 func supportNewTypeLink(x interface{}) string { 152 v := url.Values{} 153 v.Add("labels", "enhancement,database/sql") 154 v.Add("template", "02_FEATURE_REQUEST.md") 155 v.Add("title", fmt.Sprintf("feat: Support new type `%T` in `database/sql` query args", x)) 156 157 return "https://github.com/ydb-platform/ydb-go-sdk/issues/new?" + v.Encode() 158 } 159 160 func toYdbParam(name string, value interface{}) (*params.Parameter, error) { 161 if na, ok := value.(driver.NamedValue); ok { 162 n, v := na.Name, na.Value 163 if n != "" { 164 name = n 165 } 166 value = v 167 } 168 if na, ok := value.(sql.NamedArg); ok { 169 n, v := na.Name, na.Value 170 if n != "" { 171 name = n 172 } 173 value = v 174 } 175 if v, ok := value.(*params.Parameter); ok { 176 return v, nil 177 } 178 v, err := toValue(value) 179 if err != nil { 180 return nil, xerrors.WithStackTrace(err) 181 } 182 if name == "" { 183 return nil, xerrors.WithStackTrace(errUnnamedParam) 184 } 185 if name[0] != '$' { 186 name = "$" + name 187 } 188 189 return params.Named(name, v), nil 190 } 191 192 func Params(args ...interface{}) ([]*params.Parameter, error) { 193 parameters := make([]*params.Parameter, 0, len(args)) 194 for i, arg := range args { 195 var newParam *params.Parameter 196 var newParams []*params.Parameter 197 var err error 198 switch x := arg.(type) { 199 case driver.NamedValue: 200 newParams, err = paramHandleNamedValue(x, i, len(args)) 201 case sql.NamedArg: 202 if x.Name == "" { 203 return nil, xerrors.WithStackTrace(errUnnamedParam) 204 } 205 newParam, err = toYdbParam(x.Name, x.Value) 206 newParams = append(newParams, newParam) 207 case *params.Parameters: 208 if len(args) > 1 { 209 return nil, xerrors.WithStackTrace(errMultipleQueryParameters) 210 } 211 parameters = *x 212 case *params.Parameter: 213 newParams = append(newParams, x) 214 default: 215 newParam, err = toYdbParam(fmt.Sprintf("$p%d", i), x) 216 newParams = append(newParams, newParam) 217 } 218 if err != nil { 219 return nil, xerrors.WithStackTrace(err) 220 } 221 parameters = append(parameters, newParams...) 222 } 223 sort.Slice(parameters, func(i, j int) bool { 224 return parameters[i].Name() < parameters[j].Name() 225 }) 226 227 return parameters, nil 228 } 229 230 func paramHandleNamedValue(arg driver.NamedValue, paramNumber, argsLen int) ([]*params.Parameter, error) { 231 if arg.Name == "" { 232 switch x := arg.Value.(type) { 233 case *params.Parameters: 234 if argsLen > 1 { 235 return nil, xerrors.WithStackTrace(errMultipleQueryParameters) 236 } 237 238 return *x, nil 239 case *params.Parameter: 240 return []*params.Parameter{x}, nil 241 default: 242 arg.Name = fmt.Sprintf("$p%d", paramNumber) 243 param, err := toYdbParam(arg.Name, arg.Value) 244 if err != nil { 245 return nil, xerrors.WithStackTrace(err) 246 } 247 248 return []*params.Parameter{param}, nil 249 } 250 } else { 251 param, err := toYdbParam(arg.Name, arg.Value) 252 if err != nil { 253 return nil, xerrors.WithStackTrace(err) 254 } 255 256 return []*params.Parameter{param}, nil 257 } 258 }