vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/vindex_func_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/require"
    25  
    26  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    27  
    28  	"vitess.io/vitess/go/sqltypes"
    29  	"vitess.io/vitess/go/vt/key"
    30  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    31  
    32  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    33  )
    34  
    35  // uvindex is Unique.
    36  type uvindex struct{ matchid, matchkr bool }
    37  
    38  func (*uvindex) String() string     { return "uvindex" }
    39  func (*uvindex) Cost() int          { return 1 }
    40  func (*uvindex) IsUnique() bool     { return true }
    41  func (*uvindex) NeedsVCursor() bool { return false }
    42  func (*uvindex) Verify(context.Context, vindexes.VCursor, []sqltypes.Value, [][]byte) ([]bool, error) {
    43  	panic("unimplemented")
    44  }
    45  
    46  func (v *uvindex) Map(ctx context.Context, vcursor vindexes.VCursor, ids []sqltypes.Value) ([]key.Destination, error) {
    47  	destinations := make([]key.Destination, 0, len(ids))
    48  	dkid := []byte("foo")
    49  	for i := 0; i < len(ids); i++ {
    50  		if v.matchkr {
    51  			destinations = append(destinations,
    52  				key.DestinationKeyRange{
    53  					KeyRange: &topodatapb.KeyRange{
    54  						Start: []byte{0x40},
    55  						End:   []byte{0x60},
    56  					},
    57  				})
    58  		} else if v.matchid {
    59  			destinations = append(destinations, key.DestinationKeyspaceID(dkid))
    60  		} else {
    61  			destinations = append(destinations, key.DestinationNone{})
    62  		}
    63  	}
    64  	return destinations, nil
    65  }
    66  
    67  // nvindex is NonUnique.
    68  type nvindex struct{ matchid, matchkr bool }
    69  
    70  func (*nvindex) String() string     { return "nvindex" }
    71  func (*nvindex) Cost() int          { return 1 }
    72  func (*nvindex) IsUnique() bool     { return false }
    73  func (*nvindex) NeedsVCursor() bool { return false }
    74  func (*nvindex) Verify(context.Context, vindexes.VCursor, []sqltypes.Value, [][]byte) ([]bool, error) {
    75  	panic("unimplemented")
    76  }
    77  
    78  func (v *nvindex) Map(ctx context.Context, vcursor vindexes.VCursor, ids []sqltypes.Value) ([]key.Destination, error) {
    79  	destinations := make([]key.Destination, 0)
    80  	for i := 0; i < len(ids); i++ {
    81  		if v.matchid {
    82  			destinations = append(destinations,
    83  				[]key.Destination{
    84  					key.DestinationKeyspaceIDs([][]byte{
    85  						[]byte("foo"),
    86  						[]byte("bar"),
    87  					}),
    88  				}...)
    89  		} else if v.matchkr {
    90  			destinations = append(destinations,
    91  				[]key.Destination{
    92  					key.DestinationKeyRange{
    93  						KeyRange: &topodatapb.KeyRange{
    94  							Start: []byte{0x40},
    95  							End:   []byte{0x60},
    96  						},
    97  					},
    98  				}...)
    99  		} else {
   100  			destinations = append(destinations, []key.Destination{key.DestinationNone{}}...)
   101  		}
   102  	}
   103  	return destinations, nil
   104  }
   105  
   106  func TestVindexFuncMap(t *testing.T) {
   107  	// Unique Vindex returning 0 rows.
   108  	vf := testVindexFunc(&uvindex{})
   109  	got, err := vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   110  	require.NoError(t, err)
   111  	want := &sqltypes.Result{
   112  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   113  	}
   114  	require.Equal(t, got, want)
   115  
   116  	// Unique Vindex returning 1 row.
   117  	vf = testVindexFunc(&uvindex{matchid: true})
   118  	got, err = vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   119  	require.NoError(t, err)
   120  	want = sqltypes.MakeTestResult(
   121  		sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   122  		"1|foo|||666f6f",
   123  	)
   124  	for _, row := range want.Rows {
   125  		row[2] = sqltypes.NULL
   126  		row[3] = sqltypes.NULL
   127  	}
   128  	require.Equal(t, got, want)
   129  
   130  	// Unique Vindex returning 3 rows
   131  	vf = &VindexFunc{
   132  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   133  		Cols:   []int{0, 1, 2, 3, 4},
   134  		Opcode: VindexMap,
   135  		Vindex: &uvindex{matchid: true},
   136  		Value:  evalengine.TupleExpr{evalengine.NewLiteralInt(1), evalengine.NewLiteralInt(2), evalengine.NewLiteralInt(3)},
   137  	}
   138  	got, err = vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   139  	require.NoError(t, err)
   140  	want = sqltypes.MakeTestResult(
   141  		sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   142  		"1|foo|||666f6f",
   143  		"2|foo|||666f6f",
   144  		"3|foo|||666f6f",
   145  	)
   146  	for _, row := range want.Rows {
   147  		row[2] = sqltypes.NULL
   148  		row[3] = sqltypes.NULL
   149  	}
   150  	require.Equal(t, got, want)
   151  
   152  	// Unique Vindex returning keyrange.
   153  	vf = testVindexFunc(&uvindex{matchkr: true})
   154  	got, err = vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   155  	require.NoError(t, err)
   156  	want = &sqltypes.Result{
   157  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   158  		Rows: [][]sqltypes.Value{
   159  			{
   160  				sqltypes.NewVarBinary("1"),
   161  				sqltypes.NULL,
   162  				sqltypes.MakeTrusted(sqltypes.VarBinary, []byte{0x40}),
   163  				sqltypes.MakeTrusted(sqltypes.VarBinary, []byte{0x60}),
   164  				sqltypes.NULL,
   165  			},
   166  		},
   167  		RowsAffected: 0,
   168  	}
   169  	require.Equal(t, got, want)
   170  
   171  	// NonUnique Vindex returning 0 rows.
   172  	vf = testVindexFunc(&nvindex{})
   173  	got, err = vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   174  	require.NoError(t, err)
   175  	want = &sqltypes.Result{
   176  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   177  	}
   178  	require.Equal(t, got, want)
   179  
   180  	// NonUnique Vindex returning 2 rows.
   181  	vf = testVindexFunc(&nvindex{matchid: true})
   182  	got, err = vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   183  	require.NoError(t, err)
   184  	want = sqltypes.MakeTestResult(
   185  		sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   186  		"1|foo|||666f6f",
   187  		"1|bar|||626172",
   188  	)
   189  	// Massage the rows because MakeTestResult doesn't do NULL values.
   190  	for _, row := range want.Rows {
   191  		row[2] = sqltypes.NULL
   192  		row[3] = sqltypes.NULL
   193  	}
   194  	require.Equal(t, got, want)
   195  
   196  	// NonUnique Vindex returning keyrange
   197  	vf = testVindexFunc(&nvindex{matchkr: true})
   198  	got, err = vf.TryExecute(context.Background(), &noopVCursor{}, nil, false)
   199  	require.NoError(t, err)
   200  	want = &sqltypes.Result{
   201  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   202  		Rows: [][]sqltypes.Value{{
   203  			sqltypes.NewVarBinary("1"),
   204  			sqltypes.NULL,
   205  			sqltypes.MakeTrusted(sqltypes.VarBinary, []byte{0x40}),
   206  			sqltypes.MakeTrusted(sqltypes.VarBinary, []byte{0x60}),
   207  			sqltypes.NULL,
   208  		}},
   209  		RowsAffected: 0,
   210  	}
   211  	require.Equal(t, got, want)
   212  }
   213  
   214  func TestVindexFuncStreamExecute(t *testing.T) {
   215  	vf := testVindexFunc(&nvindex{matchid: true})
   216  	want := []*sqltypes.Result{{
   217  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   218  	}, {
   219  		Rows: [][]sqltypes.Value{{
   220  			sqltypes.NewVarBinary("1"), sqltypes.NewVarBinary("foo"), sqltypes.NULL, sqltypes.NULL, sqltypes.NewVarBinary("666f6f"),
   221  		}, {
   222  			sqltypes.NewVarBinary("1"), sqltypes.NewVarBinary("bar"), sqltypes.NULL, sqltypes.NULL, sqltypes.NewVarBinary("626172"),
   223  		}},
   224  	}}
   225  	i := 0
   226  	err := vf.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(qr *sqltypes.Result) error {
   227  		if !reflect.DeepEqual(qr, want[i]) {
   228  			t.Errorf("callback(%d):\n%v, want\n%v", i, qr, want[i])
   229  		}
   230  		i++
   231  		return nil
   232  	})
   233  	if err != nil {
   234  		t.Fatal(err)
   235  	}
   236  }
   237  
   238  func TestVindexFuncGetFields(t *testing.T) {
   239  	vf := testVindexFunc(&uvindex{matchid: true})
   240  	got, err := vf.GetFields(context.Background(), nil, nil)
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	want := &sqltypes.Result{
   245  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   246  	}
   247  	if !reflect.DeepEqual(got, want) {
   248  		t.Errorf("Execute(Map, uvindex(none)):\n%v, want\n%v", got, want)
   249  	}
   250  }
   251  
   252  func TestFieldOrder(t *testing.T) {
   253  	vf := testVindexFunc(&nvindex{matchid: true})
   254  	vf.Fields = sqltypes.MakeTestFields("keyspace_id|id|keyspace_id", "varbinary|varbinary|varbinary")
   255  	vf.Cols = []int{1, 0, 1}
   256  	got, err := vf.TryExecute(context.Background(), &noopVCursor{}, nil, true)
   257  	if err != nil {
   258  		t.Fatal(err)
   259  	}
   260  	want := sqltypes.MakeTestResult(
   261  		vf.Fields,
   262  		"foo|1|foo",
   263  		"bar|1|bar",
   264  	)
   265  	if !reflect.DeepEqual(got, want) {
   266  		t.Errorf("Execute(Map, uvindex(none)):\n%v, want\n%v", got, want)
   267  	}
   268  }
   269  
   270  func testVindexFunc(v vindexes.SingleColumn) *VindexFunc {
   271  	return &VindexFunc{
   272  		Fields: sqltypes.MakeTestFields("id|keyspace_id|hex(keyspace_id)|range_start|range_end", "varbinary|varbinary|varbinary|varbinary|varbinary"),
   273  		Cols:   []int{0, 1, 2, 3, 4},
   274  		Opcode: VindexMap,
   275  		Vindex: v,
   276  		Value:  evalengine.NewLiteralInt(1),
   277  	}
   278  }