google.golang.org/grpc@v1.74.2/credentials/alts/internal/conn/record_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package conn 20 21 import ( 22 "bytes" 23 "encoding/binary" 24 "fmt" 25 "io" 26 "math" 27 "net" 28 "reflect" 29 "strings" 30 "testing" 31 32 core "google.golang.org/grpc/credentials/alts/internal" 33 "google.golang.org/grpc/internal/grpctest" 34 ) 35 36 type s struct { 37 grpctest.Tester 38 } 39 40 func Test(t *testing.T) { 41 grpctest.RunSubTests(t, s{}) 42 } 43 44 const ( 45 rekeyRecordProtocol = "ALTSRP_GCM_AES128_REKEY" 46 ) 47 48 var ( 49 recordProtocols = []string{rekeyRecordProtocol} 50 altsRecordFuncs = map[string]ALTSRecordFunc{ 51 // ALTS handshaker protocols. 52 rekeyRecordProtocol: func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { 53 return NewAES128GCM(s, keyData) 54 }, 55 } 56 ) 57 58 func init() { 59 for protocol, f := range altsRecordFuncs { 60 if err := RegisterProtocol(protocol, f); err != nil { 61 panic(err) 62 } 63 } 64 } 65 66 // testConn mimics a net.Conn to the peer. 67 type testConn struct { 68 net.Conn 69 in *bytes.Buffer 70 out *bytes.Buffer 71 } 72 73 func (c *testConn) Read(b []byte) (n int, err error) { 74 return c.in.Read(b) 75 } 76 77 func (c *testConn) Write(b []byte) (n int, err error) { 78 return c.out.Write(b) 79 } 80 81 func (c *testConn) Close() error { 82 return nil 83 } 84 85 func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, rp string, protected []byte) *conn { 86 key := []byte{ 87 // 16 arbitrary bytes. 88 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} 89 tc := testConn{ 90 in: in, 91 out: out, 92 } 93 c, err := NewConn(&tc, side, rp, key, protected) 94 if err != nil { 95 panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) 96 } 97 return c.(*conn) 98 } 99 100 func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (client, server *conn) { 101 clientBuf := new(bytes.Buffer) 102 serverBuf := new(bytes.Buffer) 103 clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, rp, clientProtected) 104 serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, rp, serverProtected) 105 return clientConn, serverConn 106 } 107 108 func testPingPong(t *testing.T, rp string) { 109 clientConn, serverConn := newConnPair(rp, nil, nil) 110 clientMsg := []byte("Client Message") 111 if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { 112 t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg)) 113 } 114 rcvClientMsg := make([]byte, len(clientMsg)) 115 if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { 116 t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg)) 117 } 118 if !reflect.DeepEqual(clientMsg, rcvClientMsg) { 119 t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) 120 } 121 122 serverMsg := []byte("Server Message") 123 if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil { 124 t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg)) 125 } 126 rcvServerMsg := make([]byte, len(serverMsg)) 127 if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil { 128 t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg)) 129 } 130 if !reflect.DeepEqual(serverMsg, rcvServerMsg) { 131 t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg) 132 } 133 } 134 135 func (s) TestPingPong(t *testing.T) { 136 for _, rp := range recordProtocols { 137 testPingPong(t, rp) 138 } 139 } 140 141 func testSmallReadBuffer(t *testing.T, rp string) { 142 clientConn, serverConn := newConnPair(rp, nil, nil) 143 msg := []byte("Very Important Message") 144 if n, err := clientConn.Write(msg); err != nil { 145 t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg)) 146 } 147 rcvMsg := make([]byte, len(msg)) 148 n := 2 // Arbitrary index to break rcvMsg in two. 149 rcvMsg1 := rcvMsg[:n] 150 rcvMsg2 := rcvMsg[n:] 151 if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil { 152 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1)) 153 } 154 if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil { 155 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2)) 156 } 157 if !reflect.DeepEqual(msg, rcvMsg) { 158 t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg) 159 } 160 } 161 162 func (s) TestSmallReadBuffer(t *testing.T) { 163 for _, rp := range recordProtocols { 164 testSmallReadBuffer(t, rp) 165 } 166 } 167 168 func testLargeMsg(t *testing.T, rp string) { 169 clientConn, serverConn := newConnPair(rp, nil, nil) 170 // msgLen is such that the length in the framing is larger than the 171 // default size of one frame. 172 msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 173 msg := make([]byte, msgLen) 174 if n, err := clientConn.Write(msg); n != len(msg) || err != nil { 175 t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg)) 176 } 177 rcvMsg := make([]byte, len(msg)) 178 if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { 179 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg)) 180 } 181 if !reflect.DeepEqual(msg, rcvMsg) { 182 t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg) 183 } 184 } 185 186 func (s) TestLargeMsg(t *testing.T) { 187 for _, rp := range recordProtocols { 188 testLargeMsg(t, rp) 189 } 190 } 191 192 // TestLargeRecord writes a very large ALTS record and verifies that the server 193 // receives it correctly. The large ALTS record should cause the reader to 194 // expand it's read buffer to hold the entire record and store the decrypted 195 // message until the receiver reads all of the bytes. 196 func (s) TestLargeRecord(t *testing.T) { 197 clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil) 198 msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize)) 199 // Increase the size of ALTS records written by the client. 200 clientConn.payloadLengthLimit = math.MaxInt32 201 if n, err := clientConn.Write(msg); n != len(msg) || err != nil { 202 t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg)) 203 } 204 rcvMsg := make([]byte, len(msg)) 205 if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { 206 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg)) 207 } 208 if !reflect.DeepEqual(msg, rcvMsg) { 209 t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg) 210 } 211 } 212 213 // BenchmarkLargeMessage measures the performance of ALTS conns for sending and 214 // receiving a large message. 215 func BenchmarkLargeMessage(b *testing.B) { 216 msgLen := 20 * 1024 * 1024 // 20 MiB 217 msg := make([]byte, msgLen) 218 rcvMsg := make([]byte, len(msg)) 219 b.ResetTimer() 220 clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil) 221 for range b.N { 222 // Write 20 MiB 5 times to transfer a total of 100 MiB. 223 for range 5 { 224 if n, err := clientConn.Write(msg); n != len(msg) || err != nil { 225 b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg)) 226 } 227 if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { 228 b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg)) 229 } 230 } 231 } 232 } 233 234 func testIncorrectMsgType(t *testing.T, rp string) { 235 // framedMsg is an empty ciphertext with correct framing but wrong 236 // message type. 237 framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize) 238 binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize) 239 wrongMsgType := uint32(0x22) 240 binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) 241 242 in := bytes.NewBuffer(framedMsg) 243 c := newTestALTSRecordConn(in, nil, core.ClientSide, rp, nil) 244 b := make([]byte, 1) 245 if n, err := c.Read(b); n != 0 || err == nil { 246 t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) 247 } 248 } 249 250 func (s) TestIncorrectMsgType(t *testing.T) { 251 for _, rp := range recordProtocols { 252 testIncorrectMsgType(t, rp) 253 } 254 } 255 256 func testFrameTooLarge(t *testing.T, rp string) { 257 buf := new(bytes.Buffer) 258 clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, rp, nil) 259 serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, rp, nil) 260 // payloadLen is such that the length in the framing is larger than 261 // allowed in one frame. 262 payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 263 payload := make([]byte, payloadLen) 264 c, err := clientConn.crypto.Encrypt(nil, payload) 265 if err != nil { 266 t.Fatalf("Error encrypting message: %v", err) 267 } 268 msgLen := msgTypeFieldSize + len(c) 269 framedMsg := make([]byte, MsgLenFieldSize+msgLen) 270 binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c))) 271 msg := framedMsg[MsgLenFieldSize:] 272 binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType) 273 copy(msg[msgTypeFieldSize:], c) 274 if _, err = buf.Write(framedMsg); err != nil { 275 t.Fatalf("Unexpected error writing to buffer: %v", err) 276 } 277 b := make([]byte, 1) 278 if n, err := serverConn.Read(b); n != 0 || err == nil { 279 t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit)) 280 } 281 } 282 283 func (s) TestFrameTooLarge(t *testing.T) { 284 for _, rp := range recordProtocols { 285 testFrameTooLarge(t, rp) 286 } 287 } 288 289 func testWriteLargeData(t *testing.T, rp string) { 290 // Test sending and receiving messages larger than the maximum write 291 // buffer size. 292 clientConn, serverConn := newConnPair(rp, nil, nil) 293 // Message size is intentionally chosen to not be multiple of 294 // payloadLengthLimit. 295 msgSize := altsWriteBufferMaxSize + (100 * 1024) 296 clientMsg := make([]byte, msgSize) 297 for i := 0; i < msgSize; i++ { 298 clientMsg[i] = 0xAA 299 } 300 if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { 301 t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg)) 302 } 303 // We need to keep reading until the entire message is received. The 304 // reason we set all bytes of the message to a value other than zero is 305 // to avoid ambiguous zero-init value of rcvClientMsg buffer and the 306 // actual received data. 307 rcvClientMsg := make([]byte, 0, msgSize) 308 numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit))) 309 for i := 0; i < numberOfExpectedFrames; i++ { 310 expectedRcvSize := serverConn.payloadLengthLimit 311 if i == numberOfExpectedFrames-1 { 312 // Last frame might be smaller. 313 expectedRcvSize = msgSize % serverConn.payloadLengthLimit 314 } 315 tmpBuf := make([]byte, expectedRcvSize) 316 if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil { 317 t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf)) 318 } 319 rcvClientMsg = append(rcvClientMsg, tmpBuf...) 320 } 321 if !reflect.DeepEqual(clientMsg, rcvClientMsg) { 322 t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) 323 } 324 } 325 326 func (s) TestWriteLargeData(t *testing.T) { 327 for _, rp := range recordProtocols { 328 testWriteLargeData(t, rp) 329 } 330 } 331 332 func testProtectedBuffer(t *testing.T, rp string) { 333 key := []byte{ 334 // 16 arbitrary bytes. 335 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} 336 337 // Encrypt a message to be passed to NewConn as a client-side protected 338 // buffer. 339 newCrypto := protocols[rp] 340 if newCrypto == nil { 341 t.Fatalf("Unknown record protocol %q", rp) 342 } 343 crypto, err := newCrypto(core.ClientSide, key) 344 if err != nil { 345 t.Fatalf("Failed to create a crypter for protocol %q: %v", rp, err) 346 } 347 msg := []byte("Client Protected Message") 348 encryptedMsg, err := crypto.Encrypt(nil, msg) 349 if err != nil { 350 t.Fatalf("Failed to encrypt the client protected message: %v", err) 351 } 352 protectedMsg := make([]byte, 8) // 8 bytes = 4 length + 4 type 353 binary.LittleEndian.PutUint32(protectedMsg, uint32(len(encryptedMsg))+4) // 4 bytes for the type 354 binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType) 355 protectedMsg = append(protectedMsg, encryptedMsg...) 356 357 _, serverConn := newConnPair(rp, nil, protectedMsg) 358 rcvClientMsg := make([]byte, len(msg)) 359 if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { 360 t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg)) 361 } 362 if !reflect.DeepEqual(msg, rcvClientMsg) { 363 t.Fatalf("Client protected/Server Read() = %v, want %v", rcvClientMsg, msg) 364 } 365 } 366 367 func (s) TestProtectedBuffer(t *testing.T) { 368 for _, rp := range recordProtocols { 369 testProtectedBuffer(t, rp) 370 } 371 }