vitess.io/vitess@v0.16.2/go/vt/vitessdriver/streaming_rows_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vitessdriver
    18  
    19  import (
    20  	"database/sql/driver"
    21  	"errors"
    22  	"io"
    23  	"reflect"
    24  	"strings"
    25  	"testing"
    26  
    27  	"github.com/stretchr/testify/require"
    28  
    29  	"vitess.io/vitess/go/sqltypes"
    30  	querypb "vitess.io/vitess/go/vt/proto/query"
    31  )
    32  
    33  var packet1 = sqltypes.Result{
    34  	Fields: []*querypb.Field{
    35  		{
    36  			Name: "field1",
    37  			Type: sqltypes.Int32,
    38  		},
    39  		{
    40  			Name: "field2",
    41  			Type: sqltypes.Float32,
    42  		},
    43  		{
    44  			Name: "field3",
    45  			Type: sqltypes.VarChar,
    46  		},
    47  	},
    48  }
    49  
    50  var packet2 = sqltypes.Result{
    51  	Rows: [][]sqltypes.Value{
    52  		{
    53  			sqltypes.NewInt32(1),
    54  			sqltypes.TestValue(sqltypes.Float32, "1.1"),
    55  			sqltypes.NewVarChar("value1"),
    56  		},
    57  	},
    58  }
    59  
    60  var packet3 = sqltypes.Result{
    61  	Rows: [][]sqltypes.Value{
    62  		{
    63  			sqltypes.NewInt32(2),
    64  			sqltypes.TestValue(sqltypes.Float32, "2.2"),
    65  			sqltypes.NewVarChar("value2"),
    66  		},
    67  	},
    68  }
    69  
    70  type adapter struct {
    71  	c   chan *sqltypes.Result
    72  	err error
    73  }
    74  
    75  func (a *adapter) Recv() (*sqltypes.Result, error) {
    76  	r, ok := <-a.c
    77  	if !ok {
    78  		return nil, a.err
    79  	}
    80  	return r, nil
    81  }
    82  
    83  func TestStreamingRows(t *testing.T) {
    84  	c := make(chan *sqltypes.Result, 3)
    85  	c <- &packet1
    86  	c <- &packet2
    87  	c <- &packet3
    88  	close(c)
    89  	ri := newStreamingRows(&adapter{c: c, err: io.EOF}, &converter{})
    90  	wantCols := []string{
    91  		"field1",
    92  		"field2",
    93  		"field3",
    94  	}
    95  	gotCols := ri.Columns()
    96  	if !reflect.DeepEqual(gotCols, wantCols) {
    97  		t.Errorf("cols: %v, want %v", gotCols, wantCols)
    98  	}
    99  
   100  	wantRow := []driver.Value{
   101  		int64(1),
   102  		float64(1.1),
   103  		[]byte("value1"),
   104  	}
   105  	gotRow := make([]driver.Value, 3)
   106  	err := ri.Next(gotRow)
   107  	require.NoError(t, err)
   108  	if !reflect.DeepEqual(gotRow, wantRow) {
   109  		t.Errorf("row1: %v, want %v", gotRow, wantRow)
   110  	}
   111  
   112  	wantRow = []driver.Value{
   113  		int64(2),
   114  		float64(2.2),
   115  		[]byte("value2"),
   116  	}
   117  	err = ri.Next(gotRow)
   118  	require.NoError(t, err)
   119  	if !reflect.DeepEqual(gotRow, wantRow) {
   120  		t.Errorf("row1: %v, want %v", gotRow, wantRow)
   121  	}
   122  
   123  	err = ri.Next(gotRow)
   124  	if err != io.EOF {
   125  		t.Errorf("got: %v, want %v", err, io.EOF)
   126  	}
   127  
   128  	_ = ri.Close()
   129  }
   130  
   131  func TestStreamingRowsReversed(t *testing.T) {
   132  	c := make(chan *sqltypes.Result, 3)
   133  	c <- &packet1
   134  	c <- &packet2
   135  	c <- &packet3
   136  	close(c)
   137  	ri := newStreamingRows(&adapter{c: c, err: io.EOF}, &converter{})
   138  	defer ri.Close()
   139  
   140  	wantRow := []driver.Value{
   141  		int64(1),
   142  		float64(1.1),
   143  		[]byte("value1"),
   144  	}
   145  	gotRow := make([]driver.Value, 3)
   146  	err := ri.Next(gotRow)
   147  	require.NoError(t, err)
   148  	if !reflect.DeepEqual(gotRow, wantRow) {
   149  		t.Errorf("row1: %v, want %v", gotRow, wantRow)
   150  	}
   151  
   152  	wantCols := []string{
   153  		"field1",
   154  		"field2",
   155  		"field3",
   156  	}
   157  	gotCols := ri.Columns()
   158  	if !reflect.DeepEqual(gotCols, wantCols) {
   159  		t.Errorf("cols: %v, want %v", gotCols, wantCols)
   160  	}
   161  
   162  	_ = ri.Close()
   163  }
   164  
   165  func TestStreamingRowsError(t *testing.T) {
   166  	c := make(chan *sqltypes.Result)
   167  	close(c)
   168  	ri := newStreamingRows(&adapter{c: c, err: errors.New("error before fields")}, &converter{})
   169  
   170  	gotCols := ri.Columns()
   171  	if gotCols != nil {
   172  		t.Errorf("cols: %v, want nil", gotCols)
   173  	}
   174  	gotRow := make([]driver.Value, 3)
   175  	err := ri.Next(gotRow)
   176  	wantErr := "error before fields"
   177  	if err == nil || !strings.Contains(err.Error(), wantErr) {
   178  		t.Errorf("err: %v does not contain %v", err, wantErr)
   179  	}
   180  	_ = ri.Close()
   181  
   182  	c = make(chan *sqltypes.Result, 1)
   183  	c <- &packet1
   184  	close(c)
   185  	ri = newStreamingRows(&adapter{c: c, err: errors.New("error after fields")}, &converter{})
   186  	wantCols := []string{
   187  		"field1",
   188  		"field2",
   189  		"field3",
   190  	}
   191  	gotCols = ri.Columns()
   192  	if !reflect.DeepEqual(gotCols, wantCols) {
   193  		t.Errorf("cols: %v, want %v", gotCols, wantCols)
   194  	}
   195  	gotRow = make([]driver.Value, 3)
   196  	err = ri.Next(gotRow)
   197  	wantErr = "error after fields"
   198  	if err == nil || !strings.Contains(err.Error(), wantErr) {
   199  		t.Errorf("err: %v does not contain %v", err, wantErr)
   200  	}
   201  	// Ensure error persists.
   202  	err = ri.Next(gotRow)
   203  	if err == nil || !strings.Contains(err.Error(), wantErr) {
   204  		t.Errorf("err: %v does not contain %v", err, wantErr)
   205  	}
   206  	_ = ri.Close()
   207  
   208  	c = make(chan *sqltypes.Result, 2)
   209  	c <- &packet1
   210  	c <- &packet2
   211  	close(c)
   212  	ri = newStreamingRows(&adapter{c: c, err: errors.New("error after rows")}, &converter{})
   213  	gotRow = make([]driver.Value, 3)
   214  	err = ri.Next(gotRow)
   215  	require.NoError(t, err)
   216  	err = ri.Next(gotRow)
   217  	wantErr = "error after rows"
   218  	if err == nil || !strings.Contains(err.Error(), wantErr) {
   219  		t.Errorf("err: %v does not contain %v", err, wantErr)
   220  	}
   221  	_ = ri.Close()
   222  
   223  	c = make(chan *sqltypes.Result, 1)
   224  	c <- &packet2
   225  	close(c)
   226  	ri = newStreamingRows(&adapter{c: c, err: io.EOF}, &converter{})
   227  	gotRow = make([]driver.Value, 3)
   228  	err = ri.Next(gotRow)
   229  	wantErr = "first packet did not return fields"
   230  	if err == nil || !strings.Contains(err.Error(), wantErr) {
   231  		t.Errorf("err: %v does not contain %v", err, wantErr)
   232  	}
   233  	_ = ri.Close()
   234  }