github.com/ltltlt/go-source-code@v0.0.0-20190830023027-95be009773aa/database/sql/convert_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"database/sql/driver"
     9  	"fmt"
    10  	"reflect"
    11  	"runtime"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  var someTime = time.Unix(123, 0)
    19  var answer int64 = 42
    20  
    21  type (
    22  	userDefined       float64
    23  	userDefinedSlice  []int
    24  	userDefinedString string
    25  )
    26  
    27  type conversionTest struct {
    28  	s, d interface{} // source and destination
    29  
    30  	// following are used if they're non-zero
    31  	wantint    int64
    32  	wantuint   uint64
    33  	wantstr    string
    34  	wantbytes  []byte
    35  	wantraw    RawBytes
    36  	wantf32    float32
    37  	wantf64    float64
    38  	wanttime   time.Time
    39  	wantbool   bool // used if d is of type *bool
    40  	wanterr    string
    41  	wantiface  interface{}
    42  	wantptr    *int64 // if non-nil, *d's pointed value must be equal to *wantptr
    43  	wantnil    bool   // if true, *d must be *int64(nil)
    44  	wantusrdef userDefined
    45  	wantusrstr userDefinedString
    46  }
    47  
    48  // Target variables for scanning into.
    49  var (
    50  	scanstr    string
    51  	scanbytes  []byte
    52  	scanraw    RawBytes
    53  	scanint    int
    54  	scanint8   int8
    55  	scanint16  int16
    56  	scanint32  int32
    57  	scanuint8  uint8
    58  	scanuint16 uint16
    59  	scanbool   bool
    60  	scanf32    float32
    61  	scanf64    float64
    62  	scantime   time.Time
    63  	scanptr    *int64
    64  	scaniface  interface{}
    65  )
    66  
    67  func conversionTests() []conversionTest {
    68  	// Return a fresh instance to test so "go test -count 2" works correctly.
    69  	return []conversionTest{
    70  		// Exact conversions (destination pointer type matches source type)
    71  		{s: "foo", d: &scanstr, wantstr: "foo"},
    72  		{s: 123, d: &scanint, wantint: 123},
    73  		{s: someTime, d: &scantime, wanttime: someTime},
    74  
    75  		// To strings
    76  		{s: "string", d: &scanstr, wantstr: "string"},
    77  		{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
    78  		{s: 123, d: &scanstr, wantstr: "123"},
    79  		{s: int8(123), d: &scanstr, wantstr: "123"},
    80  		{s: int64(123), d: &scanstr, wantstr: "123"},
    81  		{s: uint8(123), d: &scanstr, wantstr: "123"},
    82  		{s: uint16(123), d: &scanstr, wantstr: "123"},
    83  		{s: uint32(123), d: &scanstr, wantstr: "123"},
    84  		{s: uint64(123), d: &scanstr, wantstr: "123"},
    85  		{s: 1.5, d: &scanstr, wantstr: "1.5"},
    86  
    87  		// From time.Time:
    88  		{s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
    89  		{s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
    90  		{s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
    91  		{s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
    92  		{s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
    93  		{s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
    94  
    95  		// To []byte
    96  		{s: nil, d: &scanbytes, wantbytes: nil},
    97  		{s: "string", d: &scanbytes, wantbytes: []byte("string")},
    98  		{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
    99  		{s: 123, d: &scanbytes, wantbytes: []byte("123")},
   100  		{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
   101  		{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
   102  		{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
   103  		{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
   104  		{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
   105  		{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
   106  		{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
   107  
   108  		// To RawBytes
   109  		{s: nil, d: &scanraw, wantraw: nil},
   110  		{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
   111  		{s: "string", d: &scanraw, wantraw: RawBytes("string")},
   112  		{s: 123, d: &scanraw, wantraw: RawBytes("123")},
   113  		{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
   114  		{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
   115  		{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
   116  		{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
   117  		{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
   118  		{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
   119  		{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
   120  		// time.Time has been placed here to check that the RawBytes slice gets
   121  		// correctly reset when calling time.Time.AppendFormat.
   122  		{s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
   123  
   124  		// Strings to integers
   125  		{s: "255", d: &scanuint8, wantuint: 255},
   126  		{s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
   127  		{s: "256", d: &scanuint16, wantuint: 256},
   128  		{s: "-1", d: &scanint, wantint: -1},
   129  		{s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
   130  
   131  		// int64 to smaller integers
   132  		{s: int64(5), d: &scanuint8, wantuint: 5},
   133  		{s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
   134  		{s: int64(256), d: &scanuint16, wantuint: 256},
   135  		{s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
   136  
   137  		// True bools
   138  		{s: true, d: &scanbool, wantbool: true},
   139  		{s: "True", d: &scanbool, wantbool: true},
   140  		{s: "TRUE", d: &scanbool, wantbool: true},
   141  		{s: "1", d: &scanbool, wantbool: true},
   142  		{s: 1, d: &scanbool, wantbool: true},
   143  		{s: int64(1), d: &scanbool, wantbool: true},
   144  		{s: uint16(1), d: &scanbool, wantbool: true},
   145  
   146  		// False bools
   147  		{s: false, d: &scanbool, wantbool: false},
   148  		{s: "false", d: &scanbool, wantbool: false},
   149  		{s: "FALSE", d: &scanbool, wantbool: false},
   150  		{s: "0", d: &scanbool, wantbool: false},
   151  		{s: 0, d: &scanbool, wantbool: false},
   152  		{s: int64(0), d: &scanbool, wantbool: false},
   153  		{s: uint16(0), d: &scanbool, wantbool: false},
   154  
   155  		// Not bools
   156  		{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
   157  		{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
   158  
   159  		// Floats
   160  		{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
   161  		{s: int64(1), d: &scanf64, wantf64: float64(1)},
   162  		{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
   163  		{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
   164  		{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
   165  
   166  		// Pointers
   167  		{s: interface{}(nil), d: &scanptr, wantnil: true},
   168  		{s: int64(42), d: &scanptr, wantptr: &answer},
   169  
   170  		// To interface{}
   171  		{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
   172  		{s: int64(1), d: &scaniface, wantiface: int64(1)},
   173  		{s: "str", d: &scaniface, wantiface: "str"},
   174  		{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
   175  		{s: true, d: &scaniface, wantiface: true},
   176  		{s: nil, d: &scaniface},
   177  		{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
   178  
   179  		// To a user-defined type
   180  		{s: 1.5, d: new(userDefined), wantusrdef: 1.5},
   181  		{s: int64(123), d: new(userDefined), wantusrdef: 123},
   182  		{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
   183  		{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
   184  		{s: "str", d: new(userDefinedString), wantusrstr: "str"},
   185  
   186  		// Other errors
   187  		{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
   188  	}
   189  }
   190  
   191  func intPtrValue(intptr interface{}) interface{} {
   192  	return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
   193  }
   194  
   195  func intValue(intptr interface{}) int64 {
   196  	return reflect.Indirect(reflect.ValueOf(intptr)).Int()
   197  }
   198  
   199  func uintValue(intptr interface{}) uint64 {
   200  	return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
   201  }
   202  
   203  func float64Value(ptr interface{}) float64 {
   204  	return *(ptr.(*float64))
   205  }
   206  
   207  func float32Value(ptr interface{}) float32 {
   208  	return *(ptr.(*float32))
   209  }
   210  
   211  func timeValue(ptr interface{}) time.Time {
   212  	return *(ptr.(*time.Time))
   213  }
   214  
   215  func TestConversions(t *testing.T) {
   216  	for n, ct := range conversionTests() {
   217  		err := convertAssign(ct.d, ct.s)
   218  		errstr := ""
   219  		if err != nil {
   220  			errstr = err.Error()
   221  		}
   222  		errf := func(format string, args ...interface{}) {
   223  			base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
   224  			t.Errorf(base+format, args...)
   225  		}
   226  		if errstr != ct.wanterr {
   227  			errf("got error %q, want error %q", errstr, ct.wanterr)
   228  		}
   229  		if ct.wantstr != "" && ct.wantstr != scanstr {
   230  			errf("want string %q, got %q", ct.wantstr, scanstr)
   231  		}
   232  		if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
   233  			errf("want byte %q, got %q", ct.wantbytes, scanbytes)
   234  		}
   235  		if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
   236  			errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
   237  		}
   238  		if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
   239  			errf("want int %d, got %d", ct.wantint, intValue(ct.d))
   240  		}
   241  		if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
   242  			errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
   243  		}
   244  		if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
   245  			errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
   246  		}
   247  		if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
   248  			errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
   249  		}
   250  		if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
   251  			errf("want bool %v, got %v", ct.wantbool, *bp)
   252  		}
   253  		if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
   254  			errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
   255  		}
   256  		if ct.wantnil && *ct.d.(**int64) != nil {
   257  			errf("want nil, got %v", intPtrValue(ct.d))
   258  		}
   259  		if ct.wantptr != nil {
   260  			if *ct.d.(**int64) == nil {
   261  				errf("want pointer to %v, got nil", *ct.wantptr)
   262  			} else if *ct.wantptr != intPtrValue(ct.d) {
   263  				errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
   264  			}
   265  		}
   266  		if ifptr, ok := ct.d.(*interface{}); ok {
   267  			if !reflect.DeepEqual(ct.wantiface, scaniface) {
   268  				errf("want interface %#v, got %#v", ct.wantiface, scaniface)
   269  				continue
   270  			}
   271  			if srcBytes, ok := ct.s.([]byte); ok {
   272  				dstBytes := (*ifptr).([]byte)
   273  				if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
   274  					errf("copy into interface{} didn't copy []byte data")
   275  				}
   276  			}
   277  		}
   278  		if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
   279  			errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
   280  		}
   281  		if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
   282  			errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
   283  		}
   284  	}
   285  }
   286  
   287  func TestNullString(t *testing.T) {
   288  	var ns NullString
   289  	convertAssign(&ns, []byte("foo"))
   290  	if !ns.Valid {
   291  		t.Errorf("expecting not null")
   292  	}
   293  	if ns.String != "foo" {
   294  		t.Errorf("expecting foo; got %q", ns.String)
   295  	}
   296  	convertAssign(&ns, nil)
   297  	if ns.Valid {
   298  		t.Errorf("expecting null on nil")
   299  	}
   300  	if ns.String != "" {
   301  		t.Errorf("expecting blank on nil; got %q", ns.String)
   302  	}
   303  }
   304  
   305  type valueConverterTest struct {
   306  	c       driver.ValueConverter
   307  	in, out interface{}
   308  	err     string
   309  }
   310  
   311  var valueConverterTests = []valueConverterTest{
   312  	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
   313  	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
   314  }
   315  
   316  func TestValueConverters(t *testing.T) {
   317  	for i, tt := range valueConverterTests {
   318  		out, err := tt.c.ConvertValue(tt.in)
   319  		goterr := ""
   320  		if err != nil {
   321  			goterr = err.Error()
   322  		}
   323  		if goterr != tt.err {
   324  			t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
   325  				i, tt.c, tt.in, tt.in, goterr, tt.err)
   326  		}
   327  		if tt.err != "" {
   328  			continue
   329  		}
   330  		if !reflect.DeepEqual(out, tt.out) {
   331  			t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
   332  				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
   333  		}
   334  	}
   335  }
   336  
   337  // Tests that assigning to RawBytes doesn't allocate (and also works).
   338  func TestRawBytesAllocs(t *testing.T) {
   339  	var tests = []struct {
   340  		name string
   341  		in   interface{}
   342  		want string
   343  	}{
   344  		{"uint64", uint64(12345678), "12345678"},
   345  		{"uint32", uint32(1234), "1234"},
   346  		{"uint16", uint16(12), "12"},
   347  		{"uint8", uint8(1), "1"},
   348  		{"uint", uint(123), "123"},
   349  		{"int", int(123), "123"},
   350  		{"int8", int8(1), "1"},
   351  		{"int16", int16(12), "12"},
   352  		{"int32", int32(1234), "1234"},
   353  		{"int64", int64(12345678), "12345678"},
   354  		{"float32", float32(1.5), "1.5"},
   355  		{"float64", float64(64), "64"},
   356  		{"bool", false, "false"},
   357  		{"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
   358  	}
   359  
   360  	buf := make(RawBytes, 10)
   361  	test := func(name string, in interface{}, want string) {
   362  		if err := convertAssign(&buf, in); err != nil {
   363  			t.Fatalf("%s: convertAssign = %v", name, err)
   364  		}
   365  		match := len(buf) == len(want)
   366  		if match {
   367  			for i, b := range buf {
   368  				if want[i] != b {
   369  					match = false
   370  					break
   371  				}
   372  			}
   373  		}
   374  		if !match {
   375  			t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
   376  		}
   377  	}
   378  
   379  	n := testing.AllocsPerRun(100, func() {
   380  		for _, tt := range tests {
   381  			test(tt.name, tt.in, tt.want)
   382  		}
   383  	})
   384  
   385  	// The numbers below are only valid for 64-bit interface word sizes,
   386  	// and gc. With 32-bit words there are more convT2E allocs, and
   387  	// with gccgo, only pointers currently go in interface data.
   388  	// So only care on amd64 gc for now.
   389  	measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc"
   390  
   391  	if n > 0.5 && measureAllocs {
   392  		t.Fatalf("allocs = %v; want 0", n)
   393  	}
   394  
   395  	// This one involves a convT2E allocation, string -> interface{}
   396  	n = testing.AllocsPerRun(100, func() {
   397  		test("string", "foo", "foo")
   398  	})
   399  	if n > 1.5 && measureAllocs {
   400  		t.Fatalf("allocs = %v; want max 1", n)
   401  	}
   402  }
   403  
   404  // https://golang.org/issues/13905
   405  func TestUserDefinedBytes(t *testing.T) {
   406  	type userDefinedBytes []byte
   407  	var u userDefinedBytes
   408  	v := []byte("foo")
   409  
   410  	convertAssign(&u, v)
   411  	if &u[0] == &v[0] {
   412  		t.Fatal("userDefinedBytes got potentially dirty driver memory")
   413  	}
   414  }
   415  
   416  type Valuer_V string
   417  
   418  func (v Valuer_V) Value() (driver.Value, error) {
   419  	return strings.ToUpper(string(v)), nil
   420  }
   421  
   422  type Valuer_P string
   423  
   424  func (p *Valuer_P) Value() (driver.Value, error) {
   425  	if p == nil {
   426  		return "nil-to-str", nil
   427  	}
   428  	return strings.ToUpper(string(*p)), nil
   429  }
   430  
   431  func TestDriverArgs(t *testing.T) {
   432  	var nilValuerVPtr *Valuer_V
   433  	var nilValuerPPtr *Valuer_P
   434  	var nilStrPtr *string
   435  	tests := []struct {
   436  		args []interface{}
   437  		want []driver.NamedValue
   438  	}{
   439  		0: {
   440  			args: []interface{}{Valuer_V("foo")},
   441  			want: []driver.NamedValue{
   442  				driver.NamedValue{
   443  					Ordinal: 1,
   444  					Value:   "FOO",
   445  				},
   446  			},
   447  		},
   448  		1: {
   449  			args: []interface{}{nilValuerVPtr},
   450  			want: []driver.NamedValue{
   451  				driver.NamedValue{
   452  					Ordinal: 1,
   453  					Value:   nil,
   454  				},
   455  			},
   456  		},
   457  		2: {
   458  			args: []interface{}{nilValuerPPtr},
   459  			want: []driver.NamedValue{
   460  				driver.NamedValue{
   461  					Ordinal: 1,
   462  					Value:   "nil-to-str",
   463  				},
   464  			},
   465  		},
   466  		3: {
   467  			args: []interface{}{"plain-str"},
   468  			want: []driver.NamedValue{
   469  				driver.NamedValue{
   470  					Ordinal: 1,
   471  					Value:   "plain-str",
   472  				},
   473  			},
   474  		},
   475  		4: {
   476  			args: []interface{}{nilStrPtr},
   477  			want: []driver.NamedValue{
   478  				driver.NamedValue{
   479  					Ordinal: 1,
   480  					Value:   nil,
   481  				},
   482  			},
   483  		},
   484  	}
   485  	for i, tt := range tests {
   486  		ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
   487  		got, err := driverArgsConnLocked(nil, ds, tt.args)
   488  		if err != nil {
   489  			t.Errorf("test[%d]: %v", i, err)
   490  			continue
   491  		}
   492  		if !reflect.DeepEqual(got, tt.want) {
   493  			t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
   494  		}
   495  	}
   496  }