google.golang.org/grpc@v1.62.1/test/metadata_test.go (about)

     1  /*
     2   *
     3   * Copyright 2022 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"reflect"
    26  	"strings"
    27  	"testing"
    28  
    29  	"google.golang.org/grpc/codes"
    30  	"google.golang.org/grpc/internal/grpctest"
    31  	"google.golang.org/grpc/internal/stubserver"
    32  	"google.golang.org/grpc/metadata"
    33  	"google.golang.org/grpc/status"
    34  
    35  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    36  	testpb "google.golang.org/grpc/interop/grpc_testing"
    37  )
    38  
    39  func (s) TestInvalidMetadata(t *testing.T) {
    40  	grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 5)
    41  
    42  	tests := []struct {
    43  		name     string
    44  		md       metadata.MD
    45  		appendMD []string
    46  		want     error
    47  		recv     error
    48  	}{
    49  		{
    50  			name: "invalid key",
    51  			md:   map[string][]string{string(rune(0x19)): {"testVal"}},
    52  			want: status.Error(codes.Internal, "header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"),
    53  			recv: status.Error(codes.Internal, "invalid header field"),
    54  		},
    55  		{
    56  			name: "invalid value",
    57  			md:   map[string][]string{"test": {string(rune(0x19))}},
    58  			want: status.Error(codes.Internal, "header key \"test\" contains value with non-printable ASCII characters"),
    59  			recv: status.Error(codes.Internal, "invalid header field"),
    60  		},
    61  		{
    62  			name:     "invalid appended value",
    63  			md:       map[string][]string{"test": {"test"}},
    64  			appendMD: []string{"/", "value"},
    65  			want:     status.Error(codes.Internal, "header key \"/\" contains illegal characters not in [0-9a-z-_.]"),
    66  			recv:     status.Error(codes.Internal, "invalid header field"),
    67  		},
    68  		{
    69  			name:     "empty appended key",
    70  			md:       map[string][]string{"test": {"test"}},
    71  			appendMD: []string{"", "value"},
    72  			want:     status.Error(codes.Internal, "there is an empty key in the header"),
    73  			recv:     status.Error(codes.Internal, "invalid header field"),
    74  		},
    75  		{
    76  			name: "empty key",
    77  			md:   map[string][]string{"": {"test"}},
    78  			want: status.Error(codes.Internal, "there is an empty key in the header"),
    79  			recv: status.Error(codes.Internal, "invalid header field"),
    80  		},
    81  		{
    82  			name: "-bin key with arbitrary value",
    83  			md:   map[string][]string{"test-bin": {string(rune(0x19))}},
    84  			want: nil,
    85  			recv: io.EOF,
    86  		},
    87  		{
    88  			name: "valid key and value",
    89  			md:   map[string][]string{"test": {"value"}},
    90  			want: nil,
    91  			recv: io.EOF,
    92  		},
    93  	}
    94  
    95  	testNum := 0
    96  	ss := &stubserver.StubServer{
    97  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
    98  			return &testpb.Empty{}, nil
    99  		},
   100  		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
   101  			_, err := stream.Recv()
   102  			if err != nil {
   103  				return err
   104  			}
   105  			test := tests[testNum]
   106  			testNum++
   107  			// merge original md and added md.
   108  			md := metadata.Join(test.md, metadata.Pairs(test.appendMD...))
   109  
   110  			if err := stream.SetHeader(md); !reflect.DeepEqual(test.want, err) {
   111  				return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want)
   112  			}
   113  			if err := stream.SendHeader(md); !reflect.DeepEqual(test.want, err) {
   114  				return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want)
   115  			}
   116  			stream.SetTrailer(md)
   117  			return nil
   118  		},
   119  	}
   120  	if err := ss.Start(nil); err != nil {
   121  		t.Fatalf("Error starting ss endpoint server: %v", err)
   122  	}
   123  	defer ss.Stop()
   124  
   125  	for _, test := range tests {
   126  		t.Run("unary "+test.name, func(t *testing.T) {
   127  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   128  			defer cancel()
   129  			ctx = metadata.NewOutgoingContext(ctx, test.md)
   130  			ctx = metadata.AppendToOutgoingContext(ctx, test.appendMD...)
   131  			if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) {
   132  				t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want)
   133  			}
   134  		})
   135  	}
   136  
   137  	// call the stream server's api to drive the server-side unit testing
   138  	for _, test := range tests {
   139  		t.Run("streaming "+test.name, func(t *testing.T) {
   140  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   141  			defer cancel()
   142  			stream, err := ss.Client.FullDuplexCall(ctx)
   143  			if err != nil {
   144  				t.Errorf("call ss.Client.FullDuplexCall got err :%v", err)
   145  				return
   146  			}
   147  			if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
   148  				t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err)
   149  			}
   150  			if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) {
   151  				t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv)
   152  			}
   153  		})
   154  	}
   155  }