google.golang.org/grpc@v1.74.2/credentials/alts/internal/testutil/testutil.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package testutil include useful test utilities for the handshaker.
    20  package testutil
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/binary"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"sync"
    29  	"time"
    30  
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/credentials/alts/internal/conn"
    33  	altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    34  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    35  )
    36  
    37  // Stats is used to collect statistics about concurrent handshake calls.
    38  type Stats struct {
    39  	mu                 sync.Mutex
    40  	calls              int
    41  	MaxConcurrentCalls int
    42  }
    43  
    44  // Update updates the statistics by adding one call.
    45  func (s *Stats) Update() func() {
    46  	s.mu.Lock()
    47  	s.calls++
    48  	if s.calls > s.MaxConcurrentCalls {
    49  		s.MaxConcurrentCalls = s.calls
    50  	}
    51  	s.mu.Unlock()
    52  
    53  	return func() {
    54  		s.mu.Lock()
    55  		s.calls--
    56  		s.mu.Unlock()
    57  	}
    58  }
    59  
    60  // Reset resets the statistics.
    61  func (s *Stats) Reset() {
    62  	s.mu.Lock()
    63  	defer s.mu.Unlock()
    64  	s.calls = 0
    65  	s.MaxConcurrentCalls = 0
    66  }
    67  
    68  // testConn mimics a net.Conn to the peer.
    69  type testConn struct {
    70  	net.Conn
    71  	in          *bytes.Buffer
    72  	out         *bytes.Buffer
    73  	readLatency time.Duration
    74  }
    75  
    76  // NewTestConn creates a new instance of testConn object.
    77  func NewTestConn(in *bytes.Buffer, out *bytes.Buffer) net.Conn {
    78  	return &testConn{
    79  		in:          in,
    80  		out:         out,
    81  		readLatency: time.Duration(0),
    82  	}
    83  }
    84  
    85  // NewTestConnWithReadLatency creates a new instance of testConn object that
    86  // pauses for readLatency before any call to Read() returns.
    87  func NewTestConnWithReadLatency(in *bytes.Buffer, out *bytes.Buffer, readLatency time.Duration) net.Conn {
    88  	return &testConn{
    89  		in:          in,
    90  		out:         out,
    91  		readLatency: readLatency,
    92  	}
    93  }
    94  
    95  // Read reads from the in buffer.
    96  func (c *testConn) Read(b []byte) (n int, err error) {
    97  	time.Sleep(c.readLatency)
    98  	return c.in.Read(b)
    99  }
   100  
   101  // Write writes to the out buffer.
   102  func (c *testConn) Write(b []byte) (n int, err error) {
   103  	return c.out.Write(b)
   104  }
   105  
   106  // Close closes the testConn object.
   107  func (c *testConn) Close() error {
   108  	return nil
   109  }
   110  
   111  // unresponsiveTestConn mimics a net.Conn for an unresponsive peer. It is used
   112  // for testing the PeerNotResponding case.
   113  type unresponsiveTestConn struct {
   114  	net.Conn
   115  }
   116  
   117  // NewUnresponsiveTestConn creates a new instance of unresponsiveTestConn object.
   118  func NewUnresponsiveTestConn() net.Conn {
   119  	return &unresponsiveTestConn{}
   120  }
   121  
   122  // Read reads from the in buffer.
   123  func (c *unresponsiveTestConn) Read([]byte) (n int, err error) {
   124  	return 0, io.EOF
   125  }
   126  
   127  // Write writes to the out buffer.
   128  func (c *unresponsiveTestConn) Write([]byte) (n int, err error) {
   129  	return 0, nil
   130  }
   131  
   132  // Close closes the TestConn object.
   133  func (c *unresponsiveTestConn) Close() error {
   134  	return nil
   135  }
   136  
   137  // MakeFrame creates a handshake frame.
   138  func MakeFrame(pl string) []byte {
   139  	f := make([]byte, len(pl)+conn.MsgLenFieldSize)
   140  	binary.LittleEndian.PutUint32(f, uint32(len(pl)))
   141  	copy(f[conn.MsgLenFieldSize:], []byte(pl))
   142  	return f
   143  }
   144  
   145  // FakeHandshaker is a fake implementation of the ALTS handshaker service.
   146  type FakeHandshaker struct {
   147  	altsgrpc.HandshakerServiceServer
   148  	// ExpectedBoundAccessToken is the expected bound access token in the ClientStart request.
   149  	ExpectedBoundAccessToken string
   150  }
   151  
   152  // DoHandshake performs a fake ALTS handshake.
   153  func (h *FakeHandshaker) DoHandshake(stream altsgrpc.HandshakerService_DoHandshakeServer) error {
   154  	var isAssistingClient bool
   155  	var handshakeFramesReceivedSoFar []byte
   156  	for {
   157  		req, err := stream.Recv()
   158  		if err != nil {
   159  			if err == io.EOF {
   160  				return nil
   161  			}
   162  			return fmt.Errorf("stream recv failure: %v", err)
   163  		}
   164  		var resp *altspb.HandshakerResp
   165  		switch req := req.ReqOneof.(type) {
   166  		case *altspb.HandshakerReq_ClientStart:
   167  			isAssistingClient = true
   168  			resp, err = h.processStartClient(req.ClientStart)
   169  			if err != nil {
   170  				return fmt.Errorf("processStartClient failure: %v", err)
   171  			}
   172  		case *altspb.HandshakerReq_ServerStart:
   173  			// If we have received the full ClientInit, send the ServerInit and
   174  			// ServerFinished. Otherwise, wait for more bytes to arrive from the client.
   175  			isAssistingClient = false
   176  			handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.ServerStart.InBytes...)
   177  			sendHandshakeFrame := bytes.Equal(handshakeFramesReceivedSoFar, []byte("ClientInit"))
   178  			resp, err = h.processServerStart(req.ServerStart, sendHandshakeFrame)
   179  			if err != nil {
   180  				return fmt.Errorf("processServerStart failure: %v", err)
   181  			}
   182  		case *altspb.HandshakerReq_Next:
   183  			// If we have received all handshake frames, send the handshake result.
   184  			// Otherwise, wait for more bytes to arrive from the peer.
   185  			oldHandshakesBytes := len(handshakeFramesReceivedSoFar)
   186  			handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.Next.InBytes...)
   187  			isHandshakeComplete := false
   188  			if isAssistingClient {
   189  				isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ServerInitServerFinished"))
   190  			} else {
   191  				isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ClientInitClientFinished"))
   192  			}
   193  			if !isHandshakeComplete {
   194  				resp = &altspb.HandshakerResp{
   195  					BytesConsumed: uint32(len(handshakeFramesReceivedSoFar) - oldHandshakesBytes),
   196  					Status: &altspb.HandshakerStatus{
   197  						Code: uint32(codes.OK),
   198  					},
   199  				}
   200  				break
   201  			}
   202  			resp, err = h.getHandshakeResult(isAssistingClient)
   203  			if err != nil {
   204  				return fmt.Errorf("getHandshakeResult failure: %v", err)
   205  			}
   206  		default:
   207  			return fmt.Errorf("handshake request has unexpected type: %v", req)
   208  		}
   209  
   210  		if err = stream.Send(resp); err != nil {
   211  			return fmt.Errorf("stream send failure: %v", err)
   212  		}
   213  	}
   214  }
   215  
   216  func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq) (*altspb.HandshakerResp, error) {
   217  	if req.HandshakeSecurityProtocol != altspb.HandshakeProtocol_ALTS {
   218  		return nil, fmt.Errorf("unexpected handshake security protocol: %v", req.HandshakeSecurityProtocol)
   219  	}
   220  	if len(req.ApplicationProtocols) != 1 || req.ApplicationProtocols[0] != "grpc" {
   221  		return nil, fmt.Errorf("unexpected application protocols: %v", req.ApplicationProtocols)
   222  	}
   223  	if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
   224  		return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols)
   225  	}
   226  	if h.ExpectedBoundAccessToken != req.GetAccessToken() {
   227  		return nil, fmt.Errorf("unexpected access token: %v", req.GetAccessToken())
   228  	}
   229  	return &altspb.HandshakerResp{
   230  		OutFrames:     []byte("ClientInit"),
   231  		BytesConsumed: 0,
   232  		Status: &altspb.HandshakerStatus{
   233  			Code: uint32(codes.OK),
   234  		},
   235  	}, nil
   236  }
   237  
   238  func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq, sendHandshakeFrame bool) (*altspb.HandshakerResp, error) {
   239  	if len(req.ApplicationProtocols) != 1 || req.ApplicationProtocols[0] != "grpc" {
   240  		return nil, fmt.Errorf("unexpected application protocols: %v", req.ApplicationProtocols)
   241  	}
   242  	parameters, ok := req.GetHandshakeParameters()[int32(altspb.HandshakeProtocol_ALTS)]
   243  	if !ok {
   244  		return nil, fmt.Errorf("missing ALTS handshake parameters")
   245  	}
   246  	if len(parameters.RecordProtocols) != 1 || parameters.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
   247  		return nil, fmt.Errorf("unexpected record protocols: %v", parameters.RecordProtocols)
   248  	}
   249  	if sendHandshakeFrame {
   250  		return &altspb.HandshakerResp{
   251  			OutFrames:     []byte("ServerInitServerFinished"),
   252  			BytesConsumed: uint32(len(req.InBytes)),
   253  			Status: &altspb.HandshakerStatus{
   254  				Code: uint32(codes.OK),
   255  			},
   256  		}, nil
   257  	}
   258  	return &altspb.HandshakerResp{
   259  		OutFrames:     []byte("ServerInitServerFinished"),
   260  		BytesConsumed: 10,
   261  		Status: &altspb.HandshakerStatus{
   262  			Code: uint32(codes.OK),
   263  		},
   264  	}, nil
   265  }
   266  
   267  func (h *FakeHandshaker) getHandshakeResult(isAssistingClient bool) (*altspb.HandshakerResp, error) {
   268  	if isAssistingClient {
   269  		return &altspb.HandshakerResp{
   270  			OutFrames:     []byte("ClientFinished"),
   271  			BytesConsumed: 24,
   272  			Result: &altspb.HandshakerResult{
   273  				ApplicationProtocol: "grpc",
   274  				RecordProtocol:      "ALTSRP_GCM_AES128_REKEY",
   275  				KeyData:             []byte("negotiated-key-data-for-altsrp-gcm-aes128-rekey"),
   276  				PeerIdentity: &altspb.Identity{
   277  					IdentityOneof: &altspb.Identity_ServiceAccount{
   278  						ServiceAccount: "server@bar.com",
   279  					},
   280  				},
   281  				PeerRpcVersions: &altspb.RpcProtocolVersions{
   282  					MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
   283  						Minor: 1,
   284  						Major: 2,
   285  					},
   286  					MinRpcVersion: &altspb.RpcProtocolVersions_Version{
   287  						Minor: 1,
   288  						Major: 2,
   289  					},
   290  				},
   291  			},
   292  			Status: &altspb.HandshakerStatus{
   293  				Code: uint32(codes.OK),
   294  			},
   295  		}, nil
   296  	}
   297  	return &altspb.HandshakerResp{
   298  		BytesConsumed: 14,
   299  		Result: &altspb.HandshakerResult{
   300  			ApplicationProtocol: "grpc",
   301  			RecordProtocol:      "ALTSRP_GCM_AES128_REKEY",
   302  			KeyData:             []byte("negotiated-key-data-for-altsrp-gcm-aes128-rekey"),
   303  			PeerIdentity: &altspb.Identity{
   304  				IdentityOneof: &altspb.Identity_ServiceAccount{
   305  					ServiceAccount: "client@baz.com",
   306  				},
   307  			},
   308  			PeerRpcVersions: &altspb.RpcProtocolVersions{
   309  				MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
   310  					Minor: 1,
   311  					Major: 2,
   312  				},
   313  				MinRpcVersion: &altspb.RpcProtocolVersions_Version{
   314  					Minor: 1,
   315  					Major: 2,
   316  				},
   317  			},
   318  		},
   319  		Status: &altspb.HandshakerStatus{
   320  			Code: uint32(codes.OK),
   321  		},
   322  	}, nil
   323  }