github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/stream/stream_test.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stream 16 17 import ( 18 "context" 19 "net" 20 "sync" 21 "testing" 22 "time" 23 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 ) 27 28 type mockStream struct { 29 ctx context.Context 30 conn net.Conn 31 } 32 33 func newMockStream(ctx context.Context, conn net.Conn) *mockStream { 34 return &mockStream{ 35 ctx: ctx, 36 conn: conn, 37 } 38 } 39 40 func (m *mockStream) Context() context.Context { 41 return m.ctx 42 } 43 44 func (m *mockStream) Send(b []byte) error { 45 _, err := m.conn.Write(b) 46 return err 47 } 48 49 func (m *mockStream) Recv() ([]byte, error) { 50 b := make([]byte, 2*MaxChunkSize) 51 n, err := m.conn.Read(b) 52 return b[:n], err 53 } 54 55 func newStreamPipe(t *testing.T) (*ReadWriter, net.Conn) { 56 local, remote := net.Pipe() 57 stream := newMockStream(context.Background(), remote) 58 59 timeout := time.Now().Add(time.Second * 5) 60 61 require.NoError(t, local.SetReadDeadline(timeout)) 62 require.NoError(t, local.SetWriteDeadline(timeout)) 63 require.NoError(t, remote.SetReadDeadline(timeout)) 64 require.NoError(t, remote.SetWriteDeadline(timeout)) 65 66 streamConn, err := NewReadWriter(stream) 67 require.NoError(t, err) 68 69 return streamConn, local 70 } 71 72 func TestReadWriter_Write(t *testing.T) { 73 streamConn, local := newStreamPipe(t) 74 wg := &sync.WaitGroup{} 75 wg.Add(2) 76 77 data := []byte("hello world!") 78 go func() { 79 defer wg.Done() 80 n, err := streamConn.Write(data) 81 assert.NoError(t, err) 82 assert.Len(t, data, n) 83 }() 84 go func() { 85 defer wg.Done() 86 b := make([]byte, 2*MaxChunkSize) 87 n, err := local.Read(b) 88 assert.NoError(t, err) 89 assert.Len(t, data, n) 90 assert.Equal(t, data, b[:n]) 91 }() 92 93 wg.Wait() 94 } 95 96 func TestReadWriter_WriteChunk(t *testing.T) { 97 streamConn, local := newStreamPipe(t) 98 wg := &sync.WaitGroup{} 99 wg.Add(2) 100 101 data := make([]byte, MaxChunkSize+1) 102 go func() { 103 defer wg.Done() 104 n, err := streamConn.Write(data) 105 assert.NoError(t, err) 106 assert.Len(t, data, n) 107 }() 108 go func() { 109 defer wg.Done() 110 b := make([]byte, 2*MaxChunkSize) 111 n, err := local.Read(b) 112 assert.NoError(t, err) 113 assert.Equal(t, MaxChunkSize, n) 114 assert.Equal(t, data[:n], b[:n]) 115 116 n, err = local.Read(b) 117 assert.NoError(t, err) 118 assert.Equal(t, 1, n) 119 assert.Equal(t, data[:n], b[:n]) 120 }() 121 122 wg.Wait() 123 } 124 125 func TestReadWriter_Read(t *testing.T) { 126 streamConn, local := newStreamPipe(t) 127 wg := &sync.WaitGroup{} 128 wg.Add(2) 129 130 data := make([]byte, MaxChunkSize+1) 131 go func() { 132 b := make([]byte, 2*MaxChunkSize) 133 defer wg.Done() 134 n, err := streamConn.Read(b) 135 assert.NoError(t, err) 136 assert.Len(t, data, n) 137 assert.Equal(t, data, b[:n]) 138 }() 139 go func() { 140 defer wg.Done() 141 n, err := local.Write(data) 142 assert.NoError(t, err) 143 assert.Len(t, data, n) 144 }() 145 146 wg.Wait() 147 }