vitess.io/vitess@v0.16.2/go/vt/vitessdriver/rows_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vitessdriver
    18  
    19  import (
    20  	"database/sql/driver"
    21  	"fmt"
    22  	"io"
    23  	"reflect"
    24  	"testing"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  
    29  	"vitess.io/vitess/go/sqltypes"
    30  	querypb "vitess.io/vitess/go/vt/proto/query"
    31  )
    32  
    33  var rowsResult1 = sqltypes.Result{
    34  	Fields: []*querypb.Field{
    35  		{
    36  			Name: "field1",
    37  			Type: sqltypes.Int32,
    38  		},
    39  		{
    40  			Name: "field2",
    41  			Type: sqltypes.Float32,
    42  		},
    43  		{
    44  			Name: "field3",
    45  			Type: sqltypes.VarChar,
    46  		},
    47  		// Signed types which are smaller than uint64, will become an int64.
    48  		{
    49  			Name: "field4",
    50  			Type: sqltypes.Uint32,
    51  		},
    52  		// Signed uint64 values must be mapped to uint64.
    53  		{
    54  			Name: "field5",
    55  			Type: sqltypes.Uint64,
    56  		},
    57  	},
    58  	RowsAffected: 2,
    59  	InsertID:     0,
    60  	Rows: [][]sqltypes.Value{
    61  		{
    62  			sqltypes.NewInt32(1),
    63  			sqltypes.TestValue(sqltypes.Float32, "1.1"),
    64  			sqltypes.NewVarChar("value1"),
    65  			sqltypes.TestValue(sqltypes.Uint32, "2147483647"), // 2^31-1, NOT out of range for int32 => should become int64
    66  			sqltypes.NewUint64(9223372036854775807),           // 2^63-1, NOT out of range for int64
    67  		},
    68  		{
    69  			sqltypes.NewInt32(2),
    70  			sqltypes.TestValue(sqltypes.Float32, "2.2"),
    71  			sqltypes.NewVarChar("value2"),
    72  			sqltypes.TestValue(sqltypes.Uint32, "4294967295"), // 2^32-1, out of range for int32 => should become int64
    73  			sqltypes.NewUint64(18446744073709551615),          // 2^64-1, out of range for int64
    74  		},
    75  	},
    76  }
    77  
    78  func logMismatchedTypes(t *testing.T, gotRow, wantRow []driver.Value) {
    79  	for i := 1; i < len(wantRow); i++ {
    80  		got := gotRow[i]
    81  		want := wantRow[i]
    82  		v1 := reflect.ValueOf(got)
    83  		v2 := reflect.ValueOf(want)
    84  		if v1.Type() != v2.Type() {
    85  			t.Errorf("Wrong type: field: %d got: %T want: %T", i+1, got, want)
    86  		}
    87  	}
    88  }
    89  
    90  func TestRows(t *testing.T) {
    91  	ri := newRows(&rowsResult1, &converter{})
    92  	wantCols := []string{
    93  		"field1",
    94  		"field2",
    95  		"field3",
    96  		"field4",
    97  		"field5",
    98  	}
    99  	gotCols := ri.Columns()
   100  	if !reflect.DeepEqual(gotCols, wantCols) {
   101  		t.Errorf("cols: %v, want %v", gotCols, wantCols)
   102  	}
   103  
   104  	wantRow := []driver.Value{
   105  		int64(1),
   106  		float64(1.1),
   107  		[]byte("value1"),
   108  		uint64(2147483647),
   109  		uint64(9223372036854775807),
   110  	}
   111  	gotRow := make([]driver.Value, len(wantRow))
   112  	err := ri.Next(gotRow)
   113  	require.NoError(t, err)
   114  	if !reflect.DeepEqual(gotRow, wantRow) {
   115  		t.Errorf("row1: %#v, want %#v type: %T", gotRow, wantRow, wantRow[3])
   116  		logMismatchedTypes(t, gotRow, wantRow)
   117  	}
   118  
   119  	wantRow = []driver.Value{
   120  		int64(2),
   121  		float64(2.2),
   122  		[]byte("value2"),
   123  		uint64(4294967295),
   124  		uint64(18446744073709551615),
   125  	}
   126  	err = ri.Next(gotRow)
   127  	require.NoError(t, err)
   128  	if !reflect.DeepEqual(gotRow, wantRow) {
   129  		t.Errorf("row1: %v, want %v", gotRow, wantRow)
   130  		logMismatchedTypes(t, gotRow, wantRow)
   131  	}
   132  
   133  	err = ri.Next(gotRow)
   134  	if err != io.EOF {
   135  		t.Errorf("got: %v, want %v", err, io.EOF)
   136  	}
   137  
   138  	_ = ri.Close()
   139  }
   140  
   141  // Test that the ColumnTypeScanType function returns the correct reflection type for each
   142  // sql type. The sql type in turn comes from a table column's type.
   143  func TestColumnTypeScanType(t *testing.T) {
   144  	var r = sqltypes.Result{
   145  		Fields: []*querypb.Field{
   146  			{
   147  				Name: "field1",
   148  				Type: sqltypes.Int8,
   149  			},
   150  			{
   151  				Name: "field2",
   152  				Type: sqltypes.Uint8,
   153  			},
   154  			{
   155  				Name: "field3",
   156  				Type: sqltypes.Int16,
   157  			},
   158  			{
   159  				Name: "field4",
   160  				Type: sqltypes.Uint16,
   161  			},
   162  			{
   163  				Name: "field5",
   164  				Type: sqltypes.Int24,
   165  			},
   166  			{
   167  				Name: "field6",
   168  				Type: sqltypes.Uint24,
   169  			},
   170  			{
   171  				Name: "field7",
   172  				Type: sqltypes.Int32,
   173  			},
   174  			{
   175  				Name: "field8",
   176  				Type: sqltypes.Uint32,
   177  			},
   178  			{
   179  				Name: "field9",
   180  				Type: sqltypes.Int64,
   181  			},
   182  			{
   183  				Name: "field10",
   184  				Type: sqltypes.Uint64,
   185  			},
   186  			{
   187  				Name: "field11",
   188  				Type: sqltypes.Float32,
   189  			},
   190  			{
   191  				Name: "field12",
   192  				Type: sqltypes.Float64,
   193  			},
   194  			{
   195  				Name: "field13",
   196  				Type: sqltypes.VarBinary,
   197  			},
   198  			{
   199  				Name: "field14",
   200  				Type: sqltypes.Datetime,
   201  			},
   202  		},
   203  	}
   204  
   205  	ri := newRows(&r, &converter{}).(driver.RowsColumnTypeScanType)
   206  	defer ri.Close()
   207  
   208  	wantTypes := []reflect.Type{
   209  		typeInt8,
   210  		typeUint8,
   211  		typeInt16,
   212  		typeUint16,
   213  		typeInt32,
   214  		typeUint32,
   215  		typeInt32,
   216  		typeUint32,
   217  		typeInt64,
   218  		typeUint64,
   219  		typeFloat32,
   220  		typeFloat64,
   221  		typeRawBytes,
   222  		typeTime,
   223  	}
   224  
   225  	for i := 0; i < len(wantTypes); i++ {
   226  		assert.Equal(t, ri.ColumnTypeScanType(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeScanType(i), wantTypes[i]))
   227  	}
   228  }