gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/alts/internal/conn/record_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package conn 20 21 import ( 22 "bytes" 23 "encoding/binary" 24 "fmt" 25 "io" 26 "math" 27 "net" 28 "reflect" 29 "testing" 30 31 core "gitee.com/ks-custle/core-gm/grpc/credentials/alts/internal" 32 "gitee.com/ks-custle/core-gm/grpc/internal/grpctest" 33 ) 34 35 type s struct { 36 grpctest.Tester 37 } 38 39 func Test(t *testing.T) { 40 grpctest.RunSubTests(t, s{}) 41 } 42 43 const ( 44 rekeyRecordProtocol = "ALTSRP_GCM_AES128_REKEY" 45 ) 46 47 var ( 48 recordProtocols = []string{rekeyRecordProtocol} 49 altsRecordFuncs = map[string]ALTSRecordFunc{ 50 // ALTS handshaker protocols. 51 rekeyRecordProtocol: func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { 52 return NewAES128GCM(s, keyData) 53 }, 54 } 55 ) 56 57 func init() { 58 for protocol, f := range altsRecordFuncs { 59 if err := RegisterProtocol(protocol, f); err != nil { 60 panic(err) 61 } 62 } 63 } 64 65 // testConn mimics a net.Conn to the peer. 66 type testConn struct { 67 net.Conn 68 in *bytes.Buffer 69 out *bytes.Buffer 70 } 71 72 func (c *testConn) Read(b []byte) (n int, err error) { 73 return c.in.Read(b) 74 } 75 76 func (c *testConn) Write(b []byte) (n int, err error) { 77 return c.out.Write(b) 78 } 79 80 func (c *testConn) Close() error { 81 return nil 82 } 83 84 func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, rp string, protected []byte) *conn { 85 key := []byte{ 86 // 16 arbitrary bytes. 87 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} 88 tc := testConn{ 89 in: in, 90 out: out, 91 } 92 c, err := NewConn(&tc, side, rp, key, protected) 93 if err != nil { 94 panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) 95 } 96 return c.(*conn) 97 } 98 99 func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (client, server *conn) { 100 clientBuf := new(bytes.Buffer) 101 serverBuf := new(bytes.Buffer) 102 clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, rp, clientProtected) 103 serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, rp, serverProtected) 104 return clientConn, serverConn 105 } 106 107 func testPingPong(t *testing.T, rp string) { 108 clientConn, serverConn := newConnPair(rp, nil, nil) 109 clientMsg := []byte("Client Message") 110 if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { 111 t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg)) 112 } 113 rcvClientMsg := make([]byte, len(clientMsg)) 114 if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { 115 t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg)) 116 } 117 if !reflect.DeepEqual(clientMsg, rcvClientMsg) { 118 t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) 119 } 120 121 serverMsg := []byte("Server Message") 122 if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil { 123 t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg)) 124 } 125 rcvServerMsg := make([]byte, len(serverMsg)) 126 if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil { 127 t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg)) 128 } 129 if !reflect.DeepEqual(serverMsg, rcvServerMsg) { 130 t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg) 131 } 132 } 133 134 func (s) TestPingPong(t *testing.T) { 135 for _, rp := range recordProtocols { 136 testPingPong(t, rp) 137 } 138 } 139 140 func testSmallReadBuffer(t *testing.T, rp string) { 141 clientConn, serverConn := newConnPair(rp, nil, nil) 142 msg := []byte("Very Important Message") 143 if n, err := clientConn.Write(msg); err != nil { 144 t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg)) 145 } 146 rcvMsg := make([]byte, len(msg)) 147 n := 2 // Arbitrary index to break rcvMsg in two. 148 rcvMsg1 := rcvMsg[:n] 149 rcvMsg2 := rcvMsg[n:] 150 if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil { 151 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1)) 152 } 153 if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil { 154 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2)) 155 } 156 if !reflect.DeepEqual(msg, rcvMsg) { 157 t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg) 158 } 159 } 160 161 func (s) TestSmallReadBuffer(t *testing.T) { 162 for _, rp := range recordProtocols { 163 testSmallReadBuffer(t, rp) 164 } 165 } 166 167 func testLargeMsg(t *testing.T, rp string) { 168 clientConn, serverConn := newConnPair(rp, nil, nil) 169 // msgLen is such that the length in the framing is larger than the 170 // default size of one frame. 171 msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 172 msg := make([]byte, msgLen) 173 if n, err := clientConn.Write(msg); n != len(msg) || err != nil { 174 t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg)) 175 } 176 rcvMsg := make([]byte, len(msg)) 177 if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { 178 t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg)) 179 } 180 if !reflect.DeepEqual(msg, rcvMsg) { 181 t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg) 182 } 183 } 184 185 func (s) TestLargeMsg(t *testing.T) { 186 for _, rp := range recordProtocols { 187 testLargeMsg(t, rp) 188 } 189 } 190 191 func testIncorrectMsgType(t *testing.T, rp string) { 192 // framedMsg is an empty ciphertext with correct framing but wrong 193 // message type. 194 framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize) 195 binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize) 196 wrongMsgType := uint32(0x22) 197 binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) 198 199 in := bytes.NewBuffer(framedMsg) 200 c := newTestALTSRecordConn(in, nil, core.ClientSide, rp, nil) 201 b := make([]byte, 1) 202 if n, err := c.Read(b); n != 0 || err == nil { 203 t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) 204 } 205 } 206 207 func (s) TestIncorrectMsgType(t *testing.T) { 208 for _, rp := range recordProtocols { 209 testIncorrectMsgType(t, rp) 210 } 211 } 212 213 func testFrameTooLarge(t *testing.T, rp string) { 214 buf := new(bytes.Buffer) 215 clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, rp, nil) 216 serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, rp, nil) 217 // payloadLen is such that the length in the framing is larger than 218 // allowed in one frame. 219 payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 220 payload := make([]byte, payloadLen) 221 c, err := clientConn.crypto.Encrypt(nil, payload) 222 if err != nil { 223 t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err)) 224 } 225 msgLen := msgTypeFieldSize + len(c) 226 framedMsg := make([]byte, MsgLenFieldSize+msgLen) 227 binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c))) 228 msg := framedMsg[MsgLenFieldSize:] 229 binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType) 230 copy(msg[msgTypeFieldSize:], c) 231 if _, err = buf.Write(framedMsg); err != nil { 232 t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err)) 233 } 234 b := make([]byte, 1) 235 if n, err := serverConn.Read(b); n != 0 || err == nil { 236 t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit)) 237 } 238 } 239 240 func (s) TestFrameTooLarge(t *testing.T) { 241 for _, rp := range recordProtocols { 242 testFrameTooLarge(t, rp) 243 } 244 } 245 246 func testWriteLargeData(t *testing.T, rp string) { 247 // Test sending and receiving messages larger than the maximum write 248 // buffer size. 249 clientConn, serverConn := newConnPair(rp, nil, nil) 250 // Message size is intentionally chosen to not be multiple of 251 // payloadLengthLimtit. 252 msgSize := altsWriteBufferMaxSize + (100 * 1024) 253 clientMsg := make([]byte, msgSize) 254 for i := 0; i < msgSize; i++ { 255 clientMsg[i] = 0xAA 256 } 257 if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { 258 t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg)) 259 } 260 // We need to keep reading until the entire message is received. The 261 // reason we set all bytes of the message to a value other than zero is 262 // to avoid ambiguous zero-init value of rcvClientMsg buffer and the 263 // actual received data. 264 rcvClientMsg := make([]byte, 0, msgSize) 265 numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit))) 266 for i := 0; i < numberOfExpectedFrames; i++ { 267 expectedRcvSize := serverConn.payloadLengthLimit 268 if i == numberOfExpectedFrames-1 { 269 // Last frame might be smaller. 270 expectedRcvSize = msgSize % serverConn.payloadLengthLimit 271 } 272 tmpBuf := make([]byte, expectedRcvSize) 273 if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil { 274 t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf)) 275 } 276 rcvClientMsg = append(rcvClientMsg, tmpBuf...) 277 } 278 if !reflect.DeepEqual(clientMsg, rcvClientMsg) { 279 t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) 280 } 281 } 282 283 func (s) TestWriteLargeData(t *testing.T) { 284 for _, rp := range recordProtocols { 285 testWriteLargeData(t, rp) 286 } 287 } 288 289 func testProtectedBuffer(t *testing.T, rp string) { 290 key := []byte{ 291 // 16 arbitrary bytes. 292 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} 293 294 // Encrypt a message to be passed to NewConn as a client-side protected 295 // buffer. 296 newCrypto := protocols[rp] 297 if newCrypto == nil { 298 t.Fatalf("Unknown record protocol %q", rp) 299 } 300 crypto, err := newCrypto(core.ClientSide, key) 301 if err != nil { 302 t.Fatalf("Failed to create a crypter for protocol %q: %v", rp, err) 303 } 304 msg := []byte("Client Protected Message") 305 encryptedMsg, err := crypto.Encrypt(nil, msg) 306 if err != nil { 307 t.Fatalf("Failed to encrypt the client protected message: %v", err) 308 } 309 protectedMsg := make([]byte, 8) // 8 bytes = 4 length + 4 type 310 binary.LittleEndian.PutUint32(protectedMsg, uint32(len(encryptedMsg))+4) // 4 bytes for the type 311 binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType) 312 protectedMsg = append(protectedMsg, encryptedMsg...) 313 314 _, serverConn := newConnPair(rp, nil, protectedMsg) 315 rcvClientMsg := make([]byte, len(msg)) 316 if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { 317 t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg)) 318 } 319 if !reflect.DeepEqual(msg, rcvClientMsg) { 320 t.Fatalf("Client protected/Server Read() = %v, want %v", rcvClientMsg, msg) 321 } 322 } 323 324 func (s) TestProtectedBuffer(t *testing.T) { 325 for _, rp := range recordProtocols { 326 testProtectedBuffer(t, rp) 327 } 328 }