vitess.io/vitess@v0.16.2/go/vt/vtgate/grpcvtgateconn/conn.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package grpcvtgateconn provides gRPC connectivity for VTGate.
    18  package grpcvtgateconn
    19  
    20  import (
    21  	"context"
    22  
    23  	"github.com/spf13/pflag"
    24  	"google.golang.org/grpc"
    25  
    26  	"vitess.io/vitess/go/sqltypes"
    27  	"vitess.io/vitess/go/vt/callerid"
    28  	"vitess.io/vitess/go/vt/grpcclient"
    29  	"vitess.io/vitess/go/vt/servenv"
    30  	"vitess.io/vitess/go/vt/vterrors"
    31  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    32  
    33  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    34  	querypb "vitess.io/vitess/go/vt/proto/query"
    35  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    36  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    37  	vtgateservicepb "vitess.io/vitess/go/vt/proto/vtgateservice"
    38  )
    39  
    40  var (
    41  	cert string
    42  	key  string
    43  	ca   string
    44  	crl  string
    45  	name string
    46  )
    47  
    48  func init() {
    49  	vtgateconn.RegisterDialer("grpc", dial)
    50  
    51  	for _, cmd := range []string{
    52  		"vtbench",
    53  		"vtclient",
    54  		"vtcombo",
    55  		"vtctl",
    56  		"vttestserver",
    57  	} {
    58  		servenv.OnParseFor(cmd, registerFlags)
    59  	}
    60  }
    61  
    62  func registerFlags(fs *pflag.FlagSet) {
    63  	fs.StringVar(&cert, "vtgate_grpc_cert", "", "the cert to use to connect")
    64  	fs.StringVar(&key, "vtgate_grpc_key", "", "the key to use to connect")
    65  	fs.StringVar(&ca, "vtgate_grpc_ca", "", "the server ca to use to validate servers when connecting")
    66  	fs.StringVar(&crl, "vtgate_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
    67  	fs.StringVar(&name, "vtgate_grpc_server_name", "", "the server name to use to validate server certificate")
    68  }
    69  
    70  type vtgateConn struct {
    71  	cc *grpc.ClientConn
    72  	c  vtgateservicepb.VitessClient
    73  }
    74  
    75  func dial(ctx context.Context, addr string) (vtgateconn.Impl, error) {
    76  	return DialWithOpts(ctx)(ctx, addr)
    77  }
    78  
    79  // DialWithOpts allows for custom dial options to be set on a vtgateConn.
    80  func DialWithOpts(ctx context.Context, opts ...grpc.DialOption) vtgateconn.DialerFunc {
    81  	return func(ctx context.Context, address string) (vtgateconn.Impl, error) {
    82  		opt, err := grpcclient.SecureDialOption(cert, key, ca, crl, name)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  
    87  		opts = append(opts, opt)
    88  
    89  		cc, err := grpcclient.Dial(address, grpcclient.FailFast(false), opts...)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  
    94  		c := vtgateservicepb.NewVitessClient(cc)
    95  		return &vtgateConn{
    96  			cc: cc,
    97  			c:  c,
    98  		}, nil
    99  	}
   100  }
   101  
   102  func (conn *vtgateConn) Execute(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
   103  	request := &vtgatepb.ExecuteRequest{
   104  		CallerId: callerid.EffectiveCallerIDFromContext(ctx),
   105  		Session:  session,
   106  		Query: &querypb.BoundQuery{
   107  			Sql:           query,
   108  			BindVariables: bindVars,
   109  		},
   110  	}
   111  	response, err := conn.c.Execute(ctx, request)
   112  	if err != nil {
   113  		return session, nil, vterrors.FromGRPC(err)
   114  	}
   115  	if response.Error != nil {
   116  		return response.Session, nil, vterrors.FromVTRPC(response.Error)
   117  	}
   118  	return response.Session, sqltypes.Proto3ToResult(response.Result), nil
   119  }
   120  
   121  func (conn *vtgateConn) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, queryList []string, bindVarsList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
   122  	queries := make([]*querypb.BoundQuery, len(queryList))
   123  	for i, query := range queryList {
   124  		bq := &querypb.BoundQuery{Sql: query}
   125  		if len(bindVarsList) != 0 {
   126  			bq.BindVariables = bindVarsList[i]
   127  		}
   128  		queries[i] = bq
   129  	}
   130  	request := &vtgatepb.ExecuteBatchRequest{
   131  		CallerId: callerid.EffectiveCallerIDFromContext(ctx),
   132  		Session:  session,
   133  		Queries:  queries,
   134  	}
   135  	response, err := conn.c.ExecuteBatch(ctx, request)
   136  	if err != nil {
   137  		return session, nil, vterrors.FromGRPC(err)
   138  	}
   139  	if response.Error != nil {
   140  		return response.Session, nil, vterrors.FromVTRPC(response.Error)
   141  	}
   142  	return response.Session, sqltypes.Proto3ToQueryReponses(response.Results), nil
   143  }
   144  
   145  type streamExecuteAdapter struct {
   146  	recv   func() (*querypb.QueryResult, error)
   147  	fields []*querypb.Field
   148  }
   149  
   150  func (a *streamExecuteAdapter) Recv() (*sqltypes.Result, error) {
   151  	qr, err := a.recv()
   152  	if err != nil {
   153  		return nil, vterrors.FromGRPC(err)
   154  	}
   155  	if a.fields == nil {
   156  		a.fields = qr.Fields
   157  	}
   158  	return sqltypes.CustomProto3ToResult(a.fields, qr), nil
   159  }
   160  
   161  func (conn *vtgateConn) StreamExecute(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable) (sqltypes.ResultStream, error) {
   162  	req := &vtgatepb.StreamExecuteRequest{
   163  		CallerId: callerid.EffectiveCallerIDFromContext(ctx),
   164  		Query: &querypb.BoundQuery{
   165  			Sql:           query,
   166  			BindVariables: bindVars,
   167  		},
   168  		Session: session,
   169  	}
   170  	stream, err := conn.c.StreamExecute(ctx, req)
   171  	if err != nil {
   172  		return nil, vterrors.FromGRPC(err)
   173  	}
   174  	return &streamExecuteAdapter{
   175  		recv: func() (*querypb.QueryResult, error) {
   176  			ser, err := stream.Recv()
   177  			if err != nil {
   178  				return nil, err
   179  			}
   180  			return ser.Result, nil
   181  		},
   182  	}, nil
   183  }
   184  
   185  func (conn *vtgateConn) Prepare(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable) (*vtgatepb.Session, []*querypb.Field, error) {
   186  	request := &vtgatepb.PrepareRequest{
   187  		CallerId: callerid.EffectiveCallerIDFromContext(ctx),
   188  		Session:  session,
   189  		Query: &querypb.BoundQuery{
   190  			Sql:           query,
   191  			BindVariables: bindVars,
   192  		},
   193  	}
   194  	response, err := conn.c.Prepare(ctx, request)
   195  	if err != nil {
   196  		return session, nil, vterrors.FromGRPC(err)
   197  	}
   198  	if response.Error != nil {
   199  		return response.Session, nil, vterrors.FromVTRPC(response.Error)
   200  	}
   201  	return response.Session, response.Fields, nil
   202  }
   203  
   204  func (conn *vtgateConn) CloseSession(ctx context.Context, session *vtgatepb.Session) error {
   205  	request := &vtgatepb.CloseSessionRequest{
   206  		CallerId: callerid.EffectiveCallerIDFromContext(ctx),
   207  		Session:  session,
   208  	}
   209  	response, err := conn.c.CloseSession(ctx, request)
   210  	if err != nil {
   211  		return vterrors.FromGRPC(err)
   212  	}
   213  	if response.Error != nil {
   214  		return vterrors.FromVTRPC(response.Error)
   215  	}
   216  	return nil
   217  }
   218  
   219  func (conn *vtgateConn) ResolveTransaction(ctx context.Context, dtid string) error {
   220  	request := &vtgatepb.ResolveTransactionRequest{
   221  		CallerId: callerid.EffectiveCallerIDFromContext(ctx),
   222  		Dtid:     dtid,
   223  	}
   224  	_, err := conn.c.ResolveTransaction(ctx, request)
   225  	return vterrors.FromGRPC(err)
   226  }
   227  
   228  type vstreamAdapter struct {
   229  	stream vtgateservicepb.Vitess_VStreamClient
   230  }
   231  
   232  func (a *vstreamAdapter) Recv() ([]*binlogdatapb.VEvent, error) {
   233  	r, err := a.stream.Recv()
   234  	if err != nil {
   235  		return nil, vterrors.FromGRPC(err)
   236  	}
   237  	return r.Events, nil
   238  }
   239  
   240  func (conn *vtgateConn) VStream(ctx context.Context, tabletType topodatapb.TabletType, vgtid *binlogdatapb.VGtid,
   241  	filter *binlogdatapb.Filter, flags *vtgatepb.VStreamFlags) (vtgateconn.VStreamReader, error) {
   242  
   243  	req := &vtgatepb.VStreamRequest{
   244  		CallerId:   callerid.EffectiveCallerIDFromContext(ctx),
   245  		TabletType: tabletType,
   246  		Vgtid:      vgtid,
   247  		Filter:     filter,
   248  		Flags:      flags,
   249  	}
   250  	stream, err := conn.c.VStream(ctx, req)
   251  	if err != nil {
   252  		return nil, vterrors.FromGRPC(err)
   253  	}
   254  	return &vstreamAdapter{
   255  		stream: stream,
   256  	}, nil
   257  }
   258  
   259  func (conn *vtgateConn) Close() {
   260  	conn.cc.Close()
   261  }
   262  
   263  // Make sure vtgateConn implements vtgateconn.Impl
   264  var _ vtgateconn.Impl = (*vtgateConn)(nil)