github.com/dolthub/go-mysql-server@v0.18.0/driver/value.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package driver 16 17 import ( 18 "database/sql/driver" 19 "errors" 20 "fmt" 21 "strconv" 22 "time" 23 24 "github.com/dolthub/vitess/go/sqltypes" 25 querypb "github.com/dolthub/vitess/go/vt/proto/query" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/expression" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 // ErrUnsupportedType is returned when a query argument of an unsupported type is passed to a statement 33 var ErrUnsupportedType = errors.New("unsupported type") 34 35 func valueToExpr(v driver.Value) (sql.Expression, error) { 36 if v == nil { 37 return expression.NewLiteral(nil, types.Null), nil 38 } 39 40 var typ sql.Type 41 var err error 42 switch v := v.(type) { 43 case int64: 44 typ = types.Int64 45 case float64: 46 typ = types.Float64 47 case bool: 48 typ = types.Boolean 49 case []byte: 50 typ, err = types.CreateBinary(sqltypes.Blob, int64(len(v))) 51 case string: 52 typ, err = types.CreateStringWithDefaults(sqltypes.Text, int64(len(v))) 53 case time.Time: 54 typ = types.Datetime 55 default: 56 return nil, fmt.Errorf("%w: %T", ErrUnsupportedType, v) 57 } 58 if err != nil { 59 return nil, err 60 } 61 62 c, _, err := typ.Convert(v) 63 if err != nil { 64 return nil, err 65 } 66 return expression.NewLiteral(c, typ), nil 67 } 68 69 func valuesToBindings(v []driver.Value) (map[string]*querypb.BindVariable, error) { 70 if len(v) == 0 { 71 return nil, nil 72 } 73 74 b := map[string]*querypb.BindVariable{} 75 76 var err error 77 for i, v := range v { 78 b[strconv.FormatInt(int64(i), 10)], err = sqltypes.BuildBindVariable(v) 79 if err != nil { 80 return nil, err 81 } 82 } 83 84 return b, nil 85 } 86 87 func namedValuesToBindings(v []driver.NamedValue) (map[string]*querypb.BindVariable, error) { 88 if len(v) == 0 { 89 return nil, nil 90 } 91 92 b := map[string]*querypb.BindVariable{} 93 94 var err error 95 for _, v := range v { 96 name := v.Name 97 if name == "" { 98 name = "v" + strconv.FormatInt(int64(v.Ordinal), 10) 99 } 100 101 val := v.Value 102 if t, ok := val.(time.Time); ok { 103 val = t.Format(time.RFC3339Nano) 104 } 105 b[name], err = sqltypes.BuildBindVariable(val) 106 if err != nil { 107 return nil, err 108 } 109 } 110 111 return b, nil 112 }