github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/stream/stream_test.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stream
    16  
    17  import (
    18  	"context"
    19  	"net"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  type mockStream struct {
    29  	ctx  context.Context
    30  	conn net.Conn
    31  }
    32  
    33  func newMockStream(ctx context.Context, conn net.Conn) *mockStream {
    34  	return &mockStream{
    35  		ctx:  ctx,
    36  		conn: conn,
    37  	}
    38  }
    39  
    40  func (m *mockStream) Context() context.Context {
    41  	return m.ctx
    42  }
    43  
    44  func (m *mockStream) Send(b []byte) error {
    45  	_, err := m.conn.Write(b)
    46  	return err
    47  }
    48  
    49  func (m *mockStream) Recv() ([]byte, error) {
    50  	b := make([]byte, 2*MaxChunkSize)
    51  	n, err := m.conn.Read(b)
    52  	return b[:n], err
    53  }
    54  
    55  func newStreamPipe(t *testing.T) (*ReadWriter, net.Conn) {
    56  	local, remote := net.Pipe()
    57  	stream := newMockStream(context.Background(), remote)
    58  
    59  	timeout := time.Now().Add(time.Second * 5)
    60  
    61  	require.NoError(t, local.SetReadDeadline(timeout))
    62  	require.NoError(t, local.SetWriteDeadline(timeout))
    63  	require.NoError(t, remote.SetReadDeadline(timeout))
    64  	require.NoError(t, remote.SetWriteDeadline(timeout))
    65  
    66  	streamConn, err := NewReadWriter(stream)
    67  	require.NoError(t, err)
    68  
    69  	return streamConn, local
    70  }
    71  
    72  func TestReadWriter_Write(t *testing.T) {
    73  	streamConn, local := newStreamPipe(t)
    74  	wg := &sync.WaitGroup{}
    75  	wg.Add(2)
    76  
    77  	data := []byte("hello world!")
    78  	go func() {
    79  		defer wg.Done()
    80  		n, err := streamConn.Write(data)
    81  		assert.NoError(t, err)
    82  		assert.Len(t, data, n)
    83  	}()
    84  	go func() {
    85  		defer wg.Done()
    86  		b := make([]byte, 2*MaxChunkSize)
    87  		n, err := local.Read(b)
    88  		assert.NoError(t, err)
    89  		assert.Len(t, data, n)
    90  		assert.Equal(t, data, b[:n])
    91  	}()
    92  
    93  	wg.Wait()
    94  }
    95  
    96  func TestReadWriter_WriteChunk(t *testing.T) {
    97  	streamConn, local := newStreamPipe(t)
    98  	wg := &sync.WaitGroup{}
    99  	wg.Add(2)
   100  
   101  	data := make([]byte, MaxChunkSize+1)
   102  	go func() {
   103  		defer wg.Done()
   104  		n, err := streamConn.Write(data)
   105  		assert.NoError(t, err)
   106  		assert.Len(t, data, n)
   107  	}()
   108  	go func() {
   109  		defer wg.Done()
   110  		b := make([]byte, 2*MaxChunkSize)
   111  		n, err := local.Read(b)
   112  		assert.NoError(t, err)
   113  		assert.Equal(t, MaxChunkSize, n)
   114  		assert.Equal(t, data[:n], b[:n])
   115  
   116  		n, err = local.Read(b)
   117  		assert.NoError(t, err)
   118  		assert.Equal(t, 1, n)
   119  		assert.Equal(t, data[:n], b[:n])
   120  	}()
   121  
   122  	wg.Wait()
   123  }
   124  
   125  func TestReadWriter_Read(t *testing.T) {
   126  	streamConn, local := newStreamPipe(t)
   127  	wg := &sync.WaitGroup{}
   128  	wg.Add(2)
   129  
   130  	data := make([]byte, MaxChunkSize+1)
   131  	go func() {
   132  		b := make([]byte, 2*MaxChunkSize)
   133  		defer wg.Done()
   134  		n, err := streamConn.Read(b)
   135  		assert.NoError(t, err)
   136  		assert.Len(t, data, n)
   137  		assert.Equal(t, data, b[:n])
   138  	}()
   139  	go func() {
   140  		defer wg.Done()
   141  		n, err := local.Write(data)
   142  		assert.NoError(t, err)
   143  		assert.Len(t, data, n)
   144  	}()
   145  
   146  	wg.Wait()
   147  }