google.golang.org/grpc@v1.62.1/encoding/encoding_test.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package encoding_test
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"strings"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials/insecure"
    34  	"google.golang.org/grpc/encoding"
    35  	"google.golang.org/grpc/encoding/proto"
    36  	"google.golang.org/grpc/internal/grpctest"
    37  	"google.golang.org/grpc/internal/grpcutil"
    38  	"google.golang.org/grpc/internal/stubserver"
    39  	"google.golang.org/grpc/metadata"
    40  	"google.golang.org/grpc/status"
    41  
    42  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    43  	testpb "google.golang.org/grpc/interop/grpc_testing"
    44  )
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  type s struct {
    49  	grpctest.Tester
    50  }
    51  
    52  func Test(t *testing.T) {
    53  	grpctest.RunSubTests(t, s{})
    54  }
    55  
    56  type mockNamedCompressor struct {
    57  	encoding.Compressor
    58  }
    59  
    60  func (mockNamedCompressor) Name() string {
    61  	return "mock-compressor"
    62  }
    63  
    64  // Tests the case where a compressor with the same name is registered multiple
    65  // times. Test verifies the following:
    66  //   - the most recent registration is the one which is active
    67  //   - grpcutil.RegisteredCompressorNames contains a single instance of the
    68  //     previously registered compressor's name
    69  func (s) TestDuplicateCompressorRegister(t *testing.T) {
    70  	encoding.RegisterCompressor(&mockNamedCompressor{})
    71  
    72  	// Register another instance of the same compressor.
    73  	mc := &mockNamedCompressor{}
    74  	encoding.RegisterCompressor(mc)
    75  	if got := encoding.GetCompressor("mock-compressor"); got != mc {
    76  		t.Fatalf("Unexpected compressor, got: %+v, want:%+v", got, mc)
    77  	}
    78  
    79  	wantNames := []string{"mock-compressor"}
    80  	if !cmp.Equal(wantNames, grpcutil.RegisteredCompressorNames) {
    81  		t.Fatalf("Unexpected compressor names, got: %+v, want:%+v", grpcutil.RegisteredCompressorNames, wantNames)
    82  	}
    83  }
    84  
    85  // errProtoCodec wraps the proto codec and delegates to it if it is configured
    86  // to return a nil error. Else, it returns the configured error.
    87  type errProtoCodec struct {
    88  	name        string
    89  	encodingErr error
    90  	decodingErr error
    91  }
    92  
    93  func (c *errProtoCodec) Marshal(v any) ([]byte, error) {
    94  	if c.encodingErr != nil {
    95  		return nil, c.encodingErr
    96  	}
    97  	return encoding.GetCodec(proto.Name).Marshal(v)
    98  }
    99  
   100  func (c *errProtoCodec) Unmarshal(data []byte, v any) error {
   101  	if c.decodingErr != nil {
   102  		return c.decodingErr
   103  	}
   104  	return encoding.GetCodec(proto.Name).Unmarshal(data, v)
   105  }
   106  
   107  func (c *errProtoCodec) Name() string {
   108  	return c.name
   109  }
   110  
   111  // Tests the case where encoding fails on the server. Verifies that there is
   112  // no panic and that the encoding error is propagated to the client.
   113  func (s) TestEncodeDoesntPanicOnServer(t *testing.T) {
   114  	grpctest.TLogger.ExpectError("grpc: server failed to encode response")
   115  
   116  	// Create an codec that errors when encoding messages.
   117  	encodingErr := errors.New("encoding failed")
   118  	ec := &errProtoCodec{name: t.Name(), encodingErr: encodingErr}
   119  
   120  	// Start a server with the above codec.
   121  	backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
   122  	defer backend.Stop()
   123  
   124  	// Create a channel to the above server.
   125  	cc, err := grpc.Dial(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
   126  	if err != nil {
   127  		t.Fatalf("Failed to dial test backend at %q: %v", backend.Address, err)
   128  	}
   129  	defer cc.Close()
   130  
   131  	// Make an RPC and expect it to fail. Since we do not specify any codec
   132  	// here, the proto codec will get automatically used.
   133  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   134  	defer cancel()
   135  	client := testgrpc.NewTestServiceClient(cc)
   136  	_, err = client.EmptyCall(ctx, &testpb.Empty{})
   137  	if err == nil || !strings.Contains(err.Error(), encodingErr.Error()) {
   138  		t.Fatalf("RPC failed with error: %v, want: %v", err, encodingErr)
   139  	}
   140  
   141  	// Configure the codec on the server to not return errors anymore and expect
   142  	// the RPC to succeed.
   143  	ec.encodingErr = nil
   144  	if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   145  		t.Fatalf("RPC failed with error: %v", err)
   146  	}
   147  }
   148  
   149  // Tests the case where decoding fails on the server. Verifies that there is
   150  // no panic and that the decoding error is propagated to the client.
   151  func (s) TestDecodeDoesntPanicOnServer(t *testing.T) {
   152  	// Create an codec that errors when decoding messages.
   153  	decodingErr := errors.New("decoding failed")
   154  	ec := &errProtoCodec{name: t.Name(), decodingErr: decodingErr}
   155  
   156  	// Start a server with the above codec.
   157  	backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec))
   158  	defer backend.Stop()
   159  
   160  	// Create a channel to the above server. Since we do not specify any codec
   161  	// here, the proto codec will get automatically used.
   162  	cc, err := grpc.Dial(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
   163  	if err != nil {
   164  		t.Fatalf("Failed to dial test backend at %q: %v", backend.Address, err)
   165  	}
   166  	defer cc.Close()
   167  
   168  	// Make an RPC and expect it to fail. Since we do not specify any codec
   169  	// here, the proto codec will get automatically used.
   170  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   171  	defer cancel()
   172  	client := testgrpc.NewTestServiceClient(cc)
   173  	_, err = client.EmptyCall(ctx, &testpb.Empty{})
   174  	if err == nil || !strings.Contains(err.Error(), decodingErr.Error()) || !strings.Contains(err.Error(), "grpc: error unmarshalling request") {
   175  		t.Fatalf("RPC failed with error: %v, want: %v", err, decodingErr)
   176  	}
   177  
   178  	// Configure the codec on the server to not return errors anymore and expect
   179  	// the RPC to succeed.
   180  	ec.decodingErr = nil
   181  	if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   182  		t.Fatalf("RPC failed with error: %v", err)
   183  	}
   184  }
   185  
   186  // Tests the case where encoding fails on the client . Verifies that there is
   187  // no panic and that the encoding error is propagated to the RPC caller.
   188  func (s) TestEncodeDoesntPanicOnClient(t *testing.T) {
   189  	// Start a server and since we do not specify any codec here, the proto
   190  	// codec will get automatically used.
   191  	backend := stubserver.StartTestService(t, nil)
   192  	defer backend.Stop()
   193  
   194  	// Create an codec that errors when encoding messages.
   195  	encodingErr := errors.New("encoding failed")
   196  	ec := &errProtoCodec{name: t.Name(), encodingErr: encodingErr}
   197  
   198  	// Create a channel to the above server.
   199  	cc, err := grpc.Dial(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
   200  	if err != nil {
   201  		t.Fatalf("Failed to dial test backend at %q: %v", backend.Address, err)
   202  	}
   203  	defer cc.Close()
   204  
   205  	// Make an RPC with the erroring codec and expect it to fail.
   206  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   207  	defer cancel()
   208  	client := testgrpc.NewTestServiceClient(cc)
   209  	_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
   210  	if err == nil || !strings.Contains(err.Error(), encodingErr.Error()) {
   211  		t.Fatalf("RPC failed with error: %v, want: %v", err, encodingErr)
   212  	}
   213  
   214  	// Configure the codec on the client to not return errors anymore and expect
   215  	// the RPC to succeed.
   216  	ec.encodingErr = nil
   217  	if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
   218  		t.Fatalf("RPC failed with error: %v", err)
   219  	}
   220  }
   221  
   222  // Tests the case where decoding fails on the server. Verifies that there is
   223  // no panic and that the decoding error is propagated to the RPC caller.
   224  func (s) TestDecodeDoesntPanicOnClient(t *testing.T) {
   225  	// Start a server and since we do not specify any codec here, the proto
   226  	// codec will get automatically used.
   227  	backend := stubserver.StartTestService(t, nil)
   228  	defer backend.Stop()
   229  
   230  	// Create an codec that errors when decoding messages.
   231  	decodingErr := errors.New("decoding failed")
   232  	ec := &errProtoCodec{name: t.Name(), decodingErr: decodingErr}
   233  
   234  	// Create a channel to the above server.
   235  	cc, err := grpc.Dial(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
   236  	if err != nil {
   237  		t.Fatalf("Failed to dial test backend at %q: %v", backend.Address, err)
   238  	}
   239  	defer cc.Close()
   240  
   241  	// Make an RPC with the erroring codec and expect it to fail.
   242  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   243  	defer cancel()
   244  	client := testgrpc.NewTestServiceClient(cc)
   245  	_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec))
   246  	if err == nil || !strings.Contains(err.Error(), decodingErr.Error()) {
   247  		t.Fatalf("RPC failed with error: %v, want: %v", err, decodingErr)
   248  	}
   249  
   250  	// Configure the codec on the client to not return errors anymore and expect
   251  	// the RPC to succeed.
   252  	ec.decodingErr = nil
   253  	if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil {
   254  		t.Fatalf("RPC failed with error: %v", err)
   255  	}
   256  }
   257  
   258  // countingProtoCodec wraps the proto codec and counts the number of times
   259  // Marshal and Unmarshal are called.
   260  type countingProtoCodec struct {
   261  	name string
   262  
   263  	// The following fields are accessed atomically.
   264  	marshalCount   int32
   265  	unmarshalCount int32
   266  }
   267  
   268  func (p *countingProtoCodec) Marshal(v any) ([]byte, error) {
   269  	atomic.AddInt32(&p.marshalCount, 1)
   270  	return encoding.GetCodec(proto.Name).Marshal(v)
   271  }
   272  
   273  func (p *countingProtoCodec) Unmarshal(data []byte, v any) error {
   274  	atomic.AddInt32(&p.unmarshalCount, 1)
   275  	return encoding.GetCodec(proto.Name).Unmarshal(data, v)
   276  }
   277  
   278  func (p *countingProtoCodec) Name() string {
   279  	return p.name
   280  }
   281  
   282  // Tests the case where ForceServerCodec option is used on the server. Verifies
   283  // that encoding and decoding happen once per RPC.
   284  func (s) TestForceServerCodec(t *testing.T) {
   285  	// Create an server with the counting proto codec.
   286  	codec := &countingProtoCodec{name: t.Name()}
   287  	backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(codec))
   288  	defer backend.Stop()
   289  
   290  	// Create a channel to the above server.
   291  	cc, err := grpc.Dial(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
   292  	if err != nil {
   293  		t.Fatalf("Failed to dial test backend at %q: %v", backend.Address, err)
   294  	}
   295  	defer cc.Close()
   296  
   297  	// Make an RPC and expect it to fail. Since we do not specify any codec
   298  	// here, the proto codec will get automatically used.
   299  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   300  	defer cancel()
   301  	client := testgrpc.NewTestServiceClient(cc)
   302  	if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   303  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   304  	}
   305  
   306  	unmarshalCount := atomic.LoadInt32(&codec.unmarshalCount)
   307  	const wantUnmarshalCount = 1
   308  	if unmarshalCount != wantUnmarshalCount {
   309  		t.Fatalf("Unmarshal Count = %d; want %d", unmarshalCount, wantUnmarshalCount)
   310  	}
   311  	marshalCount := atomic.LoadInt32(&codec.marshalCount)
   312  	const wantMarshalCount = 1
   313  	if marshalCount != wantMarshalCount {
   314  		t.Fatalf("MarshalCount = %d; want %d", marshalCount, wantMarshalCount)
   315  	}
   316  }
   317  
   318  // renameProtoCodec wraps the proto codec and allows customizing the Name().
   319  type renameProtoCodec struct {
   320  	encoding.Codec
   321  	name string
   322  }
   323  
   324  func (r *renameProtoCodec) Name() string { return r.name }
   325  
   326  // TestForceCodecName confirms that the ForceCodec call option sets the subtype
   327  // in the content-type header according to the Name() of the codec provided.
   328  // Verifies that the name is converted to lowercase before transmitting.
   329  func (s) TestForceCodecName(t *testing.T) {
   330  	wantContentTypeCh := make(chan []string, 1)
   331  	defer close(wantContentTypeCh)
   332  
   333  	// Create a test service backend that pushes the received content-type on a
   334  	// channel for the test to inspect.
   335  	ss := &stubserver.StubServer{
   336  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   337  			md, ok := metadata.FromIncomingContext(ctx)
   338  			if !ok {
   339  				return nil, status.Errorf(codes.Internal, "no metadata in context")
   340  			}
   341  			if got, want := md["content-type"], <-wantContentTypeCh; !cmp.Equal(got, want) {
   342  				return nil, status.Errorf(codes.Internal, "got content-type=%q; want [%q]", got, want)
   343  			}
   344  			return &testpb.Empty{}, nil
   345  		},
   346  	}
   347  	// Since we don't specify a codec as a server option, it will end up
   348  	// automatically using the proto codec.
   349  	if err := ss.Start(nil); err != nil {
   350  		t.Fatalf("Error starting endpoint server: %v", err)
   351  	}
   352  	defer ss.Stop()
   353  
   354  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   355  	defer cancel()
   356  
   357  	// Force the use of the custom codec on the client with the ForceCodec call
   358  	// option. Confirm the name is converted to lowercase before transmitting.
   359  	codec := &renameProtoCodec{Codec: encoding.GetCodec(proto.Name), name: t.Name()}
   360  	wantContentTypeCh <- []string{fmt.Sprintf("application/grpc+%s", strings.ToLower(t.Name()))}
   361  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil {
   362  		t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err)
   363  	}
   364  }