gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/credentials/alts/internal/handshaker/handshaker_test.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package handshaker 20 21 import ( 22 "bytes" 23 "context" 24 "errors" 25 "testing" 26 "time" 27 28 grpc "gitee.com/zhaochuninhefei/gmgo/grpc" 29 core "gitee.com/zhaochuninhefei/gmgo/grpc/credentials/alts/internal" 30 altspb "gitee.com/zhaochuninhefei/gmgo/grpc/credentials/alts/internal/proto/grpc_gcp" 31 "gitee.com/zhaochuninhefei/gmgo/grpc/credentials/alts/internal/testutil" 32 "gitee.com/zhaochuninhefei/gmgo/grpc/internal/grpctest" 33 ) 34 35 type s struct { 36 grpctest.Tester 37 } 38 39 func Test(t *testing.T) { 40 grpctest.RunSubTests(t, s{}) 41 } 42 43 var ( 44 testRecordProtocol = rekeyRecordProtocolName 45 testKey = []byte{ 46 // 44 arbitrary bytes. 47 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 48 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b, 49 0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 50 } 51 testServiceAccount = "test_service_account" 52 testTargetServiceAccounts = []string{testServiceAccount} 53 testClientIdentity = &altspb.Identity{ 54 IdentityOneof: &altspb.Identity_Hostname{ 55 Hostname: "i_am_a_client", 56 }, 57 } 58 ) 59 60 const defaultTestTimeout = 10 * time.Second 61 62 // testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object. 63 type testRPCStream struct { 64 grpc.ClientStream 65 t *testing.T 66 isClient bool 67 // The resp expected to be returned by Recv(). Make sure this is set to 68 // the content the test requires before Recv() is invoked. 69 recvBuf *altspb.HandshakerResp 70 // false if it is the first access to Handshaker service on Envelope. 71 first bool 72 // useful for testing concurrent calls. 73 delay time.Duration 74 } 75 76 func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) { 77 resp := t.recvBuf 78 t.recvBuf = nil 79 return resp, nil 80 } 81 82 func (t *testRPCStream) Send(req *altspb.HandshakerReq) error { 83 var resp *altspb.HandshakerResp 84 if !t.first { 85 // Generate the bytes to be returned by Recv() for the initial 86 // handshaking. 87 t.first = true 88 if t.isClient { 89 resp = &altspb.HandshakerResp{ 90 OutFrames: testutil.MakeFrame("ClientInit"), 91 // Simulate consuming ServerInit. 92 BytesConsumed: 14, 93 } 94 } else { 95 resp = &altspb.HandshakerResp{ 96 OutFrames: testutil.MakeFrame("ServerInit"), 97 // Simulate consuming ClientInit. 98 BytesConsumed: 14, 99 } 100 } 101 } else { 102 // Add delay to test concurrent calls. 103 cleanup := stat.Update() 104 defer cleanup() 105 time.Sleep(t.delay) 106 107 // Generate the response to be returned by Recv() for the 108 // follow-up handshaking. 109 result := &altspb.HandshakerResult{ 110 RecordProtocol: testRecordProtocol, 111 KeyData: testKey, 112 } 113 resp = &altspb.HandshakerResp{ 114 Result: result, 115 // Simulate consuming ClientFinished or ServerFinished. 116 BytesConsumed: 18, 117 } 118 } 119 t.recvBuf = resp 120 return nil 121 } 122 123 func (t *testRPCStream) CloseSend() error { 124 return nil 125 } 126 127 var stat testutil.Stats 128 129 func (s) TestClientHandshake(t *testing.T) { 130 for _, testCase := range []struct { 131 delay time.Duration 132 numberOfHandshakes int 133 }{ 134 {0 * time.Millisecond, 1}, 135 {100 * time.Millisecond, 10 * maxPendingHandshakes}, 136 } { 137 errc := make(chan error) 138 stat.Reset() 139 140 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 141 defer cancel() 142 143 for i := 0; i < testCase.numberOfHandshakes; i++ { 144 stream := &testRPCStream{ 145 t: t, 146 isClient: true, 147 } 148 // Preload the inbound frames. 149 f1 := testutil.MakeFrame("ServerInit") 150 f2 := testutil.MakeFrame("ServerFinished") 151 in := bytes.NewBuffer(f1) 152 in.Write(f2) 153 out := new(bytes.Buffer) 154 tc := testutil.NewTestConn(in, out) 155 chs := &altsHandshaker{ 156 stream: stream, 157 conn: tc, 158 clientOpts: &ClientHandshakerOptions{ 159 TargetServiceAccounts: testTargetServiceAccounts, 160 ClientIdentity: testClientIdentity, 161 }, 162 side: core.ClientSide, 163 } 164 go func() { 165 _, context, err := chs.ClientHandshake(ctx) 166 if err == nil && context == nil { 167 errc <- errors.New("expected non-nil ALTS context") 168 return 169 } 170 errc <- err 171 chs.Close() 172 }() 173 } 174 175 // Ensure all errors are expected. 176 for i := 0; i < testCase.numberOfHandshakes; i++ { 177 if err := <-errc; err != nil && err != errDropped { 178 t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped) 179 } 180 } 181 182 // Ensure that there are no concurrent calls more than the limit. 183 if stat.MaxConcurrentCalls > maxPendingHandshakes { 184 t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes) 185 } 186 } 187 } 188 189 func (s) TestServerHandshake(t *testing.T) { 190 for _, testCase := range []struct { 191 delay time.Duration 192 numberOfHandshakes int 193 }{ 194 {0 * time.Millisecond, 1}, 195 {100 * time.Millisecond, 10 * maxPendingHandshakes}, 196 } { 197 errc := make(chan error) 198 stat.Reset() 199 200 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 201 defer cancel() 202 203 for i := 0; i < testCase.numberOfHandshakes; i++ { 204 stream := &testRPCStream{ 205 t: t, 206 isClient: false, 207 } 208 // Preload the inbound frames. 209 f1 := testutil.MakeFrame("ClientInit") 210 f2 := testutil.MakeFrame("ClientFinished") 211 in := bytes.NewBuffer(f1) 212 in.Write(f2) 213 out := new(bytes.Buffer) 214 tc := testutil.NewTestConn(in, out) 215 shs := &altsHandshaker{ 216 stream: stream, 217 conn: tc, 218 serverOpts: DefaultServerHandshakerOptions(), 219 side: core.ServerSide, 220 } 221 go func() { 222 _, context, err := shs.ServerHandshake(ctx) 223 if err == nil && context == nil { 224 errc <- errors.New("expected non-nil ALTS context") 225 return 226 } 227 errc <- err 228 shs.Close() 229 }() 230 } 231 232 // Ensure all errors are expected. 233 for i := 0; i < testCase.numberOfHandshakes; i++ { 234 if err := <-errc; err != nil && err != errDropped { 235 t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped) 236 } 237 } 238 239 // Ensure that there are no concurrent calls more than the limit. 240 if stat.MaxConcurrentCalls > maxPendingHandshakes { 241 t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes) 242 } 243 } 244 } 245 246 // testUnresponsiveRPCStream is used for testing the PeerNotResponding case. 247 type testUnresponsiveRPCStream struct { 248 grpc.ClientStream 249 } 250 251 func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) { 252 return &altspb.HandshakerResp{}, nil 253 } 254 255 func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error { 256 return nil 257 } 258 259 func (t *testUnresponsiveRPCStream) CloseSend() error { 260 return nil 261 } 262 263 func (s) TestPeerNotResponding(t *testing.T) { 264 stream := &testUnresponsiveRPCStream{} 265 chs := &altsHandshaker{ 266 stream: stream, 267 conn: testutil.NewUnresponsiveTestConn(), 268 clientOpts: &ClientHandshakerOptions{ 269 TargetServiceAccounts: testTargetServiceAccounts, 270 ClientIdentity: testClientIdentity, 271 }, 272 side: core.ClientSide, 273 } 274 275 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 276 defer cancel() 277 _, context, err := chs.ClientHandshake(ctx) 278 chs.Close() 279 if context != nil { 280 t.Error("expected non-nil ALTS context") 281 } 282 if got, want := err, core.PeerNotRespondingError; got != want { 283 t.Errorf("ClientHandshake() = %v, want %v", got, want) 284 } 285 }