vitess.io/vitess@v0.16.2/go/vt/vtgate/fakerpcvtgateconn/conn.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package fakerpcvtgateconn provides a fake implementation of
    18  // vtgateconn.Impl that doesn't do any RPC, but uses a local
    19  // map to return results.
    20  package fakerpcvtgateconn
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"io"
    26  	"math/rand"
    27  	"reflect"
    28  
    29  	"vitess.io/vitess/go/sqltypes"
    30  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    31  
    32  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    33  	querypb "vitess.io/vitess/go/vt/proto/query"
    34  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    35  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    36  )
    37  
    38  // queryExecute contains all the fields we use to test Execute
    39  type queryExecute struct {
    40  	SQL           string
    41  	BindVariables map[string]*querypb.BindVariable
    42  	Session       *vtgatepb.Session
    43  }
    44  
    45  type queryResponse struct {
    46  	execQuery *queryExecute
    47  	reply     *sqltypes.Result
    48  	err       error
    49  }
    50  
    51  // FakeVTGateConn provides a fake implementation of vtgateconn.Impl
    52  type FakeVTGateConn struct {
    53  	execMap map[string]*queryResponse
    54  }
    55  
    56  // RegisterFakeVTGateConnDialer registers the proper dialer for this fake,
    57  // and returns the underlying instance that will be returned by the dialer,
    58  // and the protocol to use to get this fake.
    59  func RegisterFakeVTGateConnDialer() (*FakeVTGateConn, string) {
    60  	protocol := "fake"
    61  	impl := &FakeVTGateConn{
    62  		execMap: make(map[string]*queryResponse),
    63  	}
    64  	vtgateconn.RegisterDialer(protocol, func(ctx context.Context, address string) (vtgateconn.Impl, error) {
    65  		return impl, nil
    66  	})
    67  	return impl, protocol
    68  }
    69  
    70  // AddQuery adds a query and expected result.
    71  func (conn *FakeVTGateConn) AddQuery(
    72  	sql string,
    73  	bindVariables map[string]*querypb.BindVariable,
    74  	session *vtgatepb.Session,
    75  	expectedResult *sqltypes.Result) {
    76  	conn.execMap[sql] = &queryResponse{
    77  		execQuery: &queryExecute{
    78  			SQL:           sql,
    79  			BindVariables: bindVariables,
    80  			Session:       session,
    81  		},
    82  		reply: expectedResult,
    83  	}
    84  }
    85  
    86  // Execute please see vtgateconn.Impl.Execute
    87  func (conn *FakeVTGateConn) Execute(ctx context.Context, session *vtgatepb.Session, sql string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
    88  	response, ok := conn.execMap[sql]
    89  	if !ok {
    90  		return nil, nil, fmt.Errorf("no match for: %s", sql)
    91  	}
    92  	query := &queryExecute{
    93  		SQL:           sql,
    94  		BindVariables: bindVars,
    95  		Session:       session,
    96  	}
    97  	if !reflect.DeepEqual(query, response.execQuery) {
    98  		return nil, nil, fmt.Errorf(
    99  			"Execute: %+v, want %+v", query, response.execQuery)
   100  	}
   101  	reply := *response.reply
   102  	s := newSession(true, "test_keyspace", []string{}, topodatapb.TabletType_PRIMARY)
   103  	return s, &reply, nil
   104  }
   105  
   106  // ExecuteBatch please see vtgateconn.Impl.ExecuteBatch
   107  func (conn *FakeVTGateConn) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVarsList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
   108  	panic("not implemented")
   109  }
   110  
   111  // StreamExecute please see vtgateconn.Impl.StreamExecute
   112  func (conn *FakeVTGateConn) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVars map[string]*querypb.BindVariable) (sqltypes.ResultStream, error) {
   113  	response, ok := conn.execMap[sql]
   114  	if !ok {
   115  		return nil, fmt.Errorf("no match for: %s", sql)
   116  	}
   117  	query := &queryExecute{
   118  		SQL:           sql,
   119  		BindVariables: bindVars,
   120  		Session:       session,
   121  	}
   122  	if !reflect.DeepEqual(query, response.execQuery) {
   123  		return nil, fmt.Errorf("StreamExecute: %+v, want %+v", sql, response.execQuery)
   124  	}
   125  	if response.err != nil {
   126  		return nil, response.err
   127  	}
   128  	var resultChan chan *sqltypes.Result
   129  	defer close(resultChan)
   130  	if response.reply != nil {
   131  		// create a result channel big enough to buffer all of
   132  		// the responses so we don't need to fork a go routine.
   133  		resultChan = make(chan *sqltypes.Result, len(response.reply.Rows)+1)
   134  		result := &sqltypes.Result{}
   135  		result.Fields = response.reply.Fields
   136  		resultChan <- result
   137  		for _, row := range response.reply.Rows {
   138  			result := &sqltypes.Result{}
   139  			result.Rows = [][]sqltypes.Value{row}
   140  			resultChan <- result
   141  		}
   142  	} else {
   143  		resultChan = make(chan *sqltypes.Result)
   144  	}
   145  	return &streamExecuteAdapter{resultChan}, nil
   146  }
   147  
   148  type streamExecuteAdapter struct {
   149  	c chan *sqltypes.Result
   150  }
   151  
   152  func (a *streamExecuteAdapter) Recv() (*sqltypes.Result, error) {
   153  	r, ok := <-a.c
   154  	if !ok {
   155  		return nil, io.EOF
   156  	}
   157  	return r, nil
   158  }
   159  
   160  // Prepare please see vtgateconn.Impl.Prepare
   161  func (conn *FakeVTGateConn) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
   162  	response, ok := conn.execMap[sql]
   163  	if !ok {
   164  		return nil, nil, fmt.Errorf("no match for: %s", sql)
   165  	}
   166  	query := &queryExecute{
   167  		SQL:           sql,
   168  		BindVariables: bindVars,
   169  		Session:       session,
   170  	}
   171  	if !reflect.DeepEqual(query, response.execQuery) {
   172  		return nil, nil, fmt.Errorf(
   173  			"Prepare: %+v, want %+v", query, response.execQuery)
   174  	}
   175  	reply := *response.reply
   176  	s := newSession(true, "test_keyspace", []string{}, topodatapb.TabletType_PRIMARY)
   177  	return s, reply.Fields, nil
   178  }
   179  
   180  // CloseSession please see vtgateconn.Impl.CloseSession
   181  func (conn *FakeVTGateConn) CloseSession(ctx context.Context, session *vtgatepb.Session) error {
   182  	panic("not implemented")
   183  }
   184  
   185  // ResolveTransaction please see vtgateconn.Impl.ResolveTransaction
   186  func (conn *FakeVTGateConn) ResolveTransaction(ctx context.Context, dtid string) error {
   187  	return nil
   188  }
   189  
   190  // VStream streams binlog events.
   191  func (conn *FakeVTGateConn) VStream(ctx context.Context, tabletType topodatapb.TabletType, vgtid *binlogdatapb.VGtid,
   192  	filter *binlogdatapb.Filter, flags *vtgatepb.VStreamFlags) (vtgateconn.VStreamReader, error) {
   193  
   194  	return nil, fmt.Errorf("NYI")
   195  }
   196  
   197  // Close please see vtgateconn.Impl.Close
   198  func (conn *FakeVTGateConn) Close() {
   199  }
   200  
   201  func newSession(
   202  	inTransaction bool,
   203  	keyspace string,
   204  	shards []string,
   205  	tabletType topodatapb.TabletType) *vtgatepb.Session {
   206  	shardSessions := make([]*vtgatepb.Session_ShardSession, len(shards))
   207  	for _, shard := range shards {
   208  		shardSessions = append(shardSessions, &vtgatepb.Session_ShardSession{
   209  			Target: &querypb.Target{
   210  				Keyspace:   keyspace,
   211  				Shard:      shard,
   212  				TabletType: tabletType,
   213  			},
   214  			TransactionId: rand.Int63(),
   215  		})
   216  	}
   217  	return &vtgatepb.Session{
   218  		InTransaction: inTransaction,
   219  		ShardSessions: shardSessions,
   220  	}
   221  }
   222  
   223  // Make sure FakeVTGateConn implements vtgateconn.Impl
   224  var _ vtgateconn.Impl = (*FakeVTGateConn)(nil)