github.com/dolthub/go-mysql-server@v0.18.0/driver/value.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package driver
    16  
    17  import (
    18  	"database/sql/driver"
    19  	"errors"
    20  	"fmt"
    21  	"strconv"
    22  	"time"
    23  
    24  	"github.com/dolthub/vitess/go/sqltypes"
    25  	querypb "github.com/dolthub/vitess/go/vt/proto/query"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/expression"
    29  	"github.com/dolthub/go-mysql-server/sql/types"
    30  )
    31  
    32  // ErrUnsupportedType is returned when a query argument of an unsupported type is passed to a statement
    33  var ErrUnsupportedType = errors.New("unsupported type")
    34  
    35  func valueToExpr(v driver.Value) (sql.Expression, error) {
    36  	if v == nil {
    37  		return expression.NewLiteral(nil, types.Null), nil
    38  	}
    39  
    40  	var typ sql.Type
    41  	var err error
    42  	switch v := v.(type) {
    43  	case int64:
    44  		typ = types.Int64
    45  	case float64:
    46  		typ = types.Float64
    47  	case bool:
    48  		typ = types.Boolean
    49  	case []byte:
    50  		typ, err = types.CreateBinary(sqltypes.Blob, int64(len(v)))
    51  	case string:
    52  		typ, err = types.CreateStringWithDefaults(sqltypes.Text, int64(len(v)))
    53  	case time.Time:
    54  		typ = types.Datetime
    55  	default:
    56  		return nil, fmt.Errorf("%w: %T", ErrUnsupportedType, v)
    57  	}
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	c, _, err := typ.Convert(v)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	return expression.NewLiteral(c, typ), nil
    67  }
    68  
    69  func valuesToBindings(v []driver.Value) (map[string]*querypb.BindVariable, error) {
    70  	if len(v) == 0 {
    71  		return nil, nil
    72  	}
    73  
    74  	b := map[string]*querypb.BindVariable{}
    75  
    76  	var err error
    77  	for i, v := range v {
    78  		b[strconv.FormatInt(int64(i), 10)], err = sqltypes.BuildBindVariable(v)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  	}
    83  
    84  	return b, nil
    85  }
    86  
    87  func namedValuesToBindings(v []driver.NamedValue) (map[string]*querypb.BindVariable, error) {
    88  	if len(v) == 0 {
    89  		return nil, nil
    90  	}
    91  
    92  	b := map[string]*querypb.BindVariable{}
    93  
    94  	var err error
    95  	for _, v := range v {
    96  		name := v.Name
    97  		if name == "" {
    98  			name = "v" + strconv.FormatInt(int64(v.Ordinal), 10)
    99  		}
   100  
   101  		val := v.Value
   102  		if t, ok := val.(time.Time); ok {
   103  			val = t.Format(time.RFC3339Nano)
   104  		}
   105  		b[name], err = sqltypes.BuildBindVariable(val)
   106  		if err != nil {
   107  			return nil, err
   108  		}
   109  	}
   110  
   111  	return b, nil
   112  }