google.golang.org/grpc@v1.72.2/rpc_util_test.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package grpc 20 21 import ( 22 "bytes" 23 "compress/gzip" 24 "context" 25 "errors" 26 "io" 27 "math" 28 "reflect" 29 "sync" 30 "testing" 31 32 "github.com/google/go-cmp/cmp" 33 "github.com/google/go-cmp/cmp/cmpopts" 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/encoding" 36 _ "google.golang.org/grpc/encoding/gzip" 37 protoenc "google.golang.org/grpc/encoding/proto" 38 "google.golang.org/grpc/internal/testutils" 39 "google.golang.org/grpc/internal/transport" 40 "google.golang.org/grpc/mem" 41 "google.golang.org/grpc/status" 42 perfpb "google.golang.org/grpc/test/codec_perf" 43 "google.golang.org/protobuf/proto" 44 ) 45 46 const ( 47 defaultDecompressedData = "default decompressed data" 48 decompressionErrorMsg = "invalid compression format" 49 ) 50 51 type fullReader struct { 52 data []byte 53 } 54 55 func (f *fullReader) ReadMessageHeader(header []byte) error { 56 buf, err := f.Read(len(header)) 57 defer buf.Free() 58 if err != nil { 59 return err 60 } 61 62 buf.CopyTo(header) 63 return nil 64 } 65 66 func (f *fullReader) Read(n int) (mem.BufferSlice, error) { 67 if n == 0 { 68 return nil, nil 69 } 70 71 if len(f.data) == 0 { 72 return nil, io.EOF 73 } 74 75 if len(f.data) < n { 76 data := f.data 77 f.data = nil 78 return mem.BufferSlice{mem.SliceBuffer(data)}, io.ErrUnexpectedEOF 79 } 80 81 buf := f.data[:n] 82 f.data = f.data[n:] 83 84 return mem.BufferSlice{mem.SliceBuffer(buf)}, nil 85 } 86 87 var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface 88 89 func (s) TestSimpleParsing(t *testing.T) { 90 bigMsg := bytes.Repeat([]byte{'x'}, 1<<24) 91 for _, test := range []struct { 92 // input 93 p []byte 94 // outputs 95 err error 96 b []byte 97 pt payloadFormat 98 }{ 99 {nil, io.EOF, nil, compressionNone}, 100 {[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone}, 101 {[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone}, 102 {[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone}, 103 {[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone}, 104 // Check that messages with length >= 2^24 are parsed. 105 {append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone}, 106 } { 107 buf := &fullReader{test.p} 108 parser := &parser{r: buf, bufferPool: mem.DefaultBufferPool()} 109 pt, b, err := parser.recvMsg(math.MaxInt32) 110 if err != test.err || !bytes.Equal(b.Materialize(), test.b) || pt != test.pt { 111 t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) 112 } 113 } 114 } 115 116 func (s) TestMultipleParsing(t *testing.T) { 117 // Set a byte stream consists of 3 messages with their headers. 118 p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'} 119 b := &fullReader{p} 120 parser := &parser{r: b, bufferPool: mem.DefaultBufferPool()} 121 122 wantRecvs := []struct { 123 pt payloadFormat 124 data []byte 125 }{ 126 {compressionNone, []byte("a")}, 127 {compressionNone, []byte("bc")}, 128 {compressionNone, []byte("d")}, 129 } 130 for i, want := range wantRecvs { 131 pt, data, err := parser.recvMsg(math.MaxInt32) 132 if err != nil || pt != want.pt || !reflect.DeepEqual(data.Materialize(), want.data) { 133 t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>", 134 i, p, pt, data, err, want.pt, want.data) 135 } 136 } 137 138 pt, data, err := parser.recvMsg(math.MaxInt32) 139 if err != io.EOF { 140 t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v", 141 len(wantRecvs), p, pt, data, err, io.EOF) 142 } 143 } 144 145 func (s) TestEncode(t *testing.T) { 146 for _, test := range []struct { 147 // input 148 msg proto.Message 149 // outputs 150 hdr []byte 151 data []byte 152 err error 153 }{ 154 {nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, 155 } { 156 data, err := encode(getCodec(protoenc.Name), test.msg) 157 if err != test.err || !bytes.Equal(data.Materialize(), test.data) { 158 t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err) 159 continue 160 } 161 if hdr, _ := msgHeader(data, nil, compressionNone); !bytes.Equal(hdr, test.hdr) { 162 t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr) 163 } 164 } 165 } 166 167 func (s) TestCompress(t *testing.T) { 168 bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression) 169 if err != nil { 170 t.Fatalf("Could not initialize gzip compressor with best compression.") 171 } 172 bestSpeedCompressor, err := NewGZIPCompressorWithLevel(gzip.BestSpeed) 173 if err != nil { 174 t.Fatalf("Could not initialize gzip compressor with best speed compression.") 175 } 176 177 defaultCompressor, err := NewGZIPCompressorWithLevel(gzip.BestSpeed) 178 if err != nil { 179 t.Fatalf("Could not initialize gzip compressor with default compression.") 180 } 181 182 level5, err := NewGZIPCompressorWithLevel(5) 183 if err != nil { 184 t.Fatalf("Could not initialize gzip compressor with level 5 compression.") 185 } 186 187 for _, test := range []struct { 188 // input 189 data []byte 190 cp Compressor 191 dc Decompressor 192 // outputs 193 err error 194 }{ 195 {make([]byte, 1024), NewGZIPCompressor(), NewGZIPDecompressor(), nil}, 196 {make([]byte, 1024), bestCompressor, NewGZIPDecompressor(), nil}, 197 {make([]byte, 1024), bestSpeedCompressor, NewGZIPDecompressor(), nil}, 198 {make([]byte, 1024), defaultCompressor, NewGZIPDecompressor(), nil}, 199 {make([]byte, 1024), level5, NewGZIPDecompressor(), nil}, 200 } { 201 b := new(bytes.Buffer) 202 if err := test.cp.Do(b, test.data); err != test.err { 203 t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err) 204 } 205 if b.Len() >= len(test.data) { 206 t.Fatalf("The compressor fails to compress data.") 207 } 208 if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) { 209 t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, <nil>", b, p, err, test.data) 210 } 211 } 212 } 213 214 func (s) TestToRPCErr(t *testing.T) { 215 for _, test := range []struct { 216 // input 217 errIn error 218 // outputs 219 errOut error 220 }{ 221 {transport.ErrConnClosing, status.Error(codes.Unavailable, transport.ErrConnClosing.Desc)}, 222 {io.ErrUnexpectedEOF, status.Error(codes.Internal, io.ErrUnexpectedEOF.Error())}, 223 } { 224 err := toRPCErr(test.errIn) 225 if _, ok := status.FromError(err); !ok { 226 t.Errorf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error) 227 } 228 if !testutils.StatusErrEqual(err, test.errOut) { 229 t.Errorf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) 230 } 231 } 232 } 233 234 // bmEncode benchmarks encoding a Protocol Buffer message containing mSize 235 // bytes. 236 func bmEncode(b *testing.B, mSize int) { 237 cdc := getCodec(protoenc.Name) 238 msg := &perfpb.Buffer{Body: make([]byte, mSize)} 239 encodeData, _ := encode(cdc, msg) 240 encodedSz := int64(len(encodeData)) 241 b.ReportAllocs() 242 b.ResetTimer() 243 for i := 0; i < b.N; i++ { 244 encode(cdc, msg) 245 } 246 b.SetBytes(encodedSz) 247 } 248 249 func BenchmarkEncode1B(b *testing.B) { 250 bmEncode(b, 1) 251 } 252 253 func BenchmarkEncode1KiB(b *testing.B) { 254 bmEncode(b, 1024) 255 } 256 257 func BenchmarkEncode8KiB(b *testing.B) { 258 bmEncode(b, 8*1024) 259 } 260 261 func BenchmarkEncode64KiB(b *testing.B) { 262 bmEncode(b, 64*1024) 263 } 264 265 func BenchmarkEncode512KiB(b *testing.B) { 266 bmEncode(b, 512*1024) 267 } 268 269 func BenchmarkEncode1MiB(b *testing.B) { 270 bmEncode(b, 1024*1024) 271 } 272 273 // bmCompressor benchmarks a compressor of a Protocol Buffer message containing 274 // mSize bytes. 275 func bmCompressor(b *testing.B, mSize int, cp Compressor) { 276 payload := make([]byte, mSize) 277 cBuf := bytes.NewBuffer(make([]byte, mSize)) 278 b.ReportAllocs() 279 b.ResetTimer() 280 for i := 0; i < b.N; i++ { 281 cp.Do(cBuf, payload) 282 cBuf.Reset() 283 } 284 } 285 286 func BenchmarkGZIPCompressor1B(b *testing.B) { 287 bmCompressor(b, 1, NewGZIPCompressor()) 288 } 289 290 func BenchmarkGZIPCompressor1KiB(b *testing.B) { 291 bmCompressor(b, 1024, NewGZIPCompressor()) 292 } 293 294 func BenchmarkGZIPCompressor8KiB(b *testing.B) { 295 bmCompressor(b, 8*1024, NewGZIPCompressor()) 296 } 297 298 func BenchmarkGZIPCompressor64KiB(b *testing.B) { 299 bmCompressor(b, 64*1024, NewGZIPCompressor()) 300 } 301 302 func BenchmarkGZIPCompressor512KiB(b *testing.B) { 303 bmCompressor(b, 512*1024, NewGZIPCompressor()) 304 } 305 306 func BenchmarkGZIPCompressor1MiB(b *testing.B) { 307 bmCompressor(b, 1024*1024, NewGZIPCompressor()) 308 } 309 310 // compressWithDeterministicError compresses the input data and returns a BufferSlice. 311 func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice { 312 t.Helper() 313 var buf bytes.Buffer 314 gz := gzip.NewWriter(&buf) 315 if _, err := gz.Write(input); err != nil { 316 t.Fatalf("compressInput() failed to write data: %v", err) 317 } 318 if err := gz.Close(); err != nil { 319 t.Fatalf("compressInput() failed to close gzip writer: %v", err) 320 } 321 compressedData := buf.Bytes() 322 return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)} 323 } 324 325 // MockDecompressor is a mock implementation of a decompressor used for testing purposes. 326 // It simulates decompression behavior, returning either decompressed data or an error based on the ShouldError flag. 327 type MockDecompressor struct { 328 ShouldError bool // Flag to control whether the decompression should simulate an error. 329 } 330 331 // Do simulates decompression. It returns a predefined error if ShouldError is true, 332 // or a fixed set of decompressed data if ShouldError is false. 333 func (m *MockDecompressor) Do(_ io.Reader) ([]byte, error) { 334 if m.ShouldError { 335 return nil, errors.New(decompressionErrorMsg) 336 } 337 return []byte(defaultDecompressedData), nil 338 } 339 340 // Type returns the string identifier for the MockDecompressor. 341 func (m *MockDecompressor) Type() string { 342 return "MockDecompressor" 343 } 344 345 // TestDecompress tests the decompress function behaves correctly for following scenarios 346 // decompress successfully when message is <= maxReceiveMessageSize 347 // errors when message > maxReceiveMessageSize 348 // decompress successfully when maxReceiveMessageSize is MaxInt 349 // errors when the decompressed message has an invalid format 350 // errors when the decompressed message exceeds the maxReceiveMessageSize. 351 func (s) TestDecompress(t *testing.T) { 352 compressor := encoding.GetCompressor("gzip") 353 validDecompressor := &MockDecompressor{ShouldError: false} 354 invalidFormatDecompressor := &MockDecompressor{ShouldError: true} 355 356 testCases := []struct { 357 name string 358 input mem.BufferSlice 359 dc Decompressor 360 maxReceiveMessageSize int 361 want []byte 362 wantErr error 363 }{ 364 { 365 name: "Decompresses successfully with sufficient buffer size", 366 input: compressWithDeterministicError(t, []byte("decompressed data")), 367 dc: nil, 368 maxReceiveMessageSize: 50, 369 want: []byte("decompressed data"), 370 wantErr: nil, 371 }, 372 { 373 name: "Fails due to exceeding maxReceiveMessageSize", 374 input: compressWithDeterministicError(t, []byte("message that is too large")), 375 dc: nil, 376 maxReceiveMessageSize: len("message that is too large") - 1, 377 want: nil, 378 wantErr: status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", len("message that is too large")-1), 379 }, 380 { 381 name: "Decompresses to exactly maxReceiveMessageSize", 382 input: compressWithDeterministicError(t, []byte("exact size message")), 383 dc: nil, 384 maxReceiveMessageSize: len("exact size message"), 385 want: []byte("exact size message"), 386 wantErr: nil, 387 }, 388 { 389 name: "Decompresses successfully with maxReceiveMessageSize MaxInt", 390 input: compressWithDeterministicError(t, []byte("large message")), 391 dc: nil, 392 maxReceiveMessageSize: math.MaxInt, 393 want: []byte("large message"), 394 wantErr: nil, 395 }, 396 { 397 name: "Fails with decompression error due to invalid format", 398 input: compressWithDeterministicError(t, []byte("invalid compressed data")), 399 dc: invalidFormatDecompressor, 400 maxReceiveMessageSize: 50, 401 want: nil, 402 wantErr: status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", errors.New(decompressionErrorMsg)), 403 }, 404 { 405 name: "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize", 406 input: compressWithDeterministicError(t, []byte("large compressed data")), 407 dc: validDecompressor, 408 maxReceiveMessageSize: 20, 409 want: nil, 410 wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20), 411 }, 412 } 413 414 for _, tc := range testCases { 415 t.Run(tc.name, func(t *testing.T) { 416 output, err := decompress(compressor, tc.input, tc.dc, tc.maxReceiveMessageSize, mem.DefaultBufferPool()) 417 if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) { 418 t.Fatalf("decompress() err = %v, wantErr = %v", err, tc.wantErr) 419 } 420 if !cmp.Equal(tc.want, output.Materialize()) { 421 t.Fatalf("decompress() output mismatch: got = %v, want = %v", output.Materialize(), tc.want) 422 } 423 }) 424 } 425 } 426 427 type mockCompressor struct { 428 // Written to by the io.Reader on every call to Read. 429 ch chan<- struct{} 430 } 431 432 func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) { 433 panic("unimplemented") 434 } 435 436 func (m *mockCompressor) Decompress(io.Reader) (io.Reader, error) { 437 return m, nil 438 } 439 440 func (m *mockCompressor) Read([]byte) (int, error) { 441 m.ch <- struct{}{} 442 return 1, io.EOF 443 } 444 445 func (m *mockCompressor) Name() string { return "" } 446 447 // Tests that the decompressor's Read method is not called after it returns EOF. 448 func (s) TestDecompress_NoReadAfterEOF(t *testing.T) { 449 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 450 defer cancel() 451 452 ch := make(chan struct{}, 10) 453 mc := &mockCompressor{ch: ch} 454 in := mem.BufferSlice{mem.NewBuffer(&[]byte{1, 2, 3}, nil)} 455 wg := sync.WaitGroup{} 456 wg.Add(1) 457 go func() { 458 defer wg.Done() 459 out, err := decompress(mc, in, nil, 1, mem.DefaultBufferPool()) 460 if err != nil { 461 t.Errorf("Unexpected error from decompress: %v", err) 462 return 463 } 464 out.Free() 465 }() 466 select { 467 case <-ch: 468 case <-ctx.Done(): 469 t.Fatalf("Timed out waiting for call to compressor") 470 } 471 ctx, cancel = context.WithTimeout(ctx, defaultTestShortTimeout) 472 defer cancel() 473 select { 474 case <-ch: 475 t.Fatalf("Unexpected second compressor.Read call detected") 476 case <-ctx.Done(): 477 } 478 wg.Wait() 479 }