github.com/mattn/go@v0.0.0-20171011075504-07f7db3ea99f/src/database/sql/convert_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"database/sql/driver"
     9  	"fmt"
    10  	"reflect"
    11  	"runtime"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  var someTime = time.Unix(123, 0)
    19  var answer int64 = 42
    20  
    21  type (
    22  	userDefined       float64
    23  	userDefinedSlice  []int
    24  	userDefinedString string
    25  )
    26  
    27  type conversionTest struct {
    28  	s, d interface{} // source and destination
    29  
    30  	// following are used if they're non-zero
    31  	wantint    int64
    32  	wantuint   uint64
    33  	wantstr    string
    34  	wantbytes  []byte
    35  	wantraw    RawBytes
    36  	wantf32    float32
    37  	wantf64    float64
    38  	wanttime   time.Time
    39  	wantbool   bool // used if d is of type *bool
    40  	wanterr    string
    41  	wantiface  interface{}
    42  	wantptr    *int64 // if non-nil, *d's pointed value must be equal to *wantptr
    43  	wantnil    bool   // if true, *d must be *int64(nil)
    44  	wantusrdef userDefined
    45  	wantusrstr userDefinedString
    46  }
    47  
    48  // Target variables for scanning into.
    49  var (
    50  	scanstr    string
    51  	scanbytes  []byte
    52  	scanraw    RawBytes
    53  	scanint    int
    54  	scanint8   int8
    55  	scanint16  int16
    56  	scanint32  int32
    57  	scanuint8  uint8
    58  	scanuint16 uint16
    59  	scanbool   bool
    60  	scanf32    float32
    61  	scanf64    float64
    62  	scantime   time.Time
    63  	scanptr    *int64
    64  	scaniface  interface{}
    65  )
    66  
    67  var conversionTests = []conversionTest{
    68  	// Exact conversions (destination pointer type matches source type)
    69  	{s: "foo", d: &scanstr, wantstr: "foo"},
    70  	{s: 123, d: &scanint, wantint: 123},
    71  	{s: someTime, d: &scantime, wanttime: someTime},
    72  
    73  	// To strings
    74  	{s: "string", d: &scanstr, wantstr: "string"},
    75  	{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
    76  	{s: 123, d: &scanstr, wantstr: "123"},
    77  	{s: int8(123), d: &scanstr, wantstr: "123"},
    78  	{s: int64(123), d: &scanstr, wantstr: "123"},
    79  	{s: uint8(123), d: &scanstr, wantstr: "123"},
    80  	{s: uint16(123), d: &scanstr, wantstr: "123"},
    81  	{s: uint32(123), d: &scanstr, wantstr: "123"},
    82  	{s: uint64(123), d: &scanstr, wantstr: "123"},
    83  	{s: 1.5, d: &scanstr, wantstr: "1.5"},
    84  
    85  	// From time.Time:
    86  	{s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
    87  	{s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
    88  	{s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
    89  	{s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
    90  	{s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
    91  	{s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
    92  
    93  	// To []byte
    94  	{s: nil, d: &scanbytes, wantbytes: nil},
    95  	{s: "string", d: &scanbytes, wantbytes: []byte("string")},
    96  	{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
    97  	{s: 123, d: &scanbytes, wantbytes: []byte("123")},
    98  	{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
    99  	{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
   100  	{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
   101  	{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
   102  	{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
   103  	{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
   104  	{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
   105  
   106  	// To RawBytes
   107  	{s: nil, d: &scanraw, wantraw: nil},
   108  	{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
   109  	{s: "string", d: &scanraw, wantraw: RawBytes("string")},
   110  	{s: 123, d: &scanraw, wantraw: RawBytes("123")},
   111  	{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
   112  	{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
   113  	{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
   114  	{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
   115  	{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
   116  	{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
   117  	{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
   118  	// time.Time has been placed here to check that the RawBytes slice gets
   119  	// correctly reset when calling time.Time.AppendFormat.
   120  	{s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
   121  
   122  	// Strings to integers
   123  	{s: "255", d: &scanuint8, wantuint: 255},
   124  	{s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
   125  	{s: "256", d: &scanuint16, wantuint: 256},
   126  	{s: "-1", d: &scanint, wantint: -1},
   127  	{s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
   128  
   129  	// int64 to smaller integers
   130  	{s: int64(5), d: &scanuint8, wantuint: 5},
   131  	{s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
   132  	{s: int64(256), d: &scanuint16, wantuint: 256},
   133  	{s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
   134  
   135  	// True bools
   136  	{s: true, d: &scanbool, wantbool: true},
   137  	{s: "True", d: &scanbool, wantbool: true},
   138  	{s: "TRUE", d: &scanbool, wantbool: true},
   139  	{s: "1", d: &scanbool, wantbool: true},
   140  	{s: 1, d: &scanbool, wantbool: true},
   141  	{s: int64(1), d: &scanbool, wantbool: true},
   142  	{s: uint16(1), d: &scanbool, wantbool: true},
   143  
   144  	// False bools
   145  	{s: false, d: &scanbool, wantbool: false},
   146  	{s: "false", d: &scanbool, wantbool: false},
   147  	{s: "FALSE", d: &scanbool, wantbool: false},
   148  	{s: "0", d: &scanbool, wantbool: false},
   149  	{s: 0, d: &scanbool, wantbool: false},
   150  	{s: int64(0), d: &scanbool, wantbool: false},
   151  	{s: uint16(0), d: &scanbool, wantbool: false},
   152  
   153  	// Not bools
   154  	{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
   155  	{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
   156  
   157  	// Floats
   158  	{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
   159  	{s: int64(1), d: &scanf64, wantf64: float64(1)},
   160  	{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
   161  	{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
   162  	{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
   163  
   164  	// Pointers
   165  	{s: interface{}(nil), d: &scanptr, wantnil: true},
   166  	{s: int64(42), d: &scanptr, wantptr: &answer},
   167  
   168  	// To interface{}
   169  	{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
   170  	{s: int64(1), d: &scaniface, wantiface: int64(1)},
   171  	{s: "str", d: &scaniface, wantiface: "str"},
   172  	{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
   173  	{s: true, d: &scaniface, wantiface: true},
   174  	{s: nil, d: &scaniface},
   175  	{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
   176  
   177  	// To a user-defined type
   178  	{s: 1.5, d: new(userDefined), wantusrdef: 1.5},
   179  	{s: int64(123), d: new(userDefined), wantusrdef: 123},
   180  	{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
   181  	{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
   182  	{s: "str", d: new(userDefinedString), wantusrstr: "str"},
   183  
   184  	// Other errors
   185  	{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
   186  }
   187  
   188  func intPtrValue(intptr interface{}) interface{} {
   189  	return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
   190  }
   191  
   192  func intValue(intptr interface{}) int64 {
   193  	return reflect.Indirect(reflect.ValueOf(intptr)).Int()
   194  }
   195  
   196  func uintValue(intptr interface{}) uint64 {
   197  	return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
   198  }
   199  
   200  func float64Value(ptr interface{}) float64 {
   201  	return *(ptr.(*float64))
   202  }
   203  
   204  func float32Value(ptr interface{}) float32 {
   205  	return *(ptr.(*float32))
   206  }
   207  
   208  func timeValue(ptr interface{}) time.Time {
   209  	return *(ptr.(*time.Time))
   210  }
   211  
   212  func TestConversions(t *testing.T) {
   213  	for n, ct := range conversionTests {
   214  		err := convertAssign(ct.d, ct.s)
   215  		errstr := ""
   216  		if err != nil {
   217  			errstr = err.Error()
   218  		}
   219  		errf := func(format string, args ...interface{}) {
   220  			base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
   221  			t.Errorf(base+format, args...)
   222  		}
   223  		if errstr != ct.wanterr {
   224  			errf("got error %q, want error %q", errstr, ct.wanterr)
   225  		}
   226  		if ct.wantstr != "" && ct.wantstr != scanstr {
   227  			errf("want string %q, got %q", ct.wantstr, scanstr)
   228  		}
   229  		if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
   230  			errf("want byte %q, got %q", ct.wantbytes, scanbytes)
   231  		}
   232  		if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
   233  			errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
   234  		}
   235  		if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
   236  			errf("want int %d, got %d", ct.wantint, intValue(ct.d))
   237  		}
   238  		if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
   239  			errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
   240  		}
   241  		if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
   242  			errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
   243  		}
   244  		if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
   245  			errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
   246  		}
   247  		if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
   248  			errf("want bool %v, got %v", ct.wantbool, *bp)
   249  		}
   250  		if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
   251  			errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
   252  		}
   253  		if ct.wantnil && *ct.d.(**int64) != nil {
   254  			errf("want nil, got %v", intPtrValue(ct.d))
   255  		}
   256  		if ct.wantptr != nil {
   257  			if *ct.d.(**int64) == nil {
   258  				errf("want pointer to %v, got nil", *ct.wantptr)
   259  			} else if *ct.wantptr != intPtrValue(ct.d) {
   260  				errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
   261  			}
   262  		}
   263  		if ifptr, ok := ct.d.(*interface{}); ok {
   264  			if !reflect.DeepEqual(ct.wantiface, scaniface) {
   265  				errf("want interface %#v, got %#v", ct.wantiface, scaniface)
   266  				continue
   267  			}
   268  			if srcBytes, ok := ct.s.([]byte); ok {
   269  				dstBytes := (*ifptr).([]byte)
   270  				if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
   271  					errf("copy into interface{} didn't copy []byte data")
   272  				}
   273  			}
   274  		}
   275  		if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
   276  			errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
   277  		}
   278  		if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
   279  			errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
   280  		}
   281  	}
   282  }
   283  
   284  func TestNullString(t *testing.T) {
   285  	var ns NullString
   286  	convertAssign(&ns, []byte("foo"))
   287  	if !ns.Valid {
   288  		t.Errorf("expecting not null")
   289  	}
   290  	if ns.String != "foo" {
   291  		t.Errorf("expecting foo; got %q", ns.String)
   292  	}
   293  	convertAssign(&ns, nil)
   294  	if ns.Valid {
   295  		t.Errorf("expecting null on nil")
   296  	}
   297  	if ns.String != "" {
   298  		t.Errorf("expecting blank on nil; got %q", ns.String)
   299  	}
   300  }
   301  
   302  type valueConverterTest struct {
   303  	c       driver.ValueConverter
   304  	in, out interface{}
   305  	err     string
   306  }
   307  
   308  var valueConverterTests = []valueConverterTest{
   309  	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
   310  	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
   311  }
   312  
   313  func TestValueConverters(t *testing.T) {
   314  	for i, tt := range valueConverterTests {
   315  		out, err := tt.c.ConvertValue(tt.in)
   316  		goterr := ""
   317  		if err != nil {
   318  			goterr = err.Error()
   319  		}
   320  		if goterr != tt.err {
   321  			t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
   322  				i, tt.c, tt.in, tt.in, goterr, tt.err)
   323  		}
   324  		if tt.err != "" {
   325  			continue
   326  		}
   327  		if !reflect.DeepEqual(out, tt.out) {
   328  			t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
   329  				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
   330  		}
   331  	}
   332  }
   333  
   334  // Tests that assigning to RawBytes doesn't allocate (and also works).
   335  func TestRawBytesAllocs(t *testing.T) {
   336  	var tests = []struct {
   337  		name string
   338  		in   interface{}
   339  		want string
   340  	}{
   341  		{"uint64", uint64(12345678), "12345678"},
   342  		{"uint32", uint32(1234), "1234"},
   343  		{"uint16", uint16(12), "12"},
   344  		{"uint8", uint8(1), "1"},
   345  		{"uint", uint(123), "123"},
   346  		{"int", int(123), "123"},
   347  		{"int8", int8(1), "1"},
   348  		{"int16", int16(12), "12"},
   349  		{"int32", int32(1234), "1234"},
   350  		{"int64", int64(12345678), "12345678"},
   351  		{"float32", float32(1.5), "1.5"},
   352  		{"float64", float64(64), "64"},
   353  		{"bool", false, "false"},
   354  		{"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
   355  	}
   356  
   357  	buf := make(RawBytes, 10)
   358  	test := func(name string, in interface{}, want string) {
   359  		if err := convertAssign(&buf, in); err != nil {
   360  			t.Fatalf("%s: convertAssign = %v", name, err)
   361  		}
   362  		match := len(buf) == len(want)
   363  		if match {
   364  			for i, b := range buf {
   365  				if want[i] != b {
   366  					match = false
   367  					break
   368  				}
   369  			}
   370  		}
   371  		if !match {
   372  			t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
   373  		}
   374  	}
   375  
   376  	n := testing.AllocsPerRun(100, func() {
   377  		for _, tt := range tests {
   378  			test(tt.name, tt.in, tt.want)
   379  		}
   380  	})
   381  
   382  	// The numbers below are only valid for 64-bit interface word sizes,
   383  	// and gc. With 32-bit words there are more convT2E allocs, and
   384  	// with gccgo, only pointers currently go in interface data.
   385  	// So only care on amd64 gc for now.
   386  	measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc"
   387  
   388  	if n > 0.5 && measureAllocs {
   389  		t.Fatalf("allocs = %v; want 0", n)
   390  	}
   391  
   392  	// This one involves a convT2E allocation, string -> interface{}
   393  	n = testing.AllocsPerRun(100, func() {
   394  		test("string", "foo", "foo")
   395  	})
   396  	if n > 1.5 && measureAllocs {
   397  		t.Fatalf("allocs = %v; want max 1", n)
   398  	}
   399  }
   400  
   401  // https://github.com/golang/go/issues/13905
   402  func TestUserDefinedBytes(t *testing.T) {
   403  	type userDefinedBytes []byte
   404  	var u userDefinedBytes
   405  	v := []byte("foo")
   406  
   407  	convertAssign(&u, v)
   408  	if &u[0] == &v[0] {
   409  		t.Fatal("userDefinedBytes got potentially dirty driver memory")
   410  	}
   411  }
   412  
   413  type Valuer_V string
   414  
   415  func (v Valuer_V) Value() (driver.Value, error) {
   416  	return strings.ToUpper(string(v)), nil
   417  }
   418  
   419  type Valuer_P string
   420  
   421  func (p *Valuer_P) Value() (driver.Value, error) {
   422  	if p == nil {
   423  		return "nil-to-str", nil
   424  	}
   425  	return strings.ToUpper(string(*p)), nil
   426  }
   427  
   428  func TestDriverArgs(t *testing.T) {
   429  	var nilValuerVPtr *Valuer_V
   430  	var nilValuerPPtr *Valuer_P
   431  	var nilStrPtr *string
   432  	tests := []struct {
   433  		args []interface{}
   434  		want []driver.NamedValue
   435  	}{
   436  		0: {
   437  			args: []interface{}{Valuer_V("foo")},
   438  			want: []driver.NamedValue{
   439  				driver.NamedValue{
   440  					Ordinal: 1,
   441  					Value:   "FOO",
   442  				},
   443  			},
   444  		},
   445  		1: {
   446  			args: []interface{}{nilValuerVPtr},
   447  			want: []driver.NamedValue{
   448  				driver.NamedValue{
   449  					Ordinal: 1,
   450  					Value:   nil,
   451  				},
   452  			},
   453  		},
   454  		2: {
   455  			args: []interface{}{nilValuerPPtr},
   456  			want: []driver.NamedValue{
   457  				driver.NamedValue{
   458  					Ordinal: 1,
   459  					Value:   "nil-to-str",
   460  				},
   461  			},
   462  		},
   463  		3: {
   464  			args: []interface{}{"plain-str"},
   465  			want: []driver.NamedValue{
   466  				driver.NamedValue{
   467  					Ordinal: 1,
   468  					Value:   "plain-str",
   469  				},
   470  			},
   471  		},
   472  		4: {
   473  			args: []interface{}{nilStrPtr},
   474  			want: []driver.NamedValue{
   475  				driver.NamedValue{
   476  					Ordinal: 1,
   477  					Value:   nil,
   478  				},
   479  			},
   480  		},
   481  	}
   482  	for i, tt := range tests {
   483  		ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
   484  		got, err := driverArgs(nil, ds, tt.args)
   485  		if err != nil {
   486  			t.Errorf("test[%d]: %v", i, err)
   487  			continue
   488  		}
   489  		if !reflect.DeepEqual(got, tt.want) {
   490  			t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
   491  		}
   492  	}
   493  }