vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/fake_primitive_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"reflect"
    23  	"strings"
    24  	"testing"
    25  
    26  	"vitess.io/vitess/go/sqltypes"
    27  
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  )
    30  
    31  // fakePrimitive fakes a primitive. For every call, it sends the
    32  // next result from the results. If the next result is nil, it
    33  // returns sendErr. For streaming calls, it sends the field info
    34  // first and two rows at a time till all rows are sent.
    35  type fakePrimitive struct {
    36  	results   []*sqltypes.Result
    37  	curResult int
    38  	// sendErr is sent at the end of the stream if it's set.
    39  	sendErr error
    40  
    41  	log []string
    42  
    43  	allResultsInOneCall bool
    44  }
    45  
    46  func (f *fakePrimitive) Inputs() []Primitive {
    47  	return []Primitive{}
    48  }
    49  
    50  var _ Primitive = (*fakePrimitive)(nil)
    51  
    52  func (f *fakePrimitive) rewind() {
    53  	f.curResult = 0
    54  	f.log = nil
    55  }
    56  
    57  func (f *fakePrimitive) RouteType() string {
    58  	return "Fake"
    59  }
    60  
    61  func (f *fakePrimitive) GetKeyspaceName() string {
    62  	return "fakeKs"
    63  }
    64  
    65  func (f *fakePrimitive) GetTableName() string {
    66  	return "fakeTable"
    67  }
    68  
    69  func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    70  	f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields))
    71  	if f.results == nil {
    72  		return nil, f.sendErr
    73  	}
    74  
    75  	r := f.results[f.curResult]
    76  	f.curResult++
    77  	if r == nil {
    78  		return nil, f.sendErr
    79  	}
    80  	return r, nil
    81  }
    82  
    83  func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    84  	f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
    85  	if f.results == nil {
    86  		return f.sendErr
    87  	}
    88  
    89  	readMoreResults := true
    90  	for readMoreResults && f.curResult < len(f.results) {
    91  		readMoreResults = f.allResultsInOneCall
    92  		r := f.results[f.curResult]
    93  		f.curResult++
    94  		if r == nil {
    95  			return f.sendErr
    96  		}
    97  		if wantfields {
    98  			if err := callback(&sqltypes.Result{Fields: r.Fields}); err != nil {
    99  				return err
   100  			}
   101  		}
   102  		result := &sqltypes.Result{}
   103  		for i := 0; i < len(r.Rows); i++ {
   104  			result.Rows = append(result.Rows, r.Rows[i])
   105  			// Send only two rows at a time.
   106  			if i%2 == 1 {
   107  				if err := callback(result); err != nil {
   108  					return err
   109  				}
   110  				result = &sqltypes.Result{}
   111  			}
   112  		}
   113  		if len(result.Rows) != 0 {
   114  			if err := callback(result); err != nil {
   115  				return err
   116  			}
   117  		}
   118  	}
   119  
   120  	return nil
   121  }
   122  func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   123  	f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
   124  	return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
   125  }
   126  
   127  func (f *fakePrimitive) ExpectLog(t *testing.T, want []string) {
   128  	t.Helper()
   129  	if !reflect.DeepEqual(f.log, want) {
   130  		t.Errorf("vc.log got:\n%v\nwant:\n%v", strings.Join(f.log, "\n"), strings.Join(want, "\n"))
   131  	}
   132  }
   133  
   134  func (f *fakePrimitive) NeedsTransaction() bool {
   135  	return false
   136  }
   137  
   138  func wrapStreamExecute(prim Primitive, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
   139  	var result *sqltypes.Result
   140  	err := prim.TryStreamExecute(context.Background(), vcursor, bindVars, wantfields, func(r *sqltypes.Result) error {
   141  		if result == nil {
   142  			result = r
   143  		} else {
   144  			result.Rows = append(result.Rows, r.Rows...)
   145  		}
   146  		return nil
   147  	})
   148  	return result, err
   149  }
   150  
   151  func (f *fakePrimitive) description() PrimitiveDescription {
   152  	return PrimitiveDescription{OperatorType: "fake"}
   153  }