google.golang.org/grpc@v1.72.2/interop/test_utils.go (about)

     1  /*
     2   *
     3   * Copyright 2014 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package interop contains functions used by interop client/server.
    20  //
    21  // See interop test case descriptions [here].
    22  //
    23  // [here]: https://github.com/grpc/grpc/blob/master/doc/interop-test-descriptions.md
    24  package interop
    25  
    26  import (
    27  	"context"
    28  	"fmt"
    29  	"io"
    30  	"os"
    31  	"strings"
    32  	"sync"
    33  	"time"
    34  
    35  	"golang.org/x/oauth2"
    36  	"golang.org/x/oauth2/google"
    37  	"google.golang.org/grpc"
    38  	"google.golang.org/grpc/codes"
    39  	"google.golang.org/grpc/grpclog"
    40  	"google.golang.org/grpc/metadata"
    41  	"google.golang.org/grpc/orca"
    42  	"google.golang.org/grpc/status"
    43  	"google.golang.org/protobuf/proto"
    44  
    45  	v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3"
    46  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    47  	testpb "google.golang.org/grpc/interop/grpc_testing"
    48  )
    49  
    50  var (
    51  	reqSizes            = []int{27182, 8, 1828, 45904}
    52  	respSizes           = []int{31415, 9, 2653, 58979}
    53  	largeReqSize        = 271828
    54  	largeRespSize       = 314159
    55  	initialMetadataKey  = "x-grpc-test-echo-initial"
    56  	trailingMetadataKey = "x-grpc-test-echo-trailing-bin"
    57  
    58  	logger = grpclog.Component("interop")
    59  )
    60  
    61  // ClientNewPayload returns a payload of the given type and size.
    62  func ClientNewPayload(t testpb.PayloadType, size int) *testpb.Payload {
    63  	if size < 0 {
    64  		logger.Fatalf("Requested a response with invalid length %d", size)
    65  	}
    66  	body := make([]byte, size)
    67  	switch t {
    68  	case testpb.PayloadType_COMPRESSABLE:
    69  	default:
    70  		logger.Fatalf("Unsupported payload type: %d", t)
    71  	}
    72  	return &testpb.Payload{
    73  		Type: t,
    74  		Body: body,
    75  	}
    76  }
    77  
    78  // DoEmptyUnaryCall performs a unary RPC with empty request and response messages.
    79  func DoEmptyUnaryCall(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
    80  	reply, err := tc.EmptyCall(ctx, &testpb.Empty{}, args...)
    81  	if err != nil {
    82  		logger.Fatal("/TestService/EmptyCall RPC failed: ", err)
    83  	}
    84  	if !proto.Equal(&testpb.Empty{}, reply) {
    85  		logger.Fatalf("/TestService/EmptyCall receives %v, want %v", reply, testpb.Empty{})
    86  	}
    87  }
    88  
    89  // DoLargeUnaryCall performs a unary RPC with large payload in the request and response.
    90  func DoLargeUnaryCall(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
    91  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
    92  	req := &testpb.SimpleRequest{
    93  		ResponseType: testpb.PayloadType_COMPRESSABLE,
    94  		ResponseSize: int32(largeRespSize),
    95  		Payload:      pl,
    96  	}
    97  	reply, err := tc.UnaryCall(ctx, req, args...)
    98  	if err != nil {
    99  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   100  	}
   101  	t := reply.GetPayload().GetType()
   102  	s := len(reply.GetPayload().GetBody())
   103  	if t != testpb.PayloadType_COMPRESSABLE || s != largeRespSize {
   104  		logger.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, largeRespSize)
   105  	}
   106  }
   107  
   108  // DoClientStreaming performs a client streaming RPC.
   109  func DoClientStreaming(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   110  	stream, err := tc.StreamingInputCall(ctx, args...)
   111  	if err != nil {
   112  		logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
   113  	}
   114  	var sum int
   115  	for _, s := range reqSizes {
   116  		pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, s)
   117  		req := &testpb.StreamingInputCallRequest{
   118  			Payload: pl,
   119  		}
   120  		if err := stream.Send(req); err != nil {
   121  			logger.Fatalf("%v has error %v while sending %v", stream, err, req)
   122  		}
   123  		sum += s
   124  	}
   125  	reply, err := stream.CloseAndRecv()
   126  	if err != nil {
   127  		logger.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil)
   128  	}
   129  	if reply.GetAggregatedPayloadSize() != int32(sum) {
   130  		logger.Fatalf("%v.CloseAndRecv().GetAggregatePayloadSize() = %v; want %v", stream, reply.GetAggregatedPayloadSize(), sum)
   131  	}
   132  }
   133  
   134  // DoServerStreaming performs a server streaming RPC.
   135  func DoServerStreaming(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   136  	respParam := make([]*testpb.ResponseParameters, len(respSizes))
   137  	for i, s := range respSizes {
   138  		respParam[i] = &testpb.ResponseParameters{
   139  			Size: int32(s),
   140  		}
   141  	}
   142  	req := &testpb.StreamingOutputCallRequest{
   143  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   144  		ResponseParameters: respParam,
   145  	}
   146  	stream, err := tc.StreamingOutputCall(ctx, req, args...)
   147  	if err != nil {
   148  		logger.Fatalf("%v.StreamingOutputCall(_) = _, %v", tc, err)
   149  	}
   150  	var rpcStatus error
   151  	var respCnt int
   152  	var index int
   153  	for {
   154  		reply, err := stream.Recv()
   155  		if err != nil {
   156  			rpcStatus = err
   157  			break
   158  		}
   159  		t := reply.GetPayload().GetType()
   160  		if t != testpb.PayloadType_COMPRESSABLE {
   161  			logger.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE)
   162  		}
   163  		size := len(reply.GetPayload().GetBody())
   164  		if size != respSizes[index] {
   165  			logger.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
   166  		}
   167  		index++
   168  		respCnt++
   169  	}
   170  	if rpcStatus != io.EOF {
   171  		logger.Fatalf("Failed to finish the server streaming rpc: %v", rpcStatus)
   172  	}
   173  	if respCnt != len(respSizes) {
   174  		logger.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
   175  	}
   176  }
   177  
   178  // DoPingPong performs ping-pong style bi-directional streaming RPC.
   179  func DoPingPong(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   180  	stream, err := tc.FullDuplexCall(ctx, args...)
   181  	if err != nil {
   182  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
   183  	}
   184  	var index int
   185  	for index < len(reqSizes) {
   186  		respParam := []*testpb.ResponseParameters{
   187  			{
   188  				Size: int32(respSizes[index]),
   189  			},
   190  		}
   191  		pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, reqSizes[index])
   192  		req := &testpb.StreamingOutputCallRequest{
   193  			ResponseType:       testpb.PayloadType_COMPRESSABLE,
   194  			ResponseParameters: respParam,
   195  			Payload:            pl,
   196  		}
   197  		if err := stream.Send(req); err != nil {
   198  			logger.Fatalf("%v has error %v while sending %v", stream, err, req)
   199  		}
   200  		reply, err := stream.Recv()
   201  		if err != nil {
   202  			logger.Fatalf("%v.Recv() = %v", stream, err)
   203  		}
   204  		t := reply.GetPayload().GetType()
   205  		if t != testpb.PayloadType_COMPRESSABLE {
   206  			logger.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE)
   207  		}
   208  		size := len(reply.GetPayload().GetBody())
   209  		if size != respSizes[index] {
   210  			logger.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
   211  		}
   212  		index++
   213  	}
   214  	if err := stream.CloseSend(); err != nil {
   215  		logger.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
   216  	}
   217  	if _, err := stream.Recv(); err != io.EOF {
   218  		logger.Fatalf("%v failed to complele the ping pong test: %v", stream, err)
   219  	}
   220  }
   221  
   222  // DoEmptyStream sets up a bi-directional streaming with zero message.
   223  func DoEmptyStream(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   224  	stream, err := tc.FullDuplexCall(ctx, args...)
   225  	if err != nil {
   226  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
   227  	}
   228  	if err := stream.CloseSend(); err != nil {
   229  		logger.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
   230  	}
   231  	if _, err := stream.Recv(); err != io.EOF {
   232  		logger.Fatalf("%v failed to complete the empty stream test: %v", stream, err)
   233  	}
   234  }
   235  
   236  // DoTimeoutOnSleepingServer performs an RPC on a sleep server which causes RPC timeout.
   237  func DoTimeoutOnSleepingServer(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   238  	ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
   239  	defer cancel()
   240  	stream, err := tc.FullDuplexCall(ctx, args...)
   241  	if err != nil {
   242  		if status.Code(err) == codes.DeadlineExceeded {
   243  			return
   244  		}
   245  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
   246  	}
   247  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 27182)
   248  	req := &testpb.StreamingOutputCallRequest{
   249  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   250  		Payload:      pl,
   251  	}
   252  	if err := stream.Send(req); err != nil && err != io.EOF {
   253  		logger.Fatalf("%v.Send(_) = %v", stream, err)
   254  	}
   255  	if _, err := stream.Recv(); status.Code(err) != codes.DeadlineExceeded {
   256  		logger.Fatalf("%v.Recv() = _, %v, want error code %d", stream, err, codes.DeadlineExceeded)
   257  	}
   258  }
   259  
   260  // DoComputeEngineCreds performs a unary RPC with compute engine auth.
   261  func DoComputeEngineCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccount, oauthScope string) {
   262  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   263  	req := &testpb.SimpleRequest{
   264  		ResponseType:   testpb.PayloadType_COMPRESSABLE,
   265  		ResponseSize:   int32(largeRespSize),
   266  		Payload:        pl,
   267  		FillUsername:   true,
   268  		FillOauthScope: true,
   269  	}
   270  	reply, err := tc.UnaryCall(ctx, req)
   271  	if err != nil {
   272  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   273  	}
   274  	user := reply.GetUsername()
   275  	scope := reply.GetOauthScope()
   276  	if user != serviceAccount {
   277  		logger.Fatalf("Got user name %q, want %q.", user, serviceAccount)
   278  	}
   279  	if !strings.Contains(oauthScope, scope) {
   280  		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
   281  	}
   282  }
   283  
   284  func getServiceAccountJSONKey(keyFile string) []byte {
   285  	jsonKey, err := os.ReadFile(keyFile)
   286  	if err != nil {
   287  		logger.Fatalf("Failed to read the service account key file: %v", err)
   288  	}
   289  	return jsonKey
   290  }
   291  
   292  // DoServiceAccountCreds performs a unary RPC with service account auth.
   293  func DoServiceAccountCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) {
   294  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   295  	req := &testpb.SimpleRequest{
   296  		ResponseType:   testpb.PayloadType_COMPRESSABLE,
   297  		ResponseSize:   int32(largeRespSize),
   298  		Payload:        pl,
   299  		FillUsername:   true,
   300  		FillOauthScope: true,
   301  	}
   302  	reply, err := tc.UnaryCall(ctx, req)
   303  	if err != nil {
   304  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   305  	}
   306  	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
   307  	user := reply.GetUsername()
   308  	scope := reply.GetOauthScope()
   309  	if !strings.Contains(string(jsonKey), user) {
   310  		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
   311  	}
   312  	if !strings.Contains(oauthScope, scope) {
   313  		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
   314  	}
   315  }
   316  
   317  // DoJWTTokenCreds performs a unary RPC with JWT token auth.
   318  func DoJWTTokenCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile string) {
   319  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   320  	req := &testpb.SimpleRequest{
   321  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   322  		ResponseSize: int32(largeRespSize),
   323  		Payload:      pl,
   324  		FillUsername: true,
   325  	}
   326  	reply, err := tc.UnaryCall(ctx, req)
   327  	if err != nil {
   328  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   329  	}
   330  	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
   331  	user := reply.GetUsername()
   332  	if !strings.Contains(string(jsonKey), user) {
   333  		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
   334  	}
   335  }
   336  
   337  // GetToken obtains an OAUTH token from the input.
   338  func GetToken(ctx context.Context, serviceAccountKeyFile string, oauthScope string) *oauth2.Token {
   339  	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
   340  	config, err := google.JWTConfigFromJSON(jsonKey, oauthScope)
   341  	if err != nil {
   342  		logger.Fatalf("Failed to get the config: %v", err)
   343  	}
   344  	token, err := config.TokenSource(ctx).Token()
   345  	if err != nil {
   346  		logger.Fatalf("Failed to get the token: %v", err)
   347  	}
   348  	return token
   349  }
   350  
   351  // DoOauth2TokenCreds performs a unary RPC with OAUTH2 token auth.
   352  func DoOauth2TokenCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) {
   353  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   354  	req := &testpb.SimpleRequest{
   355  		ResponseType:   testpb.PayloadType_COMPRESSABLE,
   356  		ResponseSize:   int32(largeRespSize),
   357  		Payload:        pl,
   358  		FillUsername:   true,
   359  		FillOauthScope: true,
   360  	}
   361  	reply, err := tc.UnaryCall(ctx, req)
   362  	if err != nil {
   363  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   364  	}
   365  	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
   366  	user := reply.GetUsername()
   367  	scope := reply.GetOauthScope()
   368  	if !strings.Contains(string(jsonKey), user) {
   369  		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
   370  	}
   371  	if !strings.Contains(oauthScope, scope) {
   372  		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
   373  	}
   374  }
   375  
   376  // DoPerRPCCreds performs a unary RPC with per RPC OAUTH2 token.
   377  func DoPerRPCCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) {
   378  	jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile)
   379  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   380  	req := &testpb.SimpleRequest{
   381  		ResponseType:   testpb.PayloadType_COMPRESSABLE,
   382  		ResponseSize:   int32(largeRespSize),
   383  		Payload:        pl,
   384  		FillUsername:   true,
   385  		FillOauthScope: true,
   386  	}
   387  	token := GetToken(ctx, serviceAccountKeyFile, oauthScope)
   388  	kv := map[string]string{"authorization": token.Type() + " " + token.AccessToken}
   389  	ctx = metadata.NewOutgoingContext(ctx, metadata.MD{"authorization": []string{kv["authorization"]}})
   390  	reply, err := tc.UnaryCall(ctx, req)
   391  	if err != nil {
   392  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   393  	}
   394  	user := reply.GetUsername()
   395  	scope := reply.GetOauthScope()
   396  	if !strings.Contains(string(jsonKey), user) {
   397  		logger.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
   398  	}
   399  	if !strings.Contains(oauthScope, scope) {
   400  		logger.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, oauthScope)
   401  	}
   402  }
   403  
   404  // DoGoogleDefaultCredentials performs a unary RPC with google default credentials
   405  func DoGoogleDefaultCredentials(ctx context.Context, tc testgrpc.TestServiceClient, defaultServiceAccount string) {
   406  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   407  	req := &testpb.SimpleRequest{
   408  		ResponseType:   testpb.PayloadType_COMPRESSABLE,
   409  		ResponseSize:   int32(largeRespSize),
   410  		Payload:        pl,
   411  		FillUsername:   true,
   412  		FillOauthScope: true,
   413  	}
   414  	reply, err := tc.UnaryCall(ctx, req)
   415  	if err != nil {
   416  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   417  	}
   418  	if reply.GetUsername() != defaultServiceAccount {
   419  		logger.Fatalf("Got user name %q; wanted %q. ", reply.GetUsername(), defaultServiceAccount)
   420  	}
   421  }
   422  
   423  // DoComputeEngineChannelCredentials performs a unary RPC with compute engine channel credentials
   424  func DoComputeEngineChannelCredentials(ctx context.Context, tc testgrpc.TestServiceClient, defaultServiceAccount string) {
   425  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
   426  	req := &testpb.SimpleRequest{
   427  		ResponseType:   testpb.PayloadType_COMPRESSABLE,
   428  		ResponseSize:   int32(largeRespSize),
   429  		Payload:        pl,
   430  		FillUsername:   true,
   431  		FillOauthScope: true,
   432  	}
   433  	reply, err := tc.UnaryCall(ctx, req)
   434  	if err != nil {
   435  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   436  	}
   437  	if reply.GetUsername() != defaultServiceAccount {
   438  		logger.Fatalf("Got user name %q; wanted %q. ", reply.GetUsername(), defaultServiceAccount)
   439  	}
   440  }
   441  
   442  var testMetadata = metadata.MD{
   443  	"key1": []string{"value1"},
   444  	"key2": []string{"value2"},
   445  }
   446  
   447  // DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent.
   448  func DoCancelAfterBegin(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   449  	ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(ctx, testMetadata))
   450  	stream, err := tc.StreamingInputCall(ctx, args...)
   451  	if err != nil {
   452  		logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
   453  	}
   454  	cancel()
   455  	_, err = stream.CloseAndRecv()
   456  	if status.Code(err) != codes.Canceled {
   457  		logger.Fatalf("%v.CloseAndRecv() got error code %d, want %d", stream, status.Code(err), codes.Canceled)
   458  	}
   459  }
   460  
   461  // DoCancelAfterFirstResponse cancels the RPC after receiving the first message from the server.
   462  func DoCancelAfterFirstResponse(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   463  	ctx, cancel := context.WithCancel(ctx)
   464  	stream, err := tc.FullDuplexCall(ctx, args...)
   465  	if err != nil {
   466  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
   467  	}
   468  	respParam := []*testpb.ResponseParameters{
   469  		{
   470  			Size: 31415,
   471  		},
   472  	}
   473  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 27182)
   474  	req := &testpb.StreamingOutputCallRequest{
   475  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   476  		ResponseParameters: respParam,
   477  		Payload:            pl,
   478  	}
   479  	if err := stream.Send(req); err != nil {
   480  		logger.Fatalf("%v has error %v while sending %v", stream, err, req)
   481  	}
   482  	if _, err := stream.Recv(); err != nil {
   483  		logger.Fatalf("%v.Recv() = %v", stream, err)
   484  	}
   485  	cancel()
   486  	if _, err := stream.Recv(); status.Code(err) != codes.Canceled {
   487  		logger.Fatalf("%v compleled with error code %d, want %d", stream, status.Code(err), codes.Canceled)
   488  	}
   489  }
   490  
   491  var (
   492  	initialMetadataValue  = "test_initial_metadata_value"
   493  	trailingMetadataValue = "\x0a\x0b\x0a\x0b\x0a\x0b"
   494  	customMetadata        = metadata.Pairs(
   495  		initialMetadataKey, initialMetadataValue,
   496  		trailingMetadataKey, trailingMetadataValue,
   497  	)
   498  )
   499  
   500  func validateMetadata(header, trailer metadata.MD) {
   501  	if len(header[initialMetadataKey]) != 1 {
   502  		logger.Fatalf("Expected exactly one header from server. Received %d", len(header[initialMetadataKey]))
   503  	}
   504  	if header[initialMetadataKey][0] != initialMetadataValue {
   505  		logger.Fatalf("Got header %s; want %s", header[initialMetadataKey][0], initialMetadataValue)
   506  	}
   507  	if len(trailer[trailingMetadataKey]) != 1 {
   508  		logger.Fatalf("Expected exactly one trailer from server. Received %d", len(trailer[trailingMetadataKey]))
   509  	}
   510  	if trailer[trailingMetadataKey][0] != trailingMetadataValue {
   511  		logger.Fatalf("Got trailer %s; want %s", trailer[trailingMetadataKey][0], trailingMetadataValue)
   512  	}
   513  }
   514  
   515  // DoCustomMetadata checks that metadata is echoed back to the client.
   516  func DoCustomMetadata(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   517  	// Testing with UnaryCall.
   518  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 1)
   519  	req := &testpb.SimpleRequest{
   520  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   521  		ResponseSize: int32(1),
   522  		Payload:      pl,
   523  	}
   524  	ctx = metadata.NewOutgoingContext(ctx, customMetadata)
   525  	var header, trailer metadata.MD
   526  	args = append(args, grpc.Header(&header), grpc.Trailer(&trailer))
   527  	reply, err := tc.UnaryCall(
   528  		ctx,
   529  		req,
   530  		args...,
   531  	)
   532  	if err != nil {
   533  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
   534  	}
   535  	t := reply.GetPayload().GetType()
   536  	s := len(reply.GetPayload().GetBody())
   537  	if t != testpb.PayloadType_COMPRESSABLE || s != 1 {
   538  		logger.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, 1)
   539  	}
   540  	validateMetadata(header, trailer)
   541  
   542  	// Testing with FullDuplex.
   543  	stream, err := tc.FullDuplexCall(ctx, args...)
   544  	if err != nil {
   545  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   546  	}
   547  	respParam := []*testpb.ResponseParameters{
   548  		{
   549  			Size: 1,
   550  		},
   551  	}
   552  	streamReq := &testpb.StreamingOutputCallRequest{
   553  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   554  		ResponseParameters: respParam,
   555  		Payload:            pl,
   556  	}
   557  	if err := stream.Send(streamReq); err != nil {
   558  		logger.Fatalf("%v has error %v while sending %v", stream, err, streamReq)
   559  	}
   560  	streamHeader, err := stream.Header()
   561  	if err != nil {
   562  		logger.Fatalf("%v.Header() = %v", stream, err)
   563  	}
   564  	if _, err := stream.Recv(); err != nil {
   565  		logger.Fatalf("%v.Recv() = %v", stream, err)
   566  	}
   567  	if err := stream.CloseSend(); err != nil {
   568  		logger.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
   569  	}
   570  	if _, err := stream.Recv(); err != io.EOF {
   571  		logger.Fatalf("%v failed to complete the custom metadata test: %v", stream, err)
   572  	}
   573  	streamTrailer := stream.Trailer()
   574  	validateMetadata(streamHeader, streamTrailer)
   575  }
   576  
   577  // DoStatusCodeAndMessage checks that the status code is propagated back to the client.
   578  func DoStatusCodeAndMessage(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   579  	var code int32 = 2
   580  	msg := "test status message"
   581  	expectedErr := status.Error(codes.Code(code), msg)
   582  	respStatus := &testpb.EchoStatus{
   583  		Code:    code,
   584  		Message: msg,
   585  	}
   586  	// Test UnaryCall.
   587  	req := &testpb.SimpleRequest{
   588  		ResponseStatus: respStatus,
   589  	}
   590  	if _, err := tc.UnaryCall(ctx, req, args...); err == nil || err.Error() != expectedErr.Error() {
   591  		logger.Fatalf("%v.UnaryCall(_, %v) = _, %v, want _, %v", tc, req, err, expectedErr)
   592  	}
   593  	// Test FullDuplexCall.
   594  	stream, err := tc.FullDuplexCall(ctx, args...)
   595  	if err != nil {
   596  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
   597  	}
   598  	streamReq := &testpb.StreamingOutputCallRequest{
   599  		ResponseStatus: respStatus,
   600  	}
   601  	if err := stream.Send(streamReq); err != nil {
   602  		logger.Fatalf("%v has error %v while sending %v, want <nil>", stream, err, streamReq)
   603  	}
   604  	if err := stream.CloseSend(); err != nil {
   605  		logger.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
   606  	}
   607  	if _, err = stream.Recv(); err.Error() != expectedErr.Error() {
   608  		logger.Fatalf("%v.Recv() returned error %v, want %v", stream, err, expectedErr)
   609  	}
   610  }
   611  
   612  // DoSpecialStatusMessage verifies Unicode and whitespace is correctly processed
   613  // in status message.
   614  func DoSpecialStatusMessage(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) {
   615  	const (
   616  		code int32  = 2
   617  		msg  string = "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n"
   618  	)
   619  	expectedErr := status.Error(codes.Code(code), msg)
   620  	req := &testpb.SimpleRequest{
   621  		ResponseStatus: &testpb.EchoStatus{
   622  			Code:    code,
   623  			Message: msg,
   624  		},
   625  	}
   626  	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
   627  	defer cancel()
   628  	if _, err := tc.UnaryCall(ctx, req, args...); err == nil || err.Error() != expectedErr.Error() {
   629  		logger.Fatalf("%v.UnaryCall(_, %v) = _, %v, want _, %v", tc, req, err, expectedErr)
   630  	}
   631  }
   632  
   633  // DoUnimplementedService attempts to call a method from an unimplemented service.
   634  func DoUnimplementedService(ctx context.Context, tc testgrpc.UnimplementedServiceClient) {
   635  	_, err := tc.UnimplementedCall(ctx, &testpb.Empty{})
   636  	if status.Code(err) != codes.Unimplemented {
   637  		logger.Fatalf("%v.UnimplementedCall() = _, %v, want _, %v", tc, status.Code(err), codes.Unimplemented)
   638  	}
   639  }
   640  
   641  // DoUnimplementedMethod attempts to call an unimplemented method.
   642  func DoUnimplementedMethod(ctx context.Context, cc *grpc.ClientConn) {
   643  	var req, reply proto.Message
   644  	if err := cc.Invoke(ctx, "/grpc.testing.TestService/UnimplementedCall", req, reply); err == nil || status.Code(err) != codes.Unimplemented {
   645  		logger.Fatalf("ClientConn.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
   646  	}
   647  }
   648  
   649  // DoPickFirstUnary runs multiple RPCs (rpcCount) and checks that all requests
   650  // are sent to the same backend.
   651  func DoPickFirstUnary(ctx context.Context, tc testgrpc.TestServiceClient) {
   652  	const rpcCount = 100
   653  
   654  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 1)
   655  	req := &testpb.SimpleRequest{
   656  		ResponseType: testpb.PayloadType_COMPRESSABLE,
   657  		ResponseSize: int32(1),
   658  		Payload:      pl,
   659  		FillServerId: true,
   660  	}
   661  	// TODO(mohanli): Revert the timeout back to 10s once TD migrates to xdstp.
   662  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   663  	defer cancel()
   664  	var serverID string
   665  	for i := 0; i < rpcCount; i++ {
   666  		resp, err := tc.UnaryCall(ctx, req)
   667  		if err != nil {
   668  			logger.Fatalf("iteration %d, failed to do UnaryCall: %v", i, err)
   669  		}
   670  		id := resp.ServerId
   671  		if id == "" {
   672  			logger.Fatalf("iteration %d, got empty server ID", i)
   673  		}
   674  		if i == 0 {
   675  			serverID = id
   676  			continue
   677  		}
   678  		if serverID != id {
   679  			logger.Fatalf("iteration %d, got different server ids: %q vs %q", i, serverID, id)
   680  		}
   681  	}
   682  }
   683  
   684  type testServer struct {
   685  	testgrpc.UnimplementedTestServiceServer
   686  
   687  	orcaMu          sync.Mutex
   688  	metricsRecorder orca.ServerMetricsRecorder
   689  }
   690  
   691  // NewTestServerOptions contains options that control the behavior of the test
   692  // server returned by NewTestServer.
   693  type NewTestServerOptions struct {
   694  	MetricsRecorder orca.ServerMetricsRecorder
   695  }
   696  
   697  // NewTestServer creates a test server for test service.  opts carries optional
   698  // settings and does not need to be provided.  If multiple opts are provided,
   699  // only the first one is used.
   700  func NewTestServer(opts ...NewTestServerOptions) testgrpc.TestServiceServer {
   701  	if len(opts) > 0 {
   702  		return &testServer{metricsRecorder: opts[0].MetricsRecorder}
   703  	}
   704  	return &testServer{}
   705  }
   706  
   707  func (s *testServer) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   708  	return new(testpb.Empty), nil
   709  }
   710  
   711  func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) {
   712  	if size < 0 {
   713  		return nil, fmt.Errorf("requested a response with invalid length %d", size)
   714  	}
   715  	body := make([]byte, size)
   716  	switch t {
   717  	case testpb.PayloadType_COMPRESSABLE:
   718  	default:
   719  		return nil, fmt.Errorf("unsupported payload type: %d", t)
   720  	}
   721  	return &testpb.Payload{
   722  		Type: t,
   723  		Body: body,
   724  	}, nil
   725  }
   726  
   727  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   728  	st := in.GetResponseStatus()
   729  	if md, ok := metadata.FromIncomingContext(ctx); ok {
   730  		if initialMetadata, ok := md[initialMetadataKey]; ok {
   731  			header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
   732  			grpc.SendHeader(ctx, header)
   733  		}
   734  		if trailingMetadata, ok := md[trailingMetadataKey]; ok {
   735  			trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
   736  			grpc.SetTrailer(ctx, trailer)
   737  		}
   738  	}
   739  	if st != nil && st.Code != 0 {
   740  		return nil, status.Error(codes.Code(st.Code), st.Message)
   741  	}
   742  	pl, err := serverNewPayload(in.GetResponseType(), in.GetResponseSize())
   743  	if err != nil {
   744  		return nil, err
   745  	}
   746  	if r, orcaData := orca.CallMetricsRecorderFromContext(ctx), in.GetOrcaPerQueryReport(); r != nil && orcaData != nil {
   747  		// Transfer the request's per-Call ORCA data to the call metrics
   748  		// recorder in the context, if present.
   749  		setORCAMetrics(r, orcaData)
   750  	}
   751  	return &testpb.SimpleResponse{
   752  		Payload: pl,
   753  	}, nil
   754  }
   755  
   756  func setORCAMetrics(r orca.ServerMetricsRecorder, orcaData *testpb.TestOrcaReport) {
   757  	r.SetCPUUtilization(orcaData.CpuUtilization)
   758  	r.SetMemoryUtilization(orcaData.MemoryUtilization)
   759  	if rq, ok := r.(orca.CallMetricsRecorder); ok {
   760  		for k, v := range orcaData.RequestCost {
   761  			rq.SetRequestCost(k, v)
   762  		}
   763  	}
   764  	for k, v := range orcaData.Utilization {
   765  		r.SetNamedUtilization(k, v)
   766  	}
   767  }
   768  
   769  func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
   770  	cs := args.GetResponseParameters()
   771  	for _, c := range cs {
   772  		if us := c.GetIntervalUs(); us > 0 {
   773  			time.Sleep(time.Duration(us) * time.Microsecond)
   774  		}
   775  		pl, err := serverNewPayload(args.GetResponseType(), c.GetSize())
   776  		if err != nil {
   777  			return err
   778  		}
   779  		if err := stream.Send(&testpb.StreamingOutputCallResponse{
   780  			Payload: pl,
   781  		}); err != nil {
   782  			return err
   783  		}
   784  	}
   785  	return nil
   786  }
   787  
   788  func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
   789  	var sum int
   790  	for {
   791  		in, err := stream.Recv()
   792  		if err == io.EOF {
   793  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{
   794  				AggregatedPayloadSize: int32(sum),
   795  			})
   796  		}
   797  		if err != nil {
   798  			return err
   799  		}
   800  		p := in.GetPayload().GetBody()
   801  		sum += len(p)
   802  	}
   803  }
   804  
   805  func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
   806  	if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
   807  		if initialMetadata, ok := md[initialMetadataKey]; ok {
   808  			header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
   809  			stream.SendHeader(header)
   810  		}
   811  		if trailingMetadata, ok := md[trailingMetadataKey]; ok {
   812  			trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
   813  			stream.SetTrailer(trailer)
   814  		}
   815  	}
   816  	hasORCALock := false
   817  	for {
   818  		in, err := stream.Recv()
   819  		if err == io.EOF {
   820  			// read done.
   821  			return nil
   822  		}
   823  		if err != nil {
   824  			return err
   825  		}
   826  		st := in.GetResponseStatus()
   827  		if st != nil && st.Code != 0 {
   828  			return status.Error(codes.Code(st.Code), st.Message)
   829  		}
   830  
   831  		if r, orcaData := s.metricsRecorder, in.GetOrcaOobReport(); r != nil && orcaData != nil {
   832  			// Transfer the request's OOB ORCA data to the server metrics recorder
   833  			// in the server, if present.
   834  			if !hasORCALock {
   835  				s.orcaMu.Lock()
   836  				defer s.orcaMu.Unlock()
   837  				hasORCALock = true
   838  			}
   839  			setORCAMetrics(r, orcaData)
   840  		}
   841  
   842  		cs := in.GetResponseParameters()
   843  		for _, c := range cs {
   844  			if us := c.GetIntervalUs(); us > 0 {
   845  				time.Sleep(time.Duration(us) * time.Microsecond)
   846  			}
   847  			pl, err := serverNewPayload(in.GetResponseType(), c.GetSize())
   848  			if err != nil {
   849  				return err
   850  			}
   851  			if err := stream.Send(&testpb.StreamingOutputCallResponse{
   852  				Payload: pl,
   853  			}); err != nil {
   854  				return err
   855  			}
   856  		}
   857  	}
   858  }
   859  
   860  func (s *testServer) HalfDuplexCall(stream testgrpc.TestService_HalfDuplexCallServer) error {
   861  	var msgBuf []*testpb.StreamingOutputCallRequest
   862  	for {
   863  		in, err := stream.Recv()
   864  		if err == io.EOF {
   865  			// read done.
   866  			break
   867  		}
   868  		if err != nil {
   869  			return err
   870  		}
   871  		msgBuf = append(msgBuf, in)
   872  	}
   873  	for _, m := range msgBuf {
   874  		cs := m.GetResponseParameters()
   875  		for _, c := range cs {
   876  			if us := c.GetIntervalUs(); us > 0 {
   877  				time.Sleep(time.Duration(us) * time.Microsecond)
   878  			}
   879  			pl, err := serverNewPayload(m.GetResponseType(), c.GetSize())
   880  			if err != nil {
   881  				return err
   882  			}
   883  			if err := stream.Send(&testpb.StreamingOutputCallResponse{
   884  				Payload: pl,
   885  			}); err != nil {
   886  				return err
   887  			}
   888  		}
   889  	}
   890  	return nil
   891  }
   892  
   893  // DoORCAPerRPCTest performs a unary RPC that enables ORCA per-call reporting
   894  // and verifies the load report sent back to the LB policy's Done callback.
   895  func DoORCAPerRPCTest(ctx context.Context, tc testgrpc.TestServiceClient) {
   896  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   897  	defer cancel()
   898  	orcaRes := &v3orcapb.OrcaLoadReport{}
   899  	_, err := tc.UnaryCall(contextWithORCAResult(ctx, &orcaRes), &testpb.SimpleRequest{
   900  		OrcaPerQueryReport: &testpb.TestOrcaReport{
   901  			CpuUtilization:    0.8210,
   902  			MemoryUtilization: 0.5847,
   903  			RequestCost:       map[string]float64{"cost": 3456.32},
   904  			Utilization:       map[string]float64{"util": 0.30499},
   905  		},
   906  	})
   907  	if err != nil {
   908  		logger.Fatalf("/TestService/UnaryCall RPC failed: ", err)
   909  	}
   910  	want := &v3orcapb.OrcaLoadReport{
   911  		CpuUtilization: 0.8210,
   912  		MemUtilization: 0.5847,
   913  		RequestCost:    map[string]float64{"cost": 3456.32},
   914  		Utilization:    map[string]float64{"util": 0.30499},
   915  	}
   916  	if !proto.Equal(orcaRes, want) {
   917  		logger.Fatalf("/TestService/UnaryCall RPC received ORCA load report %+v; want %+v", orcaRes, want)
   918  	}
   919  }
   920  
   921  // DoORCAOOBTest performs a streaming RPC that enables ORCA OOB reporting and
   922  // verifies the load report sent to the LB policy's OOB listener.
   923  func DoORCAOOBTest(ctx context.Context, tc testgrpc.TestServiceClient) {
   924  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   925  	defer cancel()
   926  	stream, err := tc.FullDuplexCall(ctx)
   927  	if err != nil {
   928  		logger.Fatalf("/TestService/FullDuplexCall received error starting stream: %v", err)
   929  	}
   930  	err = stream.Send(&testpb.StreamingOutputCallRequest{
   931  		OrcaOobReport: &testpb.TestOrcaReport{
   932  			CpuUtilization:    0.8210,
   933  			MemoryUtilization: 0.5847,
   934  			Utilization:       map[string]float64{"util": 0.30499},
   935  		},
   936  		ResponseParameters: []*testpb.ResponseParameters{{Size: 1}},
   937  	})
   938  	if err != nil {
   939  		logger.Fatalf("/TestService/FullDuplexCall received error sending: %v", err)
   940  	}
   941  	_, err = stream.Recv()
   942  	if err != nil {
   943  		logger.Fatalf("/TestService/FullDuplexCall received error receiving: %v", err)
   944  	}
   945  
   946  	want := &v3orcapb.OrcaLoadReport{
   947  		CpuUtilization: 0.8210,
   948  		MemUtilization: 0.5847,
   949  		Utilization:    map[string]float64{"util": 0.30499},
   950  	}
   951  	checkORCAMetrics(ctx, tc, want)
   952  
   953  	err = stream.Send(&testpb.StreamingOutputCallRequest{
   954  		OrcaOobReport: &testpb.TestOrcaReport{
   955  			CpuUtilization:    0.29309,
   956  			MemoryUtilization: 0.2,
   957  			Utilization:       map[string]float64{"util": 0.2039},
   958  		},
   959  		ResponseParameters: []*testpb.ResponseParameters{{Size: 1}},
   960  	})
   961  	if err != nil {
   962  		logger.Fatalf("/TestService/FullDuplexCall received error sending: %v", err)
   963  	}
   964  	_, err = stream.Recv()
   965  	if err != nil {
   966  		logger.Fatalf("/TestService/FullDuplexCall received error receiving: %v", err)
   967  	}
   968  
   969  	want = &v3orcapb.OrcaLoadReport{
   970  		CpuUtilization: 0.29309,
   971  		MemUtilization: 0.2,
   972  		Utilization:    map[string]float64{"util": 0.2039},
   973  	}
   974  	checkORCAMetrics(ctx, tc, want)
   975  }
   976  
   977  func checkORCAMetrics(ctx context.Context, tc testgrpc.TestServiceClient, want *v3orcapb.OrcaLoadReport) {
   978  	for ctx.Err() == nil {
   979  		orcaRes := &v3orcapb.OrcaLoadReport{}
   980  		if _, err := tc.UnaryCall(contextWithORCAResult(ctx, &orcaRes), &testpb.SimpleRequest{}); err != nil {
   981  			logger.Fatalf("/TestService/UnaryCall RPC failed: ", err)
   982  		}
   983  		if proto.Equal(orcaRes, want) {
   984  			return
   985  		}
   986  		logger.Infof("/TestService/UnaryCall RPC received ORCA load report %+v; want %+v", orcaRes, want)
   987  		time.Sleep(time.Second)
   988  	}
   989  	logger.Fatalf("timed out waiting for expected ORCA load report")
   990  }