google.golang.org/grpc@v1.72.2/test/compressor_test.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"bytes"
    23  	"compress/gzip"
    24  	"context"
    25  	"io"
    26  	"reflect"
    27  	"strings"
    28  	"sync/atomic"
    29  	"testing"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials/insecure"
    34  	"google.golang.org/grpc/encoding"
    35  	"google.golang.org/grpc/internal/stubserver"
    36  	"google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/status"
    38  
    39  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    40  	testpb "google.golang.org/grpc/interop/grpc_testing"
    41  )
    42  
    43  // TestUnsupportedEncodingResponse validates gRPC status codes
    44  // for different client-server compression setups
    45  // ensuring the correct behavior when compression is enabled or disabled on either side.
    46  func (s) TestUnsupportedEncodingResponse(t *testing.T) {
    47  	tests := []struct {
    48  		name           string
    49  		clientCompress bool
    50  		serverCompress bool
    51  		wantStatus     codes.Code
    52  	}{
    53  		{
    54  			name:           "client_server_compression",
    55  			clientCompress: true,
    56  			serverCompress: true,
    57  			wantStatus:     codes.OK,
    58  		},
    59  		{
    60  			name:           "client_compression",
    61  			clientCompress: true,
    62  			serverCompress: false,
    63  			wantStatus:     codes.Unimplemented,
    64  		},
    65  		{
    66  			name:           "server_compression",
    67  			clientCompress: false,
    68  			serverCompress: true,
    69  			wantStatus:     codes.Internal,
    70  		},
    71  	}
    72  
    73  	for _, test := range tests {
    74  		t.Run(test.name, func(t *testing.T) {
    75  			ss := &stubserver.StubServer{
    76  				UnaryCallF: func(_ context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
    77  					return &testpb.SimpleResponse{Payload: in.Payload}, nil
    78  				},
    79  			}
    80  			sopts := []grpc.ServerOption{}
    81  			if test.serverCompress {
    82  				// Using deprecated methods to selectively apply compression
    83  				// only on the server side. With encoding.registerCompressor(),
    84  				// the compressor is applied globally, affecting client and server
    85  				sopts = append(sopts, grpc.RPCCompressor(newNopCompressor()), grpc.RPCDecompressor(newNopDecompressor()))
    86  			}
    87  			if err := ss.StartServer(sopts...); err != nil {
    88  				t.Fatalf("Error starting server: %v", err)
    89  			}
    90  			defer ss.Stop()
    91  
    92  			dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
    93  			if test.clientCompress {
    94  				// UseCompressor() requires the compressor to be registered
    95  				// using encoding.RegisterCompressor() which applies compressor globally,
    96  				// Hence, using deprecated WithCompressor() and WithDecompressor()
    97  				// to apply compression only on client.
    98  				dopts = append(dopts, grpc.WithCompressor(newNopCompressor()), grpc.WithDecompressor(newNopDecompressor()))
    99  			}
   100  			if err := ss.StartClient(dopts...); err != nil {
   101  				t.Fatalf("Error starting client: %v", err)
   102  			}
   103  
   104  			payload := &testpb.SimpleRequest{
   105  				Payload: &testpb.Payload{
   106  					Body: []byte("test message"),
   107  				},
   108  			}
   109  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   110  			defer cancel()
   111  			_, err := ss.Client.UnaryCall(ctx, payload)
   112  			if got, want := status.Code(err), test.wantStatus; got != want {
   113  				t.Errorf("Client.UnaryCall() = %v, want %v", got, want)
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  func (s) TestCompressServerHasNoSupport(t *testing.T) {
   120  	for _, e := range listTestEnv() {
   121  		testCompressServerHasNoSupport(t, e)
   122  	}
   123  }
   124  
   125  func testCompressServerHasNoSupport(t *testing.T, e env) {
   126  	te := newTest(t, e)
   127  	te.serverCompression = false
   128  	te.clientCompression = false
   129  	te.clientNopCompression = true
   130  	te.startServer(&testServer{security: e.security})
   131  	defer te.tearDown()
   132  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   133  
   134  	const argSize = 271828
   135  	const respSize = 314159
   136  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
   137  	if err != nil {
   138  		t.Fatal(err)
   139  	}
   140  	req := &testpb.SimpleRequest{
   141  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   142  		ResponseSize: respSize,
   143  		Payload:      payload,
   144  	}
   145  
   146  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   147  	defer cancel()
   148  	if _, err := tc.UnaryCall(ctx, req); err == nil || status.Code(err) != codes.Unimplemented {
   149  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %s", err, codes.Unimplemented)
   150  	}
   151  	// Streaming RPC
   152  	stream, err := tc.FullDuplexCall(ctx)
   153  	if err != nil {
   154  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   155  	}
   156  	if _, err := stream.Recv(); err == nil || status.Code(err) != codes.Unimplemented {
   157  		t.Fatalf("%v.Recv() = %v, want error code %s", stream, err, codes.Unimplemented)
   158  	}
   159  }
   160  
   161  func (s) TestCompressOK(t *testing.T) {
   162  	for _, e := range listTestEnv() {
   163  		testCompressOK(t, e)
   164  	}
   165  }
   166  
   167  func testCompressOK(t *testing.T, e env) {
   168  	te := newTest(t, e)
   169  	te.serverCompression = true
   170  	te.clientCompression = true
   171  	te.startServer(&testServer{security: e.security})
   172  	defer te.tearDown()
   173  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   174  
   175  	// Unary call
   176  	const argSize = 271828
   177  	const respSize = 314159
   178  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	req := &testpb.SimpleRequest{
   183  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   184  		ResponseSize: respSize,
   185  		Payload:      payload,
   186  	}
   187  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   188  	defer cancel()
   189  	ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("something", "something"))
   190  	if _, err := tc.UnaryCall(ctx, req); err != nil {
   191  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
   192  	}
   193  	// Streaming RPC
   194  	stream, err := tc.FullDuplexCall(ctx)
   195  	if err != nil {
   196  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   197  	}
   198  	respParam := []*testpb.ResponseParameters{
   199  		{
   200  			Size: 31415,
   201  		},
   202  	}
   203  	payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
   204  	if err != nil {
   205  		t.Fatal(err)
   206  	}
   207  	sreq := &testpb.StreamingOutputCallRequest{
   208  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   209  		ResponseParameters: respParam,
   210  		Payload:            payload,
   211  	}
   212  	if err := stream.Send(sreq); err != nil {
   213  		t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
   214  	}
   215  	stream.CloseSend()
   216  	if _, err := stream.Recv(); err != nil {
   217  		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
   218  	}
   219  	if _, err := stream.Recv(); err != io.EOF {
   220  		t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err)
   221  	}
   222  }
   223  
   224  func (s) TestIdentityEncoding(t *testing.T) {
   225  	for _, e := range listTestEnv() {
   226  		testIdentityEncoding(t, e)
   227  	}
   228  }
   229  
   230  func testIdentityEncoding(t *testing.T, e env) {
   231  	te := newTest(t, e)
   232  	te.startServer(&testServer{security: e.security})
   233  	defer te.tearDown()
   234  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   235  
   236  	// Unary call
   237  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 5)
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  	req := &testpb.SimpleRequest{
   242  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   243  		ResponseSize: 10,
   244  		Payload:      payload,
   245  	}
   246  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   247  	defer cancel()
   248  	ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("something", "something"))
   249  	if _, err := tc.UnaryCall(ctx, req); err != nil {
   250  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
   251  	}
   252  	// Streaming RPC
   253  	stream, err := tc.FullDuplexCall(ctx, grpc.UseCompressor("identity"))
   254  	if err != nil {
   255  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   256  	}
   257  	payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
   258  	if err != nil {
   259  		t.Fatal(err)
   260  	}
   261  	sreq := &testpb.StreamingOutputCallRequest{
   262  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   263  		ResponseParameters: []*testpb.ResponseParameters{{Size: 10}},
   264  		Payload:            payload,
   265  	}
   266  	if err := stream.Send(sreq); err != nil {
   267  		t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
   268  	}
   269  	stream.CloseSend()
   270  	if _, err := stream.Recv(); err != nil {
   271  		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
   272  	}
   273  	if _, err := stream.Recv(); err != io.EOF {
   274  		t.Fatalf("%v.Recv() = %v, want io.EOF", stream, err)
   275  	}
   276  }
   277  
   278  // renameCompressor is a grpc.Compressor wrapper that allows customizing the
   279  // Type() of another compressor.
   280  type renameCompressor struct {
   281  	grpc.Compressor
   282  	name string
   283  }
   284  
   285  func (r *renameCompressor) Type() string { return r.name }
   286  
   287  // renameDecompressor is a grpc.Decompressor wrapper that allows customizing the
   288  // Type() of another Decompressor.
   289  type renameDecompressor struct {
   290  	grpc.Decompressor
   291  	name string
   292  }
   293  
   294  func (r *renameDecompressor) Type() string { return r.name }
   295  
   296  func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
   297  	wantGrpcAcceptEncodingCh := make(chan []string, 1)
   298  	defer close(wantGrpcAcceptEncodingCh)
   299  
   300  	compressor := renameCompressor{Compressor: grpc.NewGZIPCompressor(), name: "testgzip"}
   301  	decompressor := renameDecompressor{Decompressor: grpc.NewGZIPDecompressor(), name: "testgzip"}
   302  
   303  	ss := &stubserver.StubServer{
   304  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   305  			md, ok := metadata.FromIncomingContext(ctx)
   306  			if !ok {
   307  				return nil, status.Errorf(codes.Internal, "no metadata in context")
   308  			}
   309  			if got, want := md["grpc-accept-encoding"], <-wantGrpcAcceptEncodingCh; !reflect.DeepEqual(got, want) {
   310  				return nil, status.Errorf(codes.Internal, "got grpc-accept-encoding=%q; want [%q]", got, want)
   311  			}
   312  			return &testpb.Empty{}, nil
   313  		},
   314  	}
   315  	if err := ss.Start([]grpc.ServerOption{grpc.RPCDecompressor(&decompressor)}); err != nil {
   316  		t.Fatalf("Error starting endpoint server: %v", err)
   317  	}
   318  	defer ss.Stop()
   319  
   320  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   321  	defer cancel()
   322  
   323  	wantGrpcAcceptEncodingCh <- []string{"gzip"}
   324  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   325  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   326  	}
   327  
   328  	wantGrpcAcceptEncodingCh <- []string{"gzip"}
   329  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.UseCompressor("gzip")); err != nil {
   330  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   331  	}
   332  
   333  	// Use compressor directly which is not registered via
   334  	// encoding.RegisterCompressor.
   335  	if err := ss.StartClient(grpc.WithCompressor(&compressor)); err != nil {
   336  		t.Fatalf("Error starting client: %v", err)
   337  	}
   338  	wantGrpcAcceptEncodingCh <- []string{"gzip,testgzip"}
   339  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   340  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   341  	}
   342  }
   343  
   344  // wrapCompressor is a wrapper of encoding.Compressor which maintains count of
   345  // Compressor method invokes.
   346  type wrapCompressor struct {
   347  	encoding.Compressor
   348  	compressInvokes int32
   349  }
   350  
   351  func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
   352  	atomic.AddInt32(&wc.compressInvokes, 1)
   353  	return wc.Compressor.Compress(w)
   354  }
   355  
   356  func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
   357  	oldC := encoding.GetCompressor("gzip")
   358  	c := &wrapCompressor{Compressor: oldC}
   359  	encoding.RegisterCompressor(c)
   360  	t.Cleanup(func() {
   361  		encoding.RegisterCompressor(oldC)
   362  	})
   363  	return c
   364  }
   365  
   366  func (s) TestSetSendCompressorSuccess(t *testing.T) {
   367  	for _, tt := range []struct {
   368  		name                string
   369  		desc                string
   370  		payload             *testpb.Payload
   371  		dialOpts            []grpc.DialOption
   372  		resCompressor       string
   373  		wantCompressInvokes int32
   374  	}{
   375  		{
   376  			name:                "identity_request_and_gzip_response",
   377  			desc:                "request is uncompressed and response is gzip compressed",
   378  			payload:             &testpb.Payload{Body: []byte("payload")},
   379  			resCompressor:       "gzip",
   380  			wantCompressInvokes: 1,
   381  		},
   382  		{
   383  			name:                "identity_request_and_empty_response",
   384  			desc:                "request is uncompressed and response is gzip compressed",
   385  			payload:             nil,
   386  			resCompressor:       "gzip",
   387  			wantCompressInvokes: 0,
   388  		},
   389  		{
   390  			name:          "gzip_request_and_identity_response",
   391  			desc:          "request is gzip compressed and response is uncompressed with identity",
   392  			payload:       &testpb.Payload{Body: []byte("payload")},
   393  			resCompressor: "identity",
   394  			dialOpts: []grpc.DialOption{
   395  				// Use WithCompressor instead of UseCompressor to avoid counting
   396  				// the client's compressor usage.
   397  				grpc.WithCompressor(grpc.NewGZIPCompressor()),
   398  			},
   399  			wantCompressInvokes: 0,
   400  		},
   401  	} {
   402  		t.Run(tt.name, func(t *testing.T) {
   403  			t.Run("unary", func(t *testing.T) {
   404  				testUnarySetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
   405  			})
   406  
   407  			t.Run("stream", func(t *testing.T) {
   408  				testStreamSetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
   409  			})
   410  		})
   411  	}
   412  }
   413  
   414  func testUnarySetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
   415  	wc := setupGzipWrapCompressor(t)
   416  	ss := &stubserver.StubServer{
   417  		UnaryCallF: func(ctx context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   418  			if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
   419  				return nil, err
   420  			}
   421  			return &testpb.SimpleResponse{
   422  				Payload: payload,
   423  			}, nil
   424  		},
   425  	}
   426  	if err := ss.Start(nil, dialOpts...); err != nil {
   427  		t.Fatalf("Error starting endpoint server: %v", err)
   428  	}
   429  	defer ss.Stop()
   430  
   431  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   432  	defer cancel()
   433  
   434  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   435  		t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
   436  	}
   437  
   438  	compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
   439  	if compressInvokes != wantCompressInvokes {
   440  		t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
   441  	}
   442  }
   443  
   444  func testStreamSetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
   445  	wc := setupGzipWrapCompressor(t)
   446  	ss := &stubserver.StubServer{
   447  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   448  			if _, err := stream.Recv(); err != nil {
   449  				return err
   450  			}
   451  
   452  			if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
   453  				return err
   454  			}
   455  
   456  			return stream.Send(&testpb.StreamingOutputCallResponse{
   457  				Payload: payload,
   458  			})
   459  		},
   460  	}
   461  	if err := ss.Start(nil, dialOpts...); err != nil {
   462  		t.Fatalf("Error starting endpoint server: %v", err)
   463  	}
   464  	defer ss.Stop()
   465  
   466  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   467  	defer cancel()
   468  
   469  	s, err := ss.Client.FullDuplexCall(ctx)
   470  	if err != nil {
   471  		t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
   472  	}
   473  
   474  	if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
   475  		t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
   476  	}
   477  
   478  	if _, err := s.Recv(); err != nil {
   479  		t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
   480  	}
   481  
   482  	compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
   483  	if compressInvokes != wantCompressInvokes {
   484  		t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
   485  	}
   486  }
   487  
   488  func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) {
   489  	resCompressor := "snappy2"
   490  	wantErr := status.Error(codes.Unknown, "unable to set send compressor: compressor not registered \"snappy2\"")
   491  
   492  	t.Run("unary", func(t *testing.T) {
   493  		testUnarySetSendCompressorFailure(t, resCompressor, wantErr)
   494  	})
   495  
   496  	t.Run("stream", func(t *testing.T) {
   497  		testStreamSetSendCompressorFailure(t, resCompressor, wantErr)
   498  	})
   499  }
   500  
   501  func testUnarySetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
   502  	ss := &stubserver.StubServer{
   503  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   504  			if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
   505  				return nil, err
   506  			}
   507  			return &testpb.Empty{}, nil
   508  		},
   509  	}
   510  	if err := ss.Start(nil); err != nil {
   511  		t.Fatalf("Error starting endpoint server: %v", err)
   512  	}
   513  	defer ss.Stop()
   514  
   515  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   516  	defer cancel()
   517  
   518  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
   519  		t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
   520  	}
   521  }
   522  
   523  func testStreamSetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
   524  	ss := &stubserver.StubServer{
   525  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   526  			if _, err := stream.Recv(); err != nil {
   527  				return err
   528  			}
   529  
   530  			if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
   531  				return err
   532  			}
   533  
   534  			return stream.Send(&testpb.StreamingOutputCallResponse{})
   535  		},
   536  	}
   537  	if err := ss.Start(nil); err != nil {
   538  		t.Fatalf("Error starting endpoint server: %v, want: nil", err)
   539  	}
   540  	defer ss.Stop()
   541  
   542  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   543  	defer cancel()
   544  
   545  	s, err := ss.Client.FullDuplexCall(ctx)
   546  	if err != nil {
   547  		t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
   548  	}
   549  
   550  	if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
   551  		t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
   552  	}
   553  
   554  	if _, err := s.Recv(); !equalError(err, wantErr) {
   555  		t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
   556  	}
   557  }
   558  
   559  func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) {
   560  	ss := &stubserver.StubServer{
   561  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   562  			// Send headers early and then set send compressor.
   563  			grpc.SendHeader(ctx, metadata.MD{})
   564  			err := grpc.SetSendCompressor(ctx, "gzip")
   565  			if err == nil {
   566  				t.Error("Wanted set send compressor error")
   567  				return &testpb.Empty{}, nil
   568  			}
   569  			return nil, err
   570  		},
   571  	}
   572  	if err := ss.Start(nil); err != nil {
   573  		t.Fatalf("Error starting endpoint server: %v", err)
   574  	}
   575  	defer ss.Stop()
   576  
   577  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   578  	defer cancel()
   579  
   580  	wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
   581  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
   582  		t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
   583  	}
   584  }
   585  
   586  func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) {
   587  	ss := &stubserver.StubServer{
   588  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   589  			// Send headers early and then set send compressor.
   590  			grpc.SendHeader(stream.Context(), metadata.MD{})
   591  			err := grpc.SetSendCompressor(stream.Context(), "gzip")
   592  			if err == nil {
   593  				t.Error("Wanted set send compressor error")
   594  			}
   595  			return err
   596  		},
   597  	}
   598  	if err := ss.Start(nil); err != nil {
   599  		t.Fatalf("Error starting endpoint server: %v", err)
   600  	}
   601  	defer ss.Stop()
   602  
   603  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   604  	defer cancel()
   605  
   606  	wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
   607  	s, err := ss.Client.FullDuplexCall(ctx)
   608  	if err != nil {
   609  		t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
   610  	}
   611  
   612  	if _, err := s.Recv(); !equalError(err, wantErr) {
   613  		t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr)
   614  	}
   615  }
   616  
   617  func (s) TestClientSupportedCompressors(t *testing.T) {
   618  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   619  	defer cancel()
   620  
   621  	for _, tt := range []struct {
   622  		desc string
   623  		ctx  context.Context
   624  		want []string
   625  	}{
   626  		{
   627  			desc: "No additional grpc-accept-encoding header",
   628  			ctx:  ctx,
   629  			want: []string{"gzip"},
   630  		},
   631  		{
   632  			desc: "With additional grpc-accept-encoding header",
   633  			ctx: metadata.AppendToOutgoingContext(ctx,
   634  				"grpc-accept-encoding", "test-compressor-1",
   635  				"grpc-accept-encoding", "test-compressor-2",
   636  			),
   637  			want: []string{"gzip", "test-compressor-1", "test-compressor-2"},
   638  		},
   639  		{
   640  			desc: "With additional empty grpc-accept-encoding header",
   641  			ctx: metadata.AppendToOutgoingContext(ctx,
   642  				"grpc-accept-encoding", "",
   643  			),
   644  			want: []string{"gzip"},
   645  		},
   646  		{
   647  			desc: "With additional grpc-accept-encoding header with spaces between values",
   648  			ctx: metadata.AppendToOutgoingContext(ctx,
   649  				"grpc-accept-encoding", "identity, deflate",
   650  			),
   651  			want: []string{"gzip", "identity", "deflate"},
   652  		},
   653  	} {
   654  		t.Run(tt.desc, func(t *testing.T) {
   655  			ss := &stubserver.StubServer{
   656  				EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   657  					got, err := grpc.ClientSupportedCompressors(ctx)
   658  					if err != nil {
   659  						return nil, err
   660  					}
   661  
   662  					if !reflect.DeepEqual(got, tt.want) {
   663  						t.Errorf("unexpected client compressors got: %v, want: %v", got, tt.want)
   664  					}
   665  
   666  					return &testpb.Empty{}, nil
   667  				},
   668  			}
   669  			if err := ss.Start(nil); err != nil {
   670  				t.Fatalf("Error starting endpoint server: %v, want: nil", err)
   671  			}
   672  			defer ss.Stop()
   673  
   674  			_, err := ss.Client.EmptyCall(tt.ctx, &testpb.Empty{})
   675  			if err != nil {
   676  				t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
   677  			}
   678  		})
   679  	}
   680  }
   681  
   682  func (s) TestCompressorRegister(t *testing.T) {
   683  	for _, e := range listTestEnv() {
   684  		testCompressorRegister(t, e)
   685  	}
   686  }
   687  
   688  func testCompressorRegister(t *testing.T, e env) {
   689  	te := newTest(t, e)
   690  	te.clientCompression = false
   691  	te.serverCompression = false
   692  	te.clientUseCompression = true
   693  
   694  	te.startServer(&testServer{security: e.security})
   695  	defer te.tearDown()
   696  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   697  
   698  	// Unary call
   699  	const argSize = 271828
   700  	const respSize = 314159
   701  	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
   702  	if err != nil {
   703  		t.Fatal(err)
   704  	}
   705  	req := &testpb.SimpleRequest{
   706  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   707  		ResponseSize: respSize,
   708  		Payload:      payload,
   709  	}
   710  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   711  	defer cancel()
   712  	ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("something", "something"))
   713  	if _, err := tc.UnaryCall(ctx, req); err != nil {
   714  		t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
   715  	}
   716  	// Streaming RPC
   717  	stream, err := tc.FullDuplexCall(ctx)
   718  	if err != nil {
   719  		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   720  	}
   721  	respParam := []*testpb.ResponseParameters{
   722  		{
   723  			Size: 31415,
   724  		},
   725  	}
   726  	payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
   727  	if err != nil {
   728  		t.Fatal(err)
   729  	}
   730  	sreq := &testpb.StreamingOutputCallRequest{
   731  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   732  		ResponseParameters: respParam,
   733  		Payload:            payload,
   734  	}
   735  	if err := stream.Send(sreq); err != nil {
   736  		t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
   737  	}
   738  	if _, err := stream.Recv(); err != nil {
   739  		t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
   740  	}
   741  }
   742  
   743  type badGzipCompressor struct{}
   744  
   745  func (badGzipCompressor) Do(w io.Writer, p []byte) error {
   746  	buf := &bytes.Buffer{}
   747  	gzw := gzip.NewWriter(buf)
   748  	if _, err := gzw.Write(p); err != nil {
   749  		return err
   750  	}
   751  	err := gzw.Close()
   752  	bs := buf.Bytes()
   753  	if len(bs) >= 6 {
   754  		bs[len(bs)-6] ^= 1 // modify checksum at end by 1 byte
   755  	}
   756  	w.Write(bs)
   757  	return err
   758  }
   759  
   760  func (badGzipCompressor) Type() string {
   761  	return "gzip"
   762  }
   763  
   764  func (s) TestGzipBadChecksum(t *testing.T) {
   765  	ss := &stubserver.StubServer{
   766  		UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   767  			return &testpb.SimpleResponse{}, nil
   768  		},
   769  	}
   770  	if err := ss.Start(nil, grpc.WithCompressor(badGzipCompressor{})); err != nil {
   771  		t.Fatalf("Error starting endpoint server: %v", err)
   772  	}
   773  	defer ss.Stop()
   774  
   775  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   776  	defer cancel()
   777  
   778  	p, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1024))
   779  	if err != nil {
   780  		t.Fatalf("Unexpected error from newPayload: %v", err)
   781  	}
   782  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: p}); err == nil ||
   783  		status.Code(err) != codes.Internal ||
   784  		!strings.Contains(status.Convert(err).Message(), gzip.ErrChecksum.Error()) {
   785  		t.Errorf("ss.Client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum)
   786  	}
   787  }
   788  
   789  // fakeCompressor returns a messages of a configured size, irrespective of the
   790  // input.
   791  type fakeCompressor struct {
   792  	decompressedMessageSize int
   793  }
   794  
   795  func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
   796  	return nopWriteCloser{w}, nil
   797  }
   798  
   799  func (f *fakeCompressor) Decompress(io.Reader) (io.Reader, error) {
   800  	return bytes.NewReader(make([]byte, f.decompressedMessageSize)), nil
   801  }
   802  
   803  func (f *fakeCompressor) Name() string {
   804  	// Use the name of an existing compressor to avoid interactions with other
   805  	// tests since compressors can't be un-registered.
   806  	return "gzip"
   807  }
   808  
   809  type nopWriteCloser struct {
   810  	io.Writer
   811  }
   812  
   813  func (nopWriteCloser) Close() error {
   814  	return nil
   815  }
   816  
   817  // TestDecompressionExceedsMaxMessageSize uses a fake compressor that produces
   818  // messages of size 100 bytes on decompression. A server is started with the
   819  // max receive message size restricted to 99 bytes. The test verifies that the
   820  // client receives a ResourceExhausted response from the server.
   821  func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
   822  	oldC := encoding.GetCompressor("gzip")
   823  	defer func() {
   824  		encoding.RegisterCompressor(oldC)
   825  	}()
   826  	const messageLen = 100
   827  	encoding.RegisterCompressor(&fakeCompressor{decompressedMessageSize: messageLen})
   828  	ss := &stubserver.StubServer{
   829  		UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   830  			return &testpb.SimpleResponse{}, nil
   831  		},
   832  	}
   833  	if err := ss.Start([]grpc.ServerOption{grpc.MaxRecvMsgSize(messageLen - 1)}); err != nil {
   834  		t.Fatalf("Error starting endpoint server: %v", err)
   835  	}
   836  	defer ss.Stop()
   837  
   838  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   839  	defer cancel()
   840  
   841  	req := &testpb.SimpleRequest{Payload: &testpb.Payload{}}
   842  	_, err := ss.Client.UnaryCall(ctx, req, grpc.UseCompressor("gzip"))
   843  	if got, want := status.Code(err), codes.ResourceExhausted; got != want {
   844  		t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want)
   845  	}
   846  }