google.golang.org/grpc@v1.62.1/experimental/shared_buffer_pool_test.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package experimental_test
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"io"
    25  	"testing"
    26  	"time"
    27  
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/encoding/gzip"
    30  	"google.golang.org/grpc/experimental"
    31  	"google.golang.org/grpc/internal/grpctest"
    32  	"google.golang.org/grpc/internal/stubserver"
    33  
    34  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    35  )
    36  
    37  type s struct {
    38  	grpctest.Tester
    39  }
    40  
    41  func Test(t *testing.T) {
    42  	grpctest.RunSubTests(t, s{})
    43  }
    44  
    45  const defaultTestTimeout = 10 * time.Second
    46  
    47  func (s) TestRecvBufferPoolStream(t *testing.T) {
    48  	tcs := []struct {
    49  		name     string
    50  		callOpts []grpc.CallOption
    51  	}{
    52  		{
    53  			name: "default",
    54  		},
    55  		{
    56  			name: "useCompressor",
    57  			callOpts: []grpc.CallOption{
    58  				grpc.UseCompressor(gzip.Name),
    59  			},
    60  		},
    61  	}
    62  
    63  	for _, tc := range tcs {
    64  		t.Run(tc.name, func(t *testing.T) {
    65  			const reqCount = 10
    66  
    67  			ss := &stubserver.StubServer{
    68  				FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
    69  					for i := 0; i < reqCount; i++ {
    70  						preparedMsg := &grpc.PreparedMsg{}
    71  						if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
    72  							Payload: &testgrpc.Payload{
    73  								Body: []byte{'0' + uint8(i)},
    74  							},
    75  						}); err != nil {
    76  							return err
    77  						}
    78  						stream.SendMsg(preparedMsg)
    79  					}
    80  					return nil
    81  				},
    82  			}
    83  
    84  			pool := &checkBufferPool{}
    85  			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
    86  			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
    87  			if err := ss.Start(sopts, dopts...); err != nil {
    88  				t.Fatalf("Error starting endpoint server: %v", err)
    89  			}
    90  			defer ss.Stop()
    91  
    92  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    93  			defer cancel()
    94  
    95  			stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
    96  			if err != nil {
    97  				t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
    98  			}
    99  
   100  			var ngot int
   101  			var buf bytes.Buffer
   102  			for {
   103  				reply, err := stream.Recv()
   104  				if err == io.EOF {
   105  					break
   106  				}
   107  				if err != nil {
   108  					t.Fatal(err)
   109  				}
   110  				ngot++
   111  				if buf.Len() > 0 {
   112  					buf.WriteByte(',')
   113  				}
   114  				buf.Write(reply.GetPayload().GetBody())
   115  			}
   116  			if want := 10; ngot != want {
   117  				t.Fatalf("Got %d replies, want %d", ngot, want)
   118  			}
   119  			if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
   120  				t.Fatalf("Got replies %q; want %q", got, want)
   121  			}
   122  
   123  			if len(pool.puts) != reqCount {
   124  				t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
   125  			}
   126  		})
   127  	}
   128  }
   129  
   130  func (s) TestRecvBufferPoolUnary(t *testing.T) {
   131  	tcs := []struct {
   132  		name     string
   133  		callOpts []grpc.CallOption
   134  	}{
   135  		{
   136  			name: "default",
   137  		},
   138  		{
   139  			name: "useCompressor",
   140  			callOpts: []grpc.CallOption{
   141  				grpc.UseCompressor(gzip.Name),
   142  			},
   143  		},
   144  	}
   145  
   146  	for _, tc := range tcs {
   147  		t.Run(tc.name, func(t *testing.T) {
   148  			const largeSize = 1024
   149  
   150  			ss := &stubserver.StubServer{
   151  				UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
   152  					return &testgrpc.SimpleResponse{
   153  						Payload: &testgrpc.Payload{
   154  							Body: make([]byte, largeSize),
   155  						},
   156  					}, nil
   157  				},
   158  			}
   159  
   160  			pool := &checkBufferPool{}
   161  			sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
   162  			dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
   163  			if err := ss.Start(sopts, dopts...); err != nil {
   164  				t.Fatalf("Error starting endpoint server: %v", err)
   165  			}
   166  			defer ss.Stop()
   167  
   168  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   169  			defer cancel()
   170  
   171  			const reqCount = 10
   172  			for i := 0; i < reqCount; i++ {
   173  				if _, err := ss.Client.UnaryCall(
   174  					ctx,
   175  					&testgrpc.SimpleRequest{
   176  						Payload: &testgrpc.Payload{
   177  							Body: make([]byte, largeSize),
   178  						},
   179  					},
   180  					tc.callOpts...,
   181  				); err != nil {
   182  					t.Fatalf("ss.Client.UnaryCall failed: %v", err)
   183  				}
   184  			}
   185  
   186  			const bufferCount = reqCount * 2 // req + resp
   187  			if len(pool.puts) != bufferCount {
   188  				t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
   189  			}
   190  		})
   191  	}
   192  }
   193  
   194  type checkBufferPool struct {
   195  	puts [][]byte
   196  }
   197  
   198  func (p *checkBufferPool) Get(size int) []byte {
   199  	return make([]byte, size)
   200  }
   201  
   202  func (p *checkBufferPool) Put(bs *[]byte) {
   203  	p.puts = append(p.puts, *bs)
   204  }