vitess.io/vitess@v0.16.2/go/vt/vitessdriver/streaming_rows_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vitessdriver 18 19 import ( 20 "database/sql/driver" 21 "errors" 22 "io" 23 "reflect" 24 "strings" 25 "testing" 26 27 "github.com/stretchr/testify/require" 28 29 "vitess.io/vitess/go/sqltypes" 30 querypb "vitess.io/vitess/go/vt/proto/query" 31 ) 32 33 var packet1 = sqltypes.Result{ 34 Fields: []*querypb.Field{ 35 { 36 Name: "field1", 37 Type: sqltypes.Int32, 38 }, 39 { 40 Name: "field2", 41 Type: sqltypes.Float32, 42 }, 43 { 44 Name: "field3", 45 Type: sqltypes.VarChar, 46 }, 47 }, 48 } 49 50 var packet2 = sqltypes.Result{ 51 Rows: [][]sqltypes.Value{ 52 { 53 sqltypes.NewInt32(1), 54 sqltypes.TestValue(sqltypes.Float32, "1.1"), 55 sqltypes.NewVarChar("value1"), 56 }, 57 }, 58 } 59 60 var packet3 = sqltypes.Result{ 61 Rows: [][]sqltypes.Value{ 62 { 63 sqltypes.NewInt32(2), 64 sqltypes.TestValue(sqltypes.Float32, "2.2"), 65 sqltypes.NewVarChar("value2"), 66 }, 67 }, 68 } 69 70 type adapter struct { 71 c chan *sqltypes.Result 72 err error 73 } 74 75 func (a *adapter) Recv() (*sqltypes.Result, error) { 76 r, ok := <-a.c 77 if !ok { 78 return nil, a.err 79 } 80 return r, nil 81 } 82 83 func TestStreamingRows(t *testing.T) { 84 c := make(chan *sqltypes.Result, 3) 85 c <- &packet1 86 c <- &packet2 87 c <- &packet3 88 close(c) 89 ri := newStreamingRows(&adapter{c: c, err: io.EOF}, &converter{}) 90 wantCols := []string{ 91 "field1", 92 "field2", 93 "field3", 94 } 95 gotCols := ri.Columns() 96 if !reflect.DeepEqual(gotCols, wantCols) { 97 t.Errorf("cols: %v, want %v", gotCols, wantCols) 98 } 99 100 wantRow := []driver.Value{ 101 int64(1), 102 float64(1.1), 103 []byte("value1"), 104 } 105 gotRow := make([]driver.Value, 3) 106 err := ri.Next(gotRow) 107 require.NoError(t, err) 108 if !reflect.DeepEqual(gotRow, wantRow) { 109 t.Errorf("row1: %v, want %v", gotRow, wantRow) 110 } 111 112 wantRow = []driver.Value{ 113 int64(2), 114 float64(2.2), 115 []byte("value2"), 116 } 117 err = ri.Next(gotRow) 118 require.NoError(t, err) 119 if !reflect.DeepEqual(gotRow, wantRow) { 120 t.Errorf("row1: %v, want %v", gotRow, wantRow) 121 } 122 123 err = ri.Next(gotRow) 124 if err != io.EOF { 125 t.Errorf("got: %v, want %v", err, io.EOF) 126 } 127 128 _ = ri.Close() 129 } 130 131 func TestStreamingRowsReversed(t *testing.T) { 132 c := make(chan *sqltypes.Result, 3) 133 c <- &packet1 134 c <- &packet2 135 c <- &packet3 136 close(c) 137 ri := newStreamingRows(&adapter{c: c, err: io.EOF}, &converter{}) 138 defer ri.Close() 139 140 wantRow := []driver.Value{ 141 int64(1), 142 float64(1.1), 143 []byte("value1"), 144 } 145 gotRow := make([]driver.Value, 3) 146 err := ri.Next(gotRow) 147 require.NoError(t, err) 148 if !reflect.DeepEqual(gotRow, wantRow) { 149 t.Errorf("row1: %v, want %v", gotRow, wantRow) 150 } 151 152 wantCols := []string{ 153 "field1", 154 "field2", 155 "field3", 156 } 157 gotCols := ri.Columns() 158 if !reflect.DeepEqual(gotCols, wantCols) { 159 t.Errorf("cols: %v, want %v", gotCols, wantCols) 160 } 161 162 _ = ri.Close() 163 } 164 165 func TestStreamingRowsError(t *testing.T) { 166 c := make(chan *sqltypes.Result) 167 close(c) 168 ri := newStreamingRows(&adapter{c: c, err: errors.New("error before fields")}, &converter{}) 169 170 gotCols := ri.Columns() 171 if gotCols != nil { 172 t.Errorf("cols: %v, want nil", gotCols) 173 } 174 gotRow := make([]driver.Value, 3) 175 err := ri.Next(gotRow) 176 wantErr := "error before fields" 177 if err == nil || !strings.Contains(err.Error(), wantErr) { 178 t.Errorf("err: %v does not contain %v", err, wantErr) 179 } 180 _ = ri.Close() 181 182 c = make(chan *sqltypes.Result, 1) 183 c <- &packet1 184 close(c) 185 ri = newStreamingRows(&adapter{c: c, err: errors.New("error after fields")}, &converter{}) 186 wantCols := []string{ 187 "field1", 188 "field2", 189 "field3", 190 } 191 gotCols = ri.Columns() 192 if !reflect.DeepEqual(gotCols, wantCols) { 193 t.Errorf("cols: %v, want %v", gotCols, wantCols) 194 } 195 gotRow = make([]driver.Value, 3) 196 err = ri.Next(gotRow) 197 wantErr = "error after fields" 198 if err == nil || !strings.Contains(err.Error(), wantErr) { 199 t.Errorf("err: %v does not contain %v", err, wantErr) 200 } 201 // Ensure error persists. 202 err = ri.Next(gotRow) 203 if err == nil || !strings.Contains(err.Error(), wantErr) { 204 t.Errorf("err: %v does not contain %v", err, wantErr) 205 } 206 _ = ri.Close() 207 208 c = make(chan *sqltypes.Result, 2) 209 c <- &packet1 210 c <- &packet2 211 close(c) 212 ri = newStreamingRows(&adapter{c: c, err: errors.New("error after rows")}, &converter{}) 213 gotRow = make([]driver.Value, 3) 214 err = ri.Next(gotRow) 215 require.NoError(t, err) 216 err = ri.Next(gotRow) 217 wantErr = "error after rows" 218 if err == nil || !strings.Contains(err.Error(), wantErr) { 219 t.Errorf("err: %v does not contain %v", err, wantErr) 220 } 221 _ = ri.Close() 222 223 c = make(chan *sqltypes.Result, 1) 224 c <- &packet2 225 close(c) 226 ri = newStreamingRows(&adapter{c: c, err: io.EOF}, &converter{}) 227 gotRow = make([]driver.Value, 3) 228 err = ri.Next(gotRow) 229 wantErr = "first packet did not return fields" 230 if err == nil || !strings.Contains(err.Error(), wantErr) { 231 t.Errorf("err: %v does not contain %v", err, wantErr) 232 } 233 _ = ri.Close() 234 }