google.golang.org/grpc@v1.74.2/credentials/alts/internal/conn/record_test.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package conn
    20  
    21  import (
    22  	"bytes"
    23  	"encoding/binary"
    24  	"fmt"
    25  	"io"
    26  	"math"
    27  	"net"
    28  	"reflect"
    29  	"strings"
    30  	"testing"
    31  
    32  	core "google.golang.org/grpc/credentials/alts/internal"
    33  	"google.golang.org/grpc/internal/grpctest"
    34  )
    35  
    36  type s struct {
    37  	grpctest.Tester
    38  }
    39  
    40  func Test(t *testing.T) {
    41  	grpctest.RunSubTests(t, s{})
    42  }
    43  
    44  const (
    45  	rekeyRecordProtocol = "ALTSRP_GCM_AES128_REKEY"
    46  )
    47  
    48  var (
    49  	recordProtocols = []string{rekeyRecordProtocol}
    50  	altsRecordFuncs = map[string]ALTSRecordFunc{
    51  		// ALTS handshaker protocols.
    52  		rekeyRecordProtocol: func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
    53  			return NewAES128GCM(s, keyData)
    54  		},
    55  	}
    56  )
    57  
    58  func init() {
    59  	for protocol, f := range altsRecordFuncs {
    60  		if err := RegisterProtocol(protocol, f); err != nil {
    61  			panic(err)
    62  		}
    63  	}
    64  }
    65  
    66  // testConn mimics a net.Conn to the peer.
    67  type testConn struct {
    68  	net.Conn
    69  	in  *bytes.Buffer
    70  	out *bytes.Buffer
    71  }
    72  
    73  func (c *testConn) Read(b []byte) (n int, err error) {
    74  	return c.in.Read(b)
    75  }
    76  
    77  func (c *testConn) Write(b []byte) (n int, err error) {
    78  	return c.out.Write(b)
    79  }
    80  
    81  func (c *testConn) Close() error {
    82  	return nil
    83  }
    84  
    85  func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, rp string, protected []byte) *conn {
    86  	key := []byte{
    87  		// 16 arbitrary bytes.
    88  		0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
    89  	tc := testConn{
    90  		in:  in,
    91  		out: out,
    92  	}
    93  	c, err := NewConn(&tc, side, rp, key, protected)
    94  	if err != nil {
    95  		panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
    96  	}
    97  	return c.(*conn)
    98  }
    99  
   100  func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (client, server *conn) {
   101  	clientBuf := new(bytes.Buffer)
   102  	serverBuf := new(bytes.Buffer)
   103  	clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, rp, clientProtected)
   104  	serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, rp, serverProtected)
   105  	return clientConn, serverConn
   106  }
   107  
   108  func testPingPong(t *testing.T, rp string) {
   109  	clientConn, serverConn := newConnPair(rp, nil, nil)
   110  	clientMsg := []byte("Client Message")
   111  	if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
   112  		t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
   113  	}
   114  	rcvClientMsg := make([]byte, len(clientMsg))
   115  	if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
   116  		t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
   117  	}
   118  	if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
   119  		t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
   120  	}
   121  
   122  	serverMsg := []byte("Server Message")
   123  	if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil {
   124  		t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg))
   125  	}
   126  	rcvServerMsg := make([]byte, len(serverMsg))
   127  	if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil {
   128  		t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg))
   129  	}
   130  	if !reflect.DeepEqual(serverMsg, rcvServerMsg) {
   131  		t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg)
   132  	}
   133  }
   134  
   135  func (s) TestPingPong(t *testing.T) {
   136  	for _, rp := range recordProtocols {
   137  		testPingPong(t, rp)
   138  	}
   139  }
   140  
   141  func testSmallReadBuffer(t *testing.T, rp string) {
   142  	clientConn, serverConn := newConnPair(rp, nil, nil)
   143  	msg := []byte("Very Important Message")
   144  	if n, err := clientConn.Write(msg); err != nil {
   145  		t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
   146  	}
   147  	rcvMsg := make([]byte, len(msg))
   148  	n := 2 // Arbitrary index to break rcvMsg in two.
   149  	rcvMsg1 := rcvMsg[:n]
   150  	rcvMsg2 := rcvMsg[n:]
   151  	if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil {
   152  		t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1))
   153  	}
   154  	if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil {
   155  		t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2))
   156  	}
   157  	if !reflect.DeepEqual(msg, rcvMsg) {
   158  		t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg)
   159  	}
   160  }
   161  
   162  func (s) TestSmallReadBuffer(t *testing.T) {
   163  	for _, rp := range recordProtocols {
   164  		testSmallReadBuffer(t, rp)
   165  	}
   166  }
   167  
   168  func testLargeMsg(t *testing.T, rp string) {
   169  	clientConn, serverConn := newConnPair(rp, nil, nil)
   170  	// msgLen is such that the length in the framing is larger than the
   171  	// default size of one frame.
   172  	msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
   173  	msg := make([]byte, msgLen)
   174  	if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
   175  		t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
   176  	}
   177  	rcvMsg := make([]byte, len(msg))
   178  	if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
   179  		t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
   180  	}
   181  	if !reflect.DeepEqual(msg, rcvMsg) {
   182  		t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
   183  	}
   184  }
   185  
   186  func (s) TestLargeMsg(t *testing.T) {
   187  	for _, rp := range recordProtocols {
   188  		testLargeMsg(t, rp)
   189  	}
   190  }
   191  
   192  // TestLargeRecord writes a very large ALTS record and verifies that the server
   193  // receives it correctly. The large ALTS record should cause the reader to
   194  // expand it's read buffer to hold the entire record and store the decrypted
   195  // message until the receiver reads all of the bytes.
   196  func (s) TestLargeRecord(t *testing.T) {
   197  	clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
   198  	msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
   199  	// Increase the size of ALTS records written by the client.
   200  	clientConn.payloadLengthLimit = math.MaxInt32
   201  	if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
   202  		t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
   203  	}
   204  	rcvMsg := make([]byte, len(msg))
   205  	if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
   206  		t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
   207  	}
   208  	if !reflect.DeepEqual(msg, rcvMsg) {
   209  		t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
   210  	}
   211  }
   212  
   213  // BenchmarkLargeMessage measures the performance of ALTS conns for sending and
   214  // receiving a large message.
   215  func BenchmarkLargeMessage(b *testing.B) {
   216  	msgLen := 20 * 1024 * 1024 // 20 MiB
   217  	msg := make([]byte, msgLen)
   218  	rcvMsg := make([]byte, len(msg))
   219  	b.ResetTimer()
   220  	clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
   221  	for range b.N {
   222  		// Write 20 MiB 5 times to transfer a total of 100 MiB.
   223  		for range 5 {
   224  			if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
   225  				b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
   226  			}
   227  			if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
   228  				b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
   229  			}
   230  		}
   231  	}
   232  }
   233  
   234  func testIncorrectMsgType(t *testing.T, rp string) {
   235  	// framedMsg is an empty ciphertext with correct framing but wrong
   236  	// message type.
   237  	framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
   238  	binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize)
   239  	wrongMsgType := uint32(0x22)
   240  	binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
   241  
   242  	in := bytes.NewBuffer(framedMsg)
   243  	c := newTestALTSRecordConn(in, nil, core.ClientSide, rp, nil)
   244  	b := make([]byte, 1)
   245  	if n, err := c.Read(b); n != 0 || err == nil {
   246  		t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
   247  	}
   248  }
   249  
   250  func (s) TestIncorrectMsgType(t *testing.T) {
   251  	for _, rp := range recordProtocols {
   252  		testIncorrectMsgType(t, rp)
   253  	}
   254  }
   255  
   256  func testFrameTooLarge(t *testing.T, rp string) {
   257  	buf := new(bytes.Buffer)
   258  	clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, rp, nil)
   259  	serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, rp, nil)
   260  	// payloadLen is such that the length in the framing is larger than
   261  	// allowed in one frame.
   262  	payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
   263  	payload := make([]byte, payloadLen)
   264  	c, err := clientConn.crypto.Encrypt(nil, payload)
   265  	if err != nil {
   266  		t.Fatalf("Error encrypting message: %v", err)
   267  	}
   268  	msgLen := msgTypeFieldSize + len(c)
   269  	framedMsg := make([]byte, MsgLenFieldSize+msgLen)
   270  	binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c)))
   271  	msg := framedMsg[MsgLenFieldSize:]
   272  	binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType)
   273  	copy(msg[msgTypeFieldSize:], c)
   274  	if _, err = buf.Write(framedMsg); err != nil {
   275  		t.Fatalf("Unexpected error writing to buffer: %v", err)
   276  	}
   277  	b := make([]byte, 1)
   278  	if n, err := serverConn.Read(b); n != 0 || err == nil {
   279  		t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit))
   280  	}
   281  }
   282  
   283  func (s) TestFrameTooLarge(t *testing.T) {
   284  	for _, rp := range recordProtocols {
   285  		testFrameTooLarge(t, rp)
   286  	}
   287  }
   288  
   289  func testWriteLargeData(t *testing.T, rp string) {
   290  	// Test sending and receiving messages larger than the maximum write
   291  	// buffer size.
   292  	clientConn, serverConn := newConnPair(rp, nil, nil)
   293  	// Message size is intentionally chosen to not be multiple of
   294  	// payloadLengthLimit.
   295  	msgSize := altsWriteBufferMaxSize + (100 * 1024)
   296  	clientMsg := make([]byte, msgSize)
   297  	for i := 0; i < msgSize; i++ {
   298  		clientMsg[i] = 0xAA
   299  	}
   300  	if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
   301  		t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
   302  	}
   303  	// We need to keep reading until the entire message is received. The
   304  	// reason we set all bytes of the message to a value other than zero is
   305  	// to avoid ambiguous zero-init value of rcvClientMsg buffer and the
   306  	// actual received data.
   307  	rcvClientMsg := make([]byte, 0, msgSize)
   308  	numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit)))
   309  	for i := 0; i < numberOfExpectedFrames; i++ {
   310  		expectedRcvSize := serverConn.payloadLengthLimit
   311  		if i == numberOfExpectedFrames-1 {
   312  			// Last frame might be smaller.
   313  			expectedRcvSize = msgSize % serverConn.payloadLengthLimit
   314  		}
   315  		tmpBuf := make([]byte, expectedRcvSize)
   316  		if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil {
   317  			t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf))
   318  		}
   319  		rcvClientMsg = append(rcvClientMsg, tmpBuf...)
   320  	}
   321  	if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
   322  		t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
   323  	}
   324  }
   325  
   326  func (s) TestWriteLargeData(t *testing.T) {
   327  	for _, rp := range recordProtocols {
   328  		testWriteLargeData(t, rp)
   329  	}
   330  }
   331  
   332  func testProtectedBuffer(t *testing.T, rp string) {
   333  	key := []byte{
   334  		// 16 arbitrary bytes.
   335  		0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
   336  
   337  	// Encrypt a message to be passed to NewConn as a client-side protected
   338  	// buffer.
   339  	newCrypto := protocols[rp]
   340  	if newCrypto == nil {
   341  		t.Fatalf("Unknown record protocol %q", rp)
   342  	}
   343  	crypto, err := newCrypto(core.ClientSide, key)
   344  	if err != nil {
   345  		t.Fatalf("Failed to create a crypter for protocol %q: %v", rp, err)
   346  	}
   347  	msg := []byte("Client Protected Message")
   348  	encryptedMsg, err := crypto.Encrypt(nil, msg)
   349  	if err != nil {
   350  		t.Fatalf("Failed to encrypt the client protected message: %v", err)
   351  	}
   352  	protectedMsg := make([]byte, 8)                                          // 8 bytes = 4 length + 4 type
   353  	binary.LittleEndian.PutUint32(protectedMsg, uint32(len(encryptedMsg))+4) // 4 bytes for the type
   354  	binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType)
   355  	protectedMsg = append(protectedMsg, encryptedMsg...)
   356  
   357  	_, serverConn := newConnPair(rp, nil, protectedMsg)
   358  	rcvClientMsg := make([]byte, len(msg))
   359  	if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
   360  		t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
   361  	}
   362  	if !reflect.DeepEqual(msg, rcvClientMsg) {
   363  		t.Fatalf("Client protected/Server Read() = %v, want %v", rcvClientMsg, msg)
   364  	}
   365  }
   366  
   367  func (s) TestProtectedBuffer(t *testing.T) {
   368  	for _, rp := range recordProtocols {
   369  		testProtectedBuffer(t, rp)
   370  	}
   371  }