google.golang.org/grpc@v1.62.1/credentials/alts/internal/testutil/testutil.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package testutil include useful test utilities for the handshaker.
    20  package testutil
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/binary"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"sync"
    29  	"time"
    30  
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/credentials/alts/internal/conn"
    33  	altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    34  	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
    35  )
    36  
    37  // Stats is used to collect statistics about concurrent handshake calls.
    38  type Stats struct {
    39  	mu                 sync.Mutex
    40  	calls              int
    41  	MaxConcurrentCalls int
    42  }
    43  
    44  // Update updates the statistics by adding one call.
    45  func (s *Stats) Update() func() {
    46  	s.mu.Lock()
    47  	s.calls++
    48  	if s.calls > s.MaxConcurrentCalls {
    49  		s.MaxConcurrentCalls = s.calls
    50  	}
    51  	s.mu.Unlock()
    52  
    53  	return func() {
    54  		s.mu.Lock()
    55  		s.calls--
    56  		s.mu.Unlock()
    57  	}
    58  }
    59  
    60  // Reset resets the statistics.
    61  func (s *Stats) Reset() {
    62  	s.mu.Lock()
    63  	defer s.mu.Unlock()
    64  	s.calls = 0
    65  	s.MaxConcurrentCalls = 0
    66  }
    67  
    68  // testConn mimics a net.Conn to the peer.
    69  type testConn struct {
    70  	net.Conn
    71  	in          *bytes.Buffer
    72  	out         *bytes.Buffer
    73  	readLatency time.Duration
    74  }
    75  
    76  // NewTestConn creates a new instance of testConn object.
    77  func NewTestConn(in *bytes.Buffer, out *bytes.Buffer) net.Conn {
    78  	return &testConn{
    79  		in:          in,
    80  		out:         out,
    81  		readLatency: time.Duration(0),
    82  	}
    83  }
    84  
    85  // NewTestConnWithReadLatency creates a new instance of testConn object that
    86  // pauses for readLatency before any call to Read() returns.
    87  func NewTestConnWithReadLatency(in *bytes.Buffer, out *bytes.Buffer, readLatency time.Duration) net.Conn {
    88  	return &testConn{
    89  		in:          in,
    90  		out:         out,
    91  		readLatency: readLatency,
    92  	}
    93  }
    94  
    95  // Read reads from the in buffer.
    96  func (c *testConn) Read(b []byte) (n int, err error) {
    97  	time.Sleep(c.readLatency)
    98  	return c.in.Read(b)
    99  }
   100  
   101  // Write writes to the out buffer.
   102  func (c *testConn) Write(b []byte) (n int, err error) {
   103  	return c.out.Write(b)
   104  }
   105  
   106  // Close closes the testConn object.
   107  func (c *testConn) Close() error {
   108  	return nil
   109  }
   110  
   111  // unresponsiveTestConn mimics a net.Conn for an unresponsive peer. It is used
   112  // for testing the PeerNotResponding case.
   113  type unresponsiveTestConn struct {
   114  	net.Conn
   115  }
   116  
   117  // NewUnresponsiveTestConn creates a new instance of unresponsiveTestConn object.
   118  func NewUnresponsiveTestConn() net.Conn {
   119  	return &unresponsiveTestConn{}
   120  }
   121  
   122  // Read reads from the in buffer.
   123  func (c *unresponsiveTestConn) Read(b []byte) (n int, err error) {
   124  	return 0, io.EOF
   125  }
   126  
   127  // Write writes to the out buffer.
   128  func (c *unresponsiveTestConn) Write(b []byte) (n int, err error) {
   129  	return 0, nil
   130  }
   131  
   132  // Close closes the TestConn object.
   133  func (c *unresponsiveTestConn) Close() error {
   134  	return nil
   135  }
   136  
   137  // MakeFrame creates a handshake frame.
   138  func MakeFrame(pl string) []byte {
   139  	f := make([]byte, len(pl)+conn.MsgLenFieldSize)
   140  	binary.LittleEndian.PutUint32(f, uint32(len(pl)))
   141  	copy(f[conn.MsgLenFieldSize:], []byte(pl))
   142  	return f
   143  }
   144  
   145  // FakeHandshaker is a fake implementation of the ALTS handshaker service.
   146  type FakeHandshaker struct {
   147  	altsgrpc.HandshakerServiceServer
   148  }
   149  
   150  // DoHandshake performs a fake ALTS handshake.
   151  func (h *FakeHandshaker) DoHandshake(stream altsgrpc.HandshakerService_DoHandshakeServer) error {
   152  	var isAssistingClient bool
   153  	var handshakeFramesReceivedSoFar []byte
   154  	for {
   155  		req, err := stream.Recv()
   156  		if err != nil {
   157  			if err == io.EOF {
   158  				return nil
   159  			}
   160  			return fmt.Errorf("stream recv failure: %v", err)
   161  		}
   162  		var resp *altspb.HandshakerResp
   163  		switch req := req.ReqOneof.(type) {
   164  		case *altspb.HandshakerReq_ClientStart:
   165  			isAssistingClient = true
   166  			resp, err = h.processStartClient(req.ClientStart)
   167  			if err != nil {
   168  				return fmt.Errorf("processStartClient failure: %v", err)
   169  			}
   170  		case *altspb.HandshakerReq_ServerStart:
   171  			// If we have received the full ClientInit, send the ServerInit and
   172  			// ServerFinished. Otherwise, wait for more bytes to arrive from the client.
   173  			isAssistingClient = false
   174  			handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.ServerStart.InBytes...)
   175  			sendHandshakeFrame := bytes.Equal(handshakeFramesReceivedSoFar, []byte("ClientInit"))
   176  			resp, err = h.processServerStart(req.ServerStart, sendHandshakeFrame)
   177  			if err != nil {
   178  				return fmt.Errorf("processServerStart failure: %v", err)
   179  			}
   180  		case *altspb.HandshakerReq_Next:
   181  			// If we have received all handshake frames, send the handshake result.
   182  			// Otherwise, wait for more bytes to arrive from the peer.
   183  			oldHandshakesBytes := len(handshakeFramesReceivedSoFar)
   184  			handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.Next.InBytes...)
   185  			isHandshakeComplete := false
   186  			if isAssistingClient {
   187  				isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ServerInitServerFinished"))
   188  			} else {
   189  				isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ClientInitClientFinished"))
   190  			}
   191  			if !isHandshakeComplete {
   192  				resp = &altspb.HandshakerResp{
   193  					BytesConsumed: uint32(len(handshakeFramesReceivedSoFar) - oldHandshakesBytes),
   194  					Status: &altspb.HandshakerStatus{
   195  						Code: uint32(codes.OK),
   196  					},
   197  				}
   198  				break
   199  			}
   200  			resp, err = h.getHandshakeResult(isAssistingClient)
   201  			if err != nil {
   202  				return fmt.Errorf("getHandshakeResult failure: %v", err)
   203  			}
   204  		default:
   205  			return fmt.Errorf("handshake request has unexpected type: %v", req)
   206  		}
   207  
   208  		if err = stream.Send(resp); err != nil {
   209  			return fmt.Errorf("stream send failure: %v", err)
   210  		}
   211  	}
   212  }
   213  
   214  func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq) (*altspb.HandshakerResp, error) {
   215  	if req.HandshakeSecurityProtocol != altspb.HandshakeProtocol_ALTS {
   216  		return nil, fmt.Errorf("unexpected handshake security protocol: %v", req.HandshakeSecurityProtocol)
   217  	}
   218  	if len(req.ApplicationProtocols) != 1 || req.ApplicationProtocols[0] != "grpc" {
   219  		return nil, fmt.Errorf("unexpected application protocols: %v", req.ApplicationProtocols)
   220  	}
   221  	if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
   222  		return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols)
   223  	}
   224  	return &altspb.HandshakerResp{
   225  		OutFrames:     []byte("ClientInit"),
   226  		BytesConsumed: 0,
   227  		Status: &altspb.HandshakerStatus{
   228  			Code: uint32(codes.OK),
   229  		},
   230  	}, nil
   231  }
   232  
   233  func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq, sendHandshakeFrame bool) (*altspb.HandshakerResp, error) {
   234  	if len(req.ApplicationProtocols) != 1 || req.ApplicationProtocols[0] != "grpc" {
   235  		return nil, fmt.Errorf("unexpected application protocols: %v", req.ApplicationProtocols)
   236  	}
   237  	parameters, ok := req.GetHandshakeParameters()[int32(altspb.HandshakeProtocol_ALTS)]
   238  	if !ok {
   239  		return nil, fmt.Errorf("missing ALTS handshake parameters")
   240  	}
   241  	if len(parameters.RecordProtocols) != 1 || parameters.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
   242  		return nil, fmt.Errorf("unexpected record protocols: %v", parameters.RecordProtocols)
   243  	}
   244  	if sendHandshakeFrame {
   245  		return &altspb.HandshakerResp{
   246  			OutFrames:     []byte("ServerInitServerFinished"),
   247  			BytesConsumed: uint32(len(req.InBytes)),
   248  			Status: &altspb.HandshakerStatus{
   249  				Code: uint32(codes.OK),
   250  			},
   251  		}, nil
   252  	}
   253  	return &altspb.HandshakerResp{
   254  		OutFrames:     []byte("ServerInitServerFinished"),
   255  		BytesConsumed: 10,
   256  		Status: &altspb.HandshakerStatus{
   257  			Code: uint32(codes.OK),
   258  		},
   259  	}, nil
   260  }
   261  
   262  func (h *FakeHandshaker) getHandshakeResult(isAssistingClient bool) (*altspb.HandshakerResp, error) {
   263  	if isAssistingClient {
   264  		return &altspb.HandshakerResp{
   265  			OutFrames:     []byte("ClientFinished"),
   266  			BytesConsumed: 24,
   267  			Result: &altspb.HandshakerResult{
   268  				ApplicationProtocol: "grpc",
   269  				RecordProtocol:      "ALTSRP_GCM_AES128_REKEY",
   270  				KeyData:             []byte("negotiated-key-data-for-altsrp-gcm-aes128-rekey"),
   271  				PeerIdentity: &altspb.Identity{
   272  					IdentityOneof: &altspb.Identity_ServiceAccount{
   273  						ServiceAccount: "server@bar.com",
   274  					},
   275  				},
   276  				PeerRpcVersions: &altspb.RpcProtocolVersions{
   277  					MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
   278  						Minor: 1,
   279  						Major: 2,
   280  					},
   281  					MinRpcVersion: &altspb.RpcProtocolVersions_Version{
   282  						Minor: 1,
   283  						Major: 2,
   284  					},
   285  				},
   286  			},
   287  			Status: &altspb.HandshakerStatus{
   288  				Code: uint32(codes.OK),
   289  			},
   290  		}, nil
   291  	}
   292  	return &altspb.HandshakerResp{
   293  		BytesConsumed: 14,
   294  		Result: &altspb.HandshakerResult{
   295  			ApplicationProtocol: "grpc",
   296  			RecordProtocol:      "ALTSRP_GCM_AES128_REKEY",
   297  			KeyData:             []byte("negotiated-key-data-for-altsrp-gcm-aes128-rekey"),
   298  			PeerIdentity: &altspb.Identity{
   299  				IdentityOneof: &altspb.Identity_ServiceAccount{
   300  					ServiceAccount: "client@baz.com",
   301  				},
   302  			},
   303  			PeerRpcVersions: &altspb.RpcProtocolVersions{
   304  				MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
   305  					Minor: 1,
   306  					Major: 2,
   307  				},
   308  				MinRpcVersion: &altspb.RpcProtocolVersions_Version{
   309  					Minor: 1,
   310  					Major: 2,
   311  				},
   312  			},
   313  		},
   314  		Status: &altspb.HandshakerStatus{
   315  			Code: uint32(codes.OK),
   316  		},
   317  	}, nil
   318  }