github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/params.go (about) 1 package bind 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "errors" 7 "fmt" 8 "net/url" 9 "sort" 10 "time" 11 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 15 "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 16 ) 17 18 var ( 19 errUnsupportedType = errors.New("unsupported type") 20 errUnnamedParam = errors.New("unnamed param") 21 errMultipleQueryParameters = errors.New("only one query arg *table.QueryParameters allowed") 22 ) 23 24 //nolint:gocyclo 25 func toValue(v interface{}) (_ types.Value, err error) { 26 if valuer, ok := v.(driver.Valuer); ok { 27 v, err = valuer.Value() 28 if err != nil { 29 return nil, fmt.Errorf("ydb: driver.Valuer error: %w", err) 30 } 31 } 32 33 switch x := v.(type) { 34 case nil: 35 return types.VoidValue(), nil 36 case value.Value: 37 return x, nil 38 case bool: 39 return types.BoolValue(x), nil 40 case *bool: 41 return types.NullableBoolValue(x), nil 42 case int: 43 return types.Int32Value(int32(x)), nil 44 case *int: 45 if x == nil { 46 return types.NullValue(types.TypeInt32), nil 47 } 48 xx := int32(*x) 49 50 return types.NullableInt32Value(&xx), nil 51 case uint: 52 return types.Uint32Value(uint32(x)), nil 53 case *uint: 54 if x == nil { 55 return types.NullValue(types.TypeUint32), nil 56 } 57 xx := uint32(*x) 58 59 return types.NullableUint32Value(&xx), nil 60 case int8: 61 return types.Int8Value(x), nil 62 case *int8: 63 return types.NullableInt8Value(x), nil 64 case uint8: 65 return types.Uint8Value(x), nil 66 case *uint8: 67 return types.NullableUint8Value(x), nil 68 case int16: 69 return types.Int16Value(x), nil 70 case *int16: 71 return types.NullableInt16Value(x), nil 72 case uint16: 73 return types.Uint16Value(x), nil 74 case *uint16: 75 return types.NullableUint16Value(x), nil 76 case int32: 77 return types.Int32Value(x), nil 78 case *int32: 79 return types.NullableInt32Value(x), nil 80 case uint32: 81 return types.Uint32Value(x), nil 82 case *uint32: 83 return types.NullableUint32Value(x), nil 84 case int64: 85 return types.Int64Value(x), nil 86 case *int64: 87 return types.NullableInt64Value(x), nil 88 case uint64: 89 return types.Uint64Value(x), nil 90 case *uint64: 91 return types.NullableUint64Value(x), nil 92 case float32: 93 return types.FloatValue(x), nil 94 case *float32: 95 return types.NullableFloatValue(x), nil 96 case float64: 97 return types.DoubleValue(x), nil 98 case *float64: 99 return types.NullableDoubleValue(x), nil 100 case []byte: 101 return types.BytesValue(x), nil 102 case *[]byte: 103 return types.NullableBytesValue(x), nil 104 case string: 105 return types.TextValue(x), nil 106 case *string: 107 return types.NullableTextValue(x), nil 108 case []string: 109 items := make([]types.Value, len(x)) 110 for i := range x { 111 items[i] = types.TextValue(x[i]) 112 } 113 114 return types.ListValue(items...), nil 115 case [16]byte: 116 return types.UUIDValue(x), nil 117 case *[16]byte: 118 return types.NullableUUIDValue(x), nil 119 case time.Time: 120 return types.TimestampValueFromTime(x), nil 121 case *time.Time: 122 return types.NullableTimestampValueFromTime(x), nil 123 case time.Duration: 124 return types.IntervalValueFromDuration(x), nil 125 case *time.Duration: 126 return types.NullableIntervalValueFromDuration(x), nil 127 default: 128 return nil, xerrors.WithStackTrace( 129 fmt.Errorf("%T: %w. Create issue for support new type %s", 130 x, errUnsupportedType, supportNewTypeLink(x), 131 ), 132 ) 133 } 134 } 135 136 func supportNewTypeLink(x interface{}) string { 137 v := url.Values{} 138 v.Add("labels", "enhancement,database/sql") 139 v.Add("template", "02_FEATURE_REQUEST.md") 140 v.Add("title", fmt.Sprintf("feat: Support new type `%T` in `database/sql` query args", x)) 141 142 return "https://github.com/ydb-platform/ydb-go-sdk/issues/new?" + v.Encode() 143 } 144 145 func toYdbParam(name string, value interface{}) (*params.Parameter, error) { 146 if na, ok := value.(driver.NamedValue); ok { 147 n, v := na.Name, na.Value 148 if n != "" { 149 name = n 150 } 151 value = v 152 } 153 if na, ok := value.(sql.NamedArg); ok { 154 n, v := na.Name, na.Value 155 if n != "" { 156 name = n 157 } 158 value = v 159 } 160 if v, ok := value.(*params.Parameter); ok { 161 return v, nil 162 } 163 v, err := toValue(value) 164 if err != nil { 165 return nil, xerrors.WithStackTrace(err) 166 } 167 if name == "" { 168 return nil, xerrors.WithStackTrace(errUnnamedParam) 169 } 170 if name[0] != '$' { 171 name = "$" + name 172 } 173 174 return params.Named(name, v), nil 175 } 176 177 func Params(args ...interface{}) (parameters []*params.Parameter, _ error) { 178 parameters = make([]*params.Parameter, 0, len(args)) 179 for i, arg := range args { 180 switch x := arg.(type) { 181 case driver.NamedValue: 182 if x.Name == "" { 183 switch xx := x.Value.(type) { 184 case *params.Parameters: 185 if len(args) > 1 { 186 return nil, xerrors.WithStackTrace(errMultipleQueryParameters) 187 } 188 parameters = *xx 189 case *params.Parameter: 190 parameters = append(parameters, xx) 191 default: 192 x.Name = fmt.Sprintf("$p%d", i) 193 param, err := toYdbParam(x.Name, x.Value) 194 if err != nil { 195 return nil, xerrors.WithStackTrace(err) 196 } 197 parameters = append(parameters, param) 198 } 199 } else { 200 param, err := toYdbParam(x.Name, x.Value) 201 if err != nil { 202 return nil, xerrors.WithStackTrace(err) 203 } 204 parameters = append(parameters, param) 205 } 206 case sql.NamedArg: 207 if x.Name == "" { 208 return nil, xerrors.WithStackTrace(errUnnamedParam) 209 } 210 param, err := toYdbParam(x.Name, x.Value) 211 if err != nil { 212 return nil, xerrors.WithStackTrace(err) 213 } 214 parameters = append(parameters, param) 215 case *params.Parameters: 216 if len(args) > 1 { 217 return nil, xerrors.WithStackTrace(errMultipleQueryParameters) 218 } 219 parameters = *x 220 case *params.Parameter: 221 parameters = append(parameters, x) 222 default: 223 param, err := toYdbParam(fmt.Sprintf("$p%d", i), x) 224 if err != nil { 225 return nil, xerrors.WithStackTrace(err) 226 } 227 parameters = append(parameters, param) 228 } 229 } 230 sort.Slice(parameters, func(i, j int) bool { 231 return parameters[i].Name() < parameters[j].Name() 232 }) 233 234 return parameters, nil 235 }