google.golang.org/grpc@v1.72.2/rpc_util_test.go (about)

     1  /*
     2   *
     3   * Copyright 2014 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package grpc
    20  
    21  import (
    22  	"bytes"
    23  	"compress/gzip"
    24  	"context"
    25  	"errors"
    26  	"io"
    27  	"math"
    28  	"reflect"
    29  	"sync"
    30  	"testing"
    31  
    32  	"github.com/google/go-cmp/cmp"
    33  	"github.com/google/go-cmp/cmp/cmpopts"
    34  	"google.golang.org/grpc/codes"
    35  	"google.golang.org/grpc/encoding"
    36  	_ "google.golang.org/grpc/encoding/gzip"
    37  	protoenc "google.golang.org/grpc/encoding/proto"
    38  	"google.golang.org/grpc/internal/testutils"
    39  	"google.golang.org/grpc/internal/transport"
    40  	"google.golang.org/grpc/mem"
    41  	"google.golang.org/grpc/status"
    42  	perfpb "google.golang.org/grpc/test/codec_perf"
    43  	"google.golang.org/protobuf/proto"
    44  )
    45  
    46  const (
    47  	defaultDecompressedData = "default decompressed data"
    48  	decompressionErrorMsg   = "invalid compression format"
    49  )
    50  
    51  type fullReader struct {
    52  	data []byte
    53  }
    54  
    55  func (f *fullReader) ReadMessageHeader(header []byte) error {
    56  	buf, err := f.Read(len(header))
    57  	defer buf.Free()
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	buf.CopyTo(header)
    63  	return nil
    64  }
    65  
    66  func (f *fullReader) Read(n int) (mem.BufferSlice, error) {
    67  	if n == 0 {
    68  		return nil, nil
    69  	}
    70  
    71  	if len(f.data) == 0 {
    72  		return nil, io.EOF
    73  	}
    74  
    75  	if len(f.data) < n {
    76  		data := f.data
    77  		f.data = nil
    78  		return mem.BufferSlice{mem.SliceBuffer(data)}, io.ErrUnexpectedEOF
    79  	}
    80  
    81  	buf := f.data[:n]
    82  	f.data = f.data[n:]
    83  
    84  	return mem.BufferSlice{mem.SliceBuffer(buf)}, nil
    85  }
    86  
    87  var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
    88  
    89  func (s) TestSimpleParsing(t *testing.T) {
    90  	bigMsg := bytes.Repeat([]byte{'x'}, 1<<24)
    91  	for _, test := range []struct {
    92  		// input
    93  		p []byte
    94  		// outputs
    95  		err error
    96  		b   []byte
    97  		pt  payloadFormat
    98  	}{
    99  		{nil, io.EOF, nil, compressionNone},
   100  		{[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone},
   101  		{[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone},
   102  		{[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone},
   103  		{[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone},
   104  		// Check that messages with length >= 2^24 are parsed.
   105  		{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
   106  	} {
   107  		buf := &fullReader{test.p}
   108  		parser := &parser{r: buf, bufferPool: mem.DefaultBufferPool()}
   109  		pt, b, err := parser.recvMsg(math.MaxInt32)
   110  		if err != test.err || !bytes.Equal(b.Materialize(), test.b) || pt != test.pt {
   111  			t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
   112  		}
   113  	}
   114  }
   115  
   116  func (s) TestMultipleParsing(t *testing.T) {
   117  	// Set a byte stream consists of 3 messages with their headers.
   118  	p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
   119  	b := &fullReader{p}
   120  	parser := &parser{r: b, bufferPool: mem.DefaultBufferPool()}
   121  
   122  	wantRecvs := []struct {
   123  		pt   payloadFormat
   124  		data []byte
   125  	}{
   126  		{compressionNone, []byte("a")},
   127  		{compressionNone, []byte("bc")},
   128  		{compressionNone, []byte("d")},
   129  	}
   130  	for i, want := range wantRecvs {
   131  		pt, data, err := parser.recvMsg(math.MaxInt32)
   132  		if err != nil || pt != want.pt || !reflect.DeepEqual(data.Materialize(), want.data) {
   133  			t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
   134  				i, p, pt, data, err, want.pt, want.data)
   135  		}
   136  	}
   137  
   138  	pt, data, err := parser.recvMsg(math.MaxInt32)
   139  	if err != io.EOF {
   140  		t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
   141  			len(wantRecvs), p, pt, data, err, io.EOF)
   142  	}
   143  }
   144  
   145  func (s) TestEncode(t *testing.T) {
   146  	for _, test := range []struct {
   147  		// input
   148  		msg proto.Message
   149  		// outputs
   150  		hdr  []byte
   151  		data []byte
   152  		err  error
   153  	}{
   154  		{nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
   155  	} {
   156  		data, err := encode(getCodec(protoenc.Name), test.msg)
   157  		if err != test.err || !bytes.Equal(data.Materialize(), test.data) {
   158  			t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err)
   159  			continue
   160  		}
   161  		if hdr, _ := msgHeader(data, nil, compressionNone); !bytes.Equal(hdr, test.hdr) {
   162  			t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr)
   163  		}
   164  	}
   165  }
   166  
   167  func (s) TestCompress(t *testing.T) {
   168  	bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression)
   169  	if err != nil {
   170  		t.Fatalf("Could not initialize gzip compressor with best compression.")
   171  	}
   172  	bestSpeedCompressor, err := NewGZIPCompressorWithLevel(gzip.BestSpeed)
   173  	if err != nil {
   174  		t.Fatalf("Could not initialize gzip compressor with best speed compression.")
   175  	}
   176  
   177  	defaultCompressor, err := NewGZIPCompressorWithLevel(gzip.BestSpeed)
   178  	if err != nil {
   179  		t.Fatalf("Could not initialize gzip compressor with default compression.")
   180  	}
   181  
   182  	level5, err := NewGZIPCompressorWithLevel(5)
   183  	if err != nil {
   184  		t.Fatalf("Could not initialize gzip compressor with level 5 compression.")
   185  	}
   186  
   187  	for _, test := range []struct {
   188  		// input
   189  		data []byte
   190  		cp   Compressor
   191  		dc   Decompressor
   192  		// outputs
   193  		err error
   194  	}{
   195  		{make([]byte, 1024), NewGZIPCompressor(), NewGZIPDecompressor(), nil},
   196  		{make([]byte, 1024), bestCompressor, NewGZIPDecompressor(), nil},
   197  		{make([]byte, 1024), bestSpeedCompressor, NewGZIPDecompressor(), nil},
   198  		{make([]byte, 1024), defaultCompressor, NewGZIPDecompressor(), nil},
   199  		{make([]byte, 1024), level5, NewGZIPDecompressor(), nil},
   200  	} {
   201  		b := new(bytes.Buffer)
   202  		if err := test.cp.Do(b, test.data); err != test.err {
   203  			t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err)
   204  		}
   205  		if b.Len() >= len(test.data) {
   206  			t.Fatalf("The compressor fails to compress data.")
   207  		}
   208  		if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) {
   209  			t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, <nil>", b, p, err, test.data)
   210  		}
   211  	}
   212  }
   213  
   214  func (s) TestToRPCErr(t *testing.T) {
   215  	for _, test := range []struct {
   216  		// input
   217  		errIn error
   218  		// outputs
   219  		errOut error
   220  	}{
   221  		{transport.ErrConnClosing, status.Error(codes.Unavailable, transport.ErrConnClosing.Desc)},
   222  		{io.ErrUnexpectedEOF, status.Error(codes.Internal, io.ErrUnexpectedEOF.Error())},
   223  	} {
   224  		err := toRPCErr(test.errIn)
   225  		if _, ok := status.FromError(err); !ok {
   226  			t.Errorf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error)
   227  		}
   228  		if !testutils.StatusErrEqual(err, test.errOut) {
   229  			t.Errorf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
   230  		}
   231  	}
   232  }
   233  
   234  // bmEncode benchmarks encoding a Protocol Buffer message containing mSize
   235  // bytes.
   236  func bmEncode(b *testing.B, mSize int) {
   237  	cdc := getCodec(protoenc.Name)
   238  	msg := &perfpb.Buffer{Body: make([]byte, mSize)}
   239  	encodeData, _ := encode(cdc, msg)
   240  	encodedSz := int64(len(encodeData))
   241  	b.ReportAllocs()
   242  	b.ResetTimer()
   243  	for i := 0; i < b.N; i++ {
   244  		encode(cdc, msg)
   245  	}
   246  	b.SetBytes(encodedSz)
   247  }
   248  
   249  func BenchmarkEncode1B(b *testing.B) {
   250  	bmEncode(b, 1)
   251  }
   252  
   253  func BenchmarkEncode1KiB(b *testing.B) {
   254  	bmEncode(b, 1024)
   255  }
   256  
   257  func BenchmarkEncode8KiB(b *testing.B) {
   258  	bmEncode(b, 8*1024)
   259  }
   260  
   261  func BenchmarkEncode64KiB(b *testing.B) {
   262  	bmEncode(b, 64*1024)
   263  }
   264  
   265  func BenchmarkEncode512KiB(b *testing.B) {
   266  	bmEncode(b, 512*1024)
   267  }
   268  
   269  func BenchmarkEncode1MiB(b *testing.B) {
   270  	bmEncode(b, 1024*1024)
   271  }
   272  
   273  // bmCompressor benchmarks a compressor of a Protocol Buffer message containing
   274  // mSize bytes.
   275  func bmCompressor(b *testing.B, mSize int, cp Compressor) {
   276  	payload := make([]byte, mSize)
   277  	cBuf := bytes.NewBuffer(make([]byte, mSize))
   278  	b.ReportAllocs()
   279  	b.ResetTimer()
   280  	for i := 0; i < b.N; i++ {
   281  		cp.Do(cBuf, payload)
   282  		cBuf.Reset()
   283  	}
   284  }
   285  
   286  func BenchmarkGZIPCompressor1B(b *testing.B) {
   287  	bmCompressor(b, 1, NewGZIPCompressor())
   288  }
   289  
   290  func BenchmarkGZIPCompressor1KiB(b *testing.B) {
   291  	bmCompressor(b, 1024, NewGZIPCompressor())
   292  }
   293  
   294  func BenchmarkGZIPCompressor8KiB(b *testing.B) {
   295  	bmCompressor(b, 8*1024, NewGZIPCompressor())
   296  }
   297  
   298  func BenchmarkGZIPCompressor64KiB(b *testing.B) {
   299  	bmCompressor(b, 64*1024, NewGZIPCompressor())
   300  }
   301  
   302  func BenchmarkGZIPCompressor512KiB(b *testing.B) {
   303  	bmCompressor(b, 512*1024, NewGZIPCompressor())
   304  }
   305  
   306  func BenchmarkGZIPCompressor1MiB(b *testing.B) {
   307  	bmCompressor(b, 1024*1024, NewGZIPCompressor())
   308  }
   309  
   310  // compressWithDeterministicError compresses the input data and returns a BufferSlice.
   311  func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice {
   312  	t.Helper()
   313  	var buf bytes.Buffer
   314  	gz := gzip.NewWriter(&buf)
   315  	if _, err := gz.Write(input); err != nil {
   316  		t.Fatalf("compressInput() failed to write data: %v", err)
   317  	}
   318  	if err := gz.Close(); err != nil {
   319  		t.Fatalf("compressInput() failed to close gzip writer: %v", err)
   320  	}
   321  	compressedData := buf.Bytes()
   322  	return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
   323  }
   324  
   325  // MockDecompressor is a mock implementation of a decompressor used for testing purposes.
   326  // It simulates decompression behavior, returning either decompressed data or an error based on the ShouldError flag.
   327  type MockDecompressor struct {
   328  	ShouldError bool // Flag to control whether the decompression should simulate an error.
   329  }
   330  
   331  // Do simulates decompression. It returns a predefined error if ShouldError is true,
   332  // or a fixed set of decompressed data if ShouldError is false.
   333  func (m *MockDecompressor) Do(_ io.Reader) ([]byte, error) {
   334  	if m.ShouldError {
   335  		return nil, errors.New(decompressionErrorMsg)
   336  	}
   337  	return []byte(defaultDecompressedData), nil
   338  }
   339  
   340  // Type returns the string identifier for the MockDecompressor.
   341  func (m *MockDecompressor) Type() string {
   342  	return "MockDecompressor"
   343  }
   344  
   345  // TestDecompress tests the decompress function behaves correctly for following scenarios
   346  // decompress successfully when message is <= maxReceiveMessageSize
   347  // errors when message > maxReceiveMessageSize
   348  // decompress successfully when maxReceiveMessageSize is MaxInt
   349  // errors when the decompressed message has an invalid format
   350  // errors when the decompressed message exceeds the maxReceiveMessageSize.
   351  func (s) TestDecompress(t *testing.T) {
   352  	compressor := encoding.GetCompressor("gzip")
   353  	validDecompressor := &MockDecompressor{ShouldError: false}
   354  	invalidFormatDecompressor := &MockDecompressor{ShouldError: true}
   355  
   356  	testCases := []struct {
   357  		name                  string
   358  		input                 mem.BufferSlice
   359  		dc                    Decompressor
   360  		maxReceiveMessageSize int
   361  		want                  []byte
   362  		wantErr               error
   363  	}{
   364  		{
   365  			name:                  "Decompresses successfully with sufficient buffer size",
   366  			input:                 compressWithDeterministicError(t, []byte("decompressed data")),
   367  			dc:                    nil,
   368  			maxReceiveMessageSize: 50,
   369  			want:                  []byte("decompressed data"),
   370  			wantErr:               nil,
   371  		},
   372  		{
   373  			name:                  "Fails due to exceeding maxReceiveMessageSize",
   374  			input:                 compressWithDeterministicError(t, []byte("message that is too large")),
   375  			dc:                    nil,
   376  			maxReceiveMessageSize: len("message that is too large") - 1,
   377  			want:                  nil,
   378  			wantErr:               status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", len("message that is too large")-1),
   379  		},
   380  		{
   381  			name:                  "Decompresses to exactly maxReceiveMessageSize",
   382  			input:                 compressWithDeterministicError(t, []byte("exact size message")),
   383  			dc:                    nil,
   384  			maxReceiveMessageSize: len("exact size message"),
   385  			want:                  []byte("exact size message"),
   386  			wantErr:               nil,
   387  		},
   388  		{
   389  			name:                  "Decompresses successfully with maxReceiveMessageSize MaxInt",
   390  			input:                 compressWithDeterministicError(t, []byte("large message")),
   391  			dc:                    nil,
   392  			maxReceiveMessageSize: math.MaxInt,
   393  			want:                  []byte("large message"),
   394  			wantErr:               nil,
   395  		},
   396  		{
   397  			name:                  "Fails with decompression error due to invalid format",
   398  			input:                 compressWithDeterministicError(t, []byte("invalid compressed data")),
   399  			dc:                    invalidFormatDecompressor,
   400  			maxReceiveMessageSize: 50,
   401  			want:                  nil,
   402  			wantErr:               status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", errors.New(decompressionErrorMsg)),
   403  		},
   404  		{
   405  			name:                  "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize",
   406  			input:                 compressWithDeterministicError(t, []byte("large compressed data")),
   407  			dc:                    validDecompressor,
   408  			maxReceiveMessageSize: 20,
   409  			want:                  nil,
   410  			wantErr:               status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20),
   411  		},
   412  	}
   413  
   414  	for _, tc := range testCases {
   415  		t.Run(tc.name, func(t *testing.T) {
   416  			output, err := decompress(compressor, tc.input, tc.dc, tc.maxReceiveMessageSize, mem.DefaultBufferPool())
   417  			if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) {
   418  				t.Fatalf("decompress() err = %v, wantErr = %v", err, tc.wantErr)
   419  			}
   420  			if !cmp.Equal(tc.want, output.Materialize()) {
   421  				t.Fatalf("decompress() output mismatch: got = %v, want = %v", output.Materialize(), tc.want)
   422  			}
   423  		})
   424  	}
   425  }
   426  
   427  type mockCompressor struct {
   428  	// Written to by the io.Reader on every call to Read.
   429  	ch chan<- struct{}
   430  }
   431  
   432  func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) {
   433  	panic("unimplemented")
   434  }
   435  
   436  func (m *mockCompressor) Decompress(io.Reader) (io.Reader, error) {
   437  	return m, nil
   438  }
   439  
   440  func (m *mockCompressor) Read([]byte) (int, error) {
   441  	m.ch <- struct{}{}
   442  	return 1, io.EOF
   443  }
   444  
   445  func (m *mockCompressor) Name() string { return "" }
   446  
   447  // Tests that the decompressor's Read method is not called after it returns EOF.
   448  func (s) TestDecompress_NoReadAfterEOF(t *testing.T) {
   449  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   450  	defer cancel()
   451  
   452  	ch := make(chan struct{}, 10)
   453  	mc := &mockCompressor{ch: ch}
   454  	in := mem.BufferSlice{mem.NewBuffer(&[]byte{1, 2, 3}, nil)}
   455  	wg := sync.WaitGroup{}
   456  	wg.Add(1)
   457  	go func() {
   458  		defer wg.Done()
   459  		out, err := decompress(mc, in, nil, 1, mem.DefaultBufferPool())
   460  		if err != nil {
   461  			t.Errorf("Unexpected error from decompress: %v", err)
   462  			return
   463  		}
   464  		out.Free()
   465  	}()
   466  	select {
   467  	case <-ch:
   468  	case <-ctx.Done():
   469  		t.Fatalf("Timed out waiting for call to compressor")
   470  	}
   471  	ctx, cancel = context.WithTimeout(ctx, defaultTestShortTimeout)
   472  	defer cancel()
   473  	select {
   474  	case <-ch:
   475  		t.Fatalf("Unexpected second compressor.Read call detected")
   476  	case <-ctx.Done():
   477  	}
   478  	wg.Wait()
   479  }