vitess.io/vitess@v0.16.2/go/vt/vtgate/grpcvtgateservice/server.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package grpcvtgateservice provides the gRPC glue for vtgate
    18  package grpcvtgateservice
    19  
    20  import (
    21  	"context"
    22  
    23  	"github.com/spf13/pflag"
    24  	"google.golang.org/grpc"
    25  	"google.golang.org/grpc/credentials"
    26  	"google.golang.org/grpc/peer"
    27  
    28  	"vitess.io/vitess/go/sqltypes"
    29  	"vitess.io/vitess/go/vt/callerid"
    30  	"vitess.io/vitess/go/vt/callinfo"
    31  	"vitess.io/vitess/go/vt/servenv"
    32  	"vitess.io/vitess/go/vt/vterrors"
    33  	"vitess.io/vitess/go/vt/vtgate"
    34  	"vitess.io/vitess/go/vt/vtgate/vtgateservice"
    35  
    36  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    37  	querypb "vitess.io/vitess/go/vt/proto/query"
    38  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    39  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    40  	vtgateservicepb "vitess.io/vitess/go/vt/proto/vtgateservice"
    41  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    42  )
    43  
    44  const (
    45  	unsecureClient = "unsecure_grpc_client"
    46  )
    47  
    48  var (
    49  	useEffective                    bool
    50  	useEffectiveGroups              bool
    51  	useStaticAuthenticationIdentity bool
    52  )
    53  
    54  func registerFlags(fs *pflag.FlagSet) {
    55  	fs.BoolVar(&useEffective, "grpc_use_effective_callerid", false, "If set, and SSL is not used, will set the immediate caller id from the effective caller id's principal.")
    56  	fs.BoolVar(&useEffectiveGroups, "grpc-use-effective-groups", false, "If set, and SSL is not used, will set the immediate caller's security groups from the effective caller id's groups.")
    57  	fs.BoolVar(&useStaticAuthenticationIdentity, "grpc-use-static-authentication-callerid", false, "If set, will set the immediate caller id to the username authenticated by the static auth plugin.")
    58  }
    59  
    60  func init() {
    61  	servenv.OnParseFor("vtgate", registerFlags)
    62  	servenv.OnParseFor("vtcombo", registerFlags)
    63  }
    64  
    65  // VTGate is the public structure that is exported via gRPC
    66  type VTGate struct {
    67  	vtgateservicepb.UnimplementedVitessServer
    68  	server vtgateservice.VTGateService
    69  }
    70  
    71  // immediateCallerIDFromCert tries to extract the common name as well as the (domain) subject
    72  // alternative names of the certificate that was used to connect to vtgate.
    73  // If it fails for any reason, it will return "".
    74  // That immediate caller id is then inserted into a Context,
    75  // and will be used when talking to vttablet.
    76  // vttablet in turn can use table ACLs to validate access is authorized.
    77  func immediateCallerIDFromCert(ctx context.Context) (string, []string) {
    78  	p, ok := peer.FromContext(ctx)
    79  	if !ok {
    80  		return "", nil
    81  	}
    82  	if p.AuthInfo == nil {
    83  		return "", nil
    84  	}
    85  	tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo)
    86  	if !ok {
    87  		return "", nil
    88  	}
    89  	if len(tlsInfo.State.VerifiedChains) < 1 {
    90  		return "", nil
    91  	}
    92  	if len(tlsInfo.State.VerifiedChains[0]) < 1 {
    93  		return "", nil
    94  	}
    95  	cert := tlsInfo.State.VerifiedChains[0][0]
    96  	return cert.Subject.CommonName, cert.DNSNames
    97  }
    98  
    99  // immediateCallerIdFromStaticAuthentication extracts the username of the current
   100  // static authentication context and returns that to the caller.
   101  func immediateCallerIdFromStaticAuthentication(ctx context.Context) (string, []string) {
   102  	if immediate := servenv.StaticAuthUsernameFromContext(ctx); immediate != "" {
   103  		return immediate, nil
   104  	}
   105  
   106  	return "", nil
   107  }
   108  
   109  // withCallerIDContext creates a context that extracts what we need
   110  // from the incoming call and can be forwarded for use when talking to vttablet.
   111  func withCallerIDContext(ctx context.Context, effectiveCallerID *vtrpcpb.CallerID) context.Context {
   112  	// The client cert common name (if using mTLS)
   113  	immediate, securityGroups := immediateCallerIDFromCert(ctx)
   114  
   115  	// The effective caller id (if --grpc_use_effective_callerid=true)
   116  	if immediate == "" && useEffective && effectiveCallerID != nil {
   117  		immediate = effectiveCallerID.Principal
   118  		if useEffectiveGroups && len(effectiveCallerID.Groups) > 0 {
   119  			securityGroups = effectiveCallerID.Groups
   120  		}
   121  	}
   122  
   123  	// The static auth username (if --grpc-use-static-authentication-callerid=true)
   124  	if immediate == "" && useStaticAuthenticationIdentity {
   125  		immediate, securityGroups = immediateCallerIdFromStaticAuthentication(ctx)
   126  	}
   127  
   128  	if immediate == "" {
   129  		immediate = unsecureClient
   130  	}
   131  	return callerid.NewContext(callinfo.GRPCCallInfo(ctx),
   132  		effectiveCallerID,
   133  		&querypb.VTGateCallerID{Username: immediate, Groups: securityGroups})
   134  }
   135  
   136  // Execute is the RPC version of vtgateservice.VTGateService method
   137  func (vtg *VTGate) Execute(ctx context.Context, request *vtgatepb.ExecuteRequest) (response *vtgatepb.ExecuteResponse, err error) {
   138  	defer vtg.server.HandlePanic(&err)
   139  	ctx = withCallerIDContext(ctx, request.CallerId)
   140  
   141  	// Handle backward compatibility.
   142  	session := request.Session
   143  	if session == nil {
   144  		session = &vtgatepb.Session{Autocommit: true}
   145  	}
   146  	session, result, err := vtg.server.Execute(ctx, session, request.Query.Sql, request.Query.BindVariables)
   147  	return &vtgatepb.ExecuteResponse{
   148  		Result:  sqltypes.ResultToProto3(result),
   149  		Session: session,
   150  		Error:   vterrors.ToVTRPC(err),
   151  	}, nil
   152  }
   153  
   154  // ExecuteBatch is the RPC version of vtgateservice.VTGateService method
   155  func (vtg *VTGate) ExecuteBatch(ctx context.Context, request *vtgatepb.ExecuteBatchRequest) (response *vtgatepb.ExecuteBatchResponse, err error) {
   156  	defer vtg.server.HandlePanic(&err)
   157  	ctx = withCallerIDContext(ctx, request.CallerId)
   158  	sqlQueries := make([]string, len(request.Queries))
   159  	bindVars := make([]map[string]*querypb.BindVariable, len(request.Queries))
   160  	for queryNum, query := range request.Queries {
   161  		sqlQueries[queryNum] = query.Sql
   162  		bindVars[queryNum] = query.BindVariables
   163  	}
   164  	// Handle backward compatibility.
   165  	session := request.Session
   166  	if session == nil {
   167  		session = &vtgatepb.Session{Autocommit: true}
   168  	}
   169  	session, results, err := vtg.server.ExecuteBatch(ctx, session, sqlQueries, bindVars)
   170  	return &vtgatepb.ExecuteBatchResponse{
   171  		Results: sqltypes.QueryResponsesToProto3(results),
   172  		Session: session,
   173  		Error:   vterrors.ToVTRPC(err),
   174  	}, nil
   175  }
   176  
   177  // StreamExecute is the RPC version of vtgateservice.VTGateService method
   178  func (vtg *VTGate) StreamExecute(request *vtgatepb.StreamExecuteRequest, stream vtgateservicepb.Vitess_StreamExecuteServer) (err error) {
   179  	defer vtg.server.HandlePanic(&err)
   180  	ctx := withCallerIDContext(stream.Context(), request.CallerId)
   181  
   182  	// Handle backward compatibility.
   183  	session := request.Session
   184  	if session == nil {
   185  		session = &vtgatepb.Session{Autocommit: true}
   186  	}
   187  
   188  	vtgErr := vtg.server.StreamExecute(ctx, session, request.Query.Sql, request.Query.BindVariables, func(value *sqltypes.Result) error {
   189  		// Send is not safe to call concurrently, but vtgate
   190  		// guarantees that it's not.
   191  		return stream.Send(&vtgatepb.StreamExecuteResponse{
   192  			Result: sqltypes.ResultToProto3(value),
   193  		})
   194  	})
   195  	return vterrors.ToGRPC(vtgErr)
   196  }
   197  
   198  // Prepare is the RPC version of vtgateservice.VTGateService method
   199  func (vtg *VTGate) Prepare(ctx context.Context, request *vtgatepb.PrepareRequest) (response *vtgatepb.PrepareResponse, err error) {
   200  	defer vtg.server.HandlePanic(&err)
   201  	ctx = withCallerIDContext(ctx, request.CallerId)
   202  
   203  	session := request.Session
   204  	if session == nil {
   205  		session = &vtgatepb.Session{Autocommit: true}
   206  	}
   207  
   208  	session, fields, err := vtg.server.Prepare(ctx, session, request.Query.Sql, request.Query.BindVariables)
   209  	return &vtgatepb.PrepareResponse{
   210  		Fields:  fields,
   211  		Session: session,
   212  		Error:   vterrors.ToVTRPC(err),
   213  	}, nil
   214  }
   215  
   216  // CloseSession is the RPC version of vtgateservice.VTGateService method
   217  func (vtg *VTGate) CloseSession(ctx context.Context, request *vtgatepb.CloseSessionRequest) (response *vtgatepb.CloseSessionResponse, err error) {
   218  	defer vtg.server.HandlePanic(&err)
   219  	ctx = withCallerIDContext(ctx, request.CallerId)
   220  
   221  	session := request.Session
   222  	if session == nil {
   223  		session = &vtgatepb.Session{Autocommit: true}
   224  	}
   225  	err = vtg.server.CloseSession(ctx, session)
   226  	return &vtgatepb.CloseSessionResponse{
   227  		Error: vterrors.ToVTRPC(err),
   228  	}, nil
   229  }
   230  
   231  // ResolveTransaction is the RPC version of vtgateservice.VTGateService method
   232  func (vtg *VTGate) ResolveTransaction(ctx context.Context, request *vtgatepb.ResolveTransactionRequest) (response *vtgatepb.ResolveTransactionResponse, err error) {
   233  	defer vtg.server.HandlePanic(&err)
   234  	ctx = withCallerIDContext(ctx, request.CallerId)
   235  	vtgErr := vtg.server.ResolveTransaction(ctx, request.Dtid)
   236  	response = &vtgatepb.ResolveTransactionResponse{}
   237  	if vtgErr == nil {
   238  		return response, nil
   239  	}
   240  	return nil, vterrors.ToGRPC(vtgErr)
   241  }
   242  
   243  // VStream is the RPC version of vtgateservice.VTGateService method
   244  func (vtg *VTGate) VStream(request *vtgatepb.VStreamRequest, stream vtgateservicepb.Vitess_VStreamServer) (err error) {
   245  	defer vtg.server.HandlePanic(&err)
   246  	ctx := withCallerIDContext(stream.Context(), request.CallerId)
   247  
   248  	// For backward compatibility.
   249  	// The mysql query equivalent has logic to use topodatapb.TabletType_PRIMARY if tablet_type is not set.
   250  	tabletType := request.TabletType
   251  	if tabletType == topodatapb.TabletType_UNKNOWN {
   252  		tabletType = topodatapb.TabletType_PRIMARY
   253  	}
   254  	vtgErr := vtg.server.VStream(ctx,
   255  		tabletType,
   256  		request.Vgtid,
   257  		request.Filter,
   258  		request.Flags,
   259  		func(events []*binlogdatapb.VEvent) error {
   260  			return stream.Send(&vtgatepb.VStreamResponse{
   261  				Events: events,
   262  			})
   263  		})
   264  	return vterrors.ToGRPC(vtgErr)
   265  }
   266  
   267  func init() {
   268  	vtgate.RegisterVTGates = append(vtgate.RegisterVTGates, func(vtGate vtgateservice.VTGateService) {
   269  		if servenv.GRPCCheckServiceMap("vtgateservice") {
   270  			vtgateservicepb.RegisterVitessServer(servenv.GRPCServer, &VTGate{server: vtGate})
   271  		}
   272  	})
   273  }
   274  
   275  // RegisterForTest registers the gRPC implementation on the gRPC
   276  // server.  Useful for unit tests only, for real use, the init()
   277  // function does the registration.
   278  func RegisterForTest(s *grpc.Server, service vtgateservice.VTGateService) {
   279  	vtgateservicepb.RegisterVitessServer(s, &VTGate{server: service})
   280  }