google.golang.org/grpc@v1.72.2/experimental/shared_buffer_pool_test.go (about)

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package experimental_test
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"io"
    25  	"testing"
    26  	"time"
    27  
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/encoding/gzip"
    30  	"google.golang.org/grpc/experimental"
    31  	"google.golang.org/grpc/internal/grpctest"
    32  	"google.golang.org/grpc/internal/stubserver"
    33  
    34  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    35  	testpb "google.golang.org/grpc/interop/grpc_testing"
    36  )
    37  
    38  type s struct {
    39  	grpctest.Tester
    40  }
    41  
    42  func Test(t *testing.T) {
    43  	grpctest.RunSubTests(t, s{})
    44  }
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  func (s) TestRecvBufferPoolStream(t *testing.T) {
    49  	// TODO: How much of this test can be preserved now that buffer reuse happens at
    50  	// the codec and HTTP/2 level?
    51  	t.SkipNow()
    52  	tcs := []struct {
    53  		name     string
    54  		callOpts []grpc.CallOption
    55  	}{
    56  		{
    57  			name: "default",
    58  		},
    59  		{
    60  			name: "useCompressor",
    61  			callOpts: []grpc.CallOption{
    62  				grpc.UseCompressor(gzip.Name),
    63  			},
    64  		},
    65  	}
    66  
    67  	for _, tc := range tcs {
    68  		t.Run(tc.name, func(t *testing.T) {
    69  			const reqCount = 10
    70  
    71  			ss := &stubserver.StubServer{
    72  				FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
    73  					for i := 0; i < reqCount; i++ {
    74  						preparedMsg := &grpc.PreparedMsg{}
    75  						if err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
    76  							Payload: &testpb.Payload{
    77  								Body: []byte{'0' + uint8(i)},
    78  							},
    79  						}); err != nil {
    80  							return err
    81  						}
    82  						stream.SendMsg(preparedMsg)
    83  					}
    84  					return nil
    85  				},
    86  			}
    87  
    88  			pool := &checkBufferPool{}
    89  			sopts := []grpc.ServerOption{experimental.BufferPool(pool)}
    90  			dopts := []grpc.DialOption{experimental.WithBufferPool(pool)}
    91  			if err := ss.Start(sopts, dopts...); err != nil {
    92  				t.Fatalf("Error starting endpoint server: %v", err)
    93  			}
    94  			defer ss.Stop()
    95  
    96  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    97  			defer cancel()
    98  
    99  			stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
   100  			if err != nil {
   101  				t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
   102  			}
   103  
   104  			var ngot int
   105  			var buf bytes.Buffer
   106  			for {
   107  				reply, err := stream.Recv()
   108  				if err == io.EOF {
   109  					break
   110  				}
   111  				if err != nil {
   112  					t.Fatal(err)
   113  				}
   114  				ngot++
   115  				if buf.Len() > 0 {
   116  					buf.WriteByte(',')
   117  				}
   118  				buf.Write(reply.GetPayload().GetBody())
   119  			}
   120  			if want := 10; ngot != want {
   121  				t.Fatalf("Got %d replies, want %d", ngot, want)
   122  			}
   123  			if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
   124  				t.Fatalf("Got replies %q; want %q", got, want)
   125  			}
   126  
   127  			if len(pool.puts) != reqCount {
   128  				t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
   129  			}
   130  		})
   131  	}
   132  }
   133  
   134  func (s) TestRecvBufferPoolUnary(t *testing.T) {
   135  	// TODO: See above
   136  	t.SkipNow()
   137  	tcs := []struct {
   138  		name     string
   139  		callOpts []grpc.CallOption
   140  	}{
   141  		{
   142  			name: "default",
   143  		},
   144  		{
   145  			name: "useCompressor",
   146  			callOpts: []grpc.CallOption{
   147  				grpc.UseCompressor(gzip.Name),
   148  			},
   149  		},
   150  	}
   151  
   152  	for _, tc := range tcs {
   153  		t.Run(tc.name, func(t *testing.T) {
   154  			const largeSize = 1024
   155  
   156  			ss := &stubserver.StubServer{
   157  				UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   158  					return &testpb.SimpleResponse{
   159  						Payload: &testpb.Payload{
   160  							Body: make([]byte, largeSize),
   161  						},
   162  					}, nil
   163  				},
   164  			}
   165  
   166  			pool := &checkBufferPool{}
   167  			sopts := []grpc.ServerOption{experimental.BufferPool(pool)}
   168  			dopts := []grpc.DialOption{experimental.WithBufferPool(pool)}
   169  			if err := ss.Start(sopts, dopts...); err != nil {
   170  				t.Fatalf("Error starting endpoint server: %v", err)
   171  			}
   172  			defer ss.Stop()
   173  
   174  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   175  			defer cancel()
   176  
   177  			const reqCount = 10
   178  			for i := 0; i < reqCount; i++ {
   179  				if _, err := ss.Client.UnaryCall(
   180  					ctx,
   181  					&testpb.SimpleRequest{
   182  						Payload: &testpb.Payload{
   183  							Body: make([]byte, largeSize),
   184  						},
   185  					},
   186  					tc.callOpts...,
   187  				); err != nil {
   188  					t.Fatalf("ss.Client.UnaryCall failed: %v", err)
   189  				}
   190  			}
   191  
   192  			const bufferCount = reqCount * 2 // req + resp
   193  			if len(pool.puts) != bufferCount {
   194  				t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
   195  			}
   196  		})
   197  	}
   198  }
   199  
   200  type checkBufferPool struct {
   201  	puts [][]byte
   202  }
   203  
   204  func (p *checkBufferPool) Get(size int) *[]byte {
   205  	b := make([]byte, size)
   206  	return &b
   207  }
   208  
   209  func (p *checkBufferPool) Put(bs *[]byte) {
   210  	p.puts = append(p.puts, *bs)
   211  }