github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/pingconn/pingconn_test.go (about)

     1  // Copyright 2022 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pingconn
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/tls"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"math"
    25  	"net"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/stretchr/testify/require"
    30  
    31  	"github.com/gravitational/teleport/api/fixtures"
    32  )
    33  
    34  type pingConn interface {
    35  	net.Conn
    36  	WritePing() error
    37  }
    38  
    39  func TestPingConnection(t *testing.T) {
    40  	connTypes := []struct {
    41  		name     string
    42  		makeFunc func(t *testing.T) (pingConn, pingConn)
    43  	}{
    44  		{
    45  			name:     "PingConn",
    46  			makeFunc: makePingConn,
    47  		},
    48  		{
    49  			name:     "PingTLSConn",
    50  			makeFunc: makePingTLSConn,
    51  		},
    52  	}
    53  
    54  	for _, connType := range connTypes {
    55  		t.Run(connType.name, func(t *testing.T) {
    56  			t.Run("BufferSize", func(t *testing.T) {
    57  				nWrites := 10
    58  				dataWritten := []byte("message")
    59  
    60  				for _, tt := range []struct {
    61  					desc    string
    62  					bufSize int
    63  				}{
    64  					{desc: "Same", bufSize: len(dataWritten)},
    65  					{desc: "Large", bufSize: len(dataWritten) * 2},
    66  					{desc: "Short", bufSize: len(dataWritten) / 2},
    67  				} {
    68  					t.Run(tt.desc, func(t *testing.T) {
    69  						r, w := makePingConn(t)
    70  
    71  						// Write routine
    72  						errChan := make(chan error, 2)
    73  						go func() {
    74  							defer w.Close()
    75  
    76  							for i := 0; i < nWrites; i++ {
    77  								// Eventually write some pings.
    78  								if i%2 == 0 {
    79  									err := w.WritePing()
    80  									if err != nil {
    81  										errChan <- err
    82  										return
    83  									}
    84  								}
    85  
    86  								_, err := w.Write(dataWritten)
    87  								if err != nil {
    88  									errChan <- err
    89  									return
    90  								}
    91  							}
    92  
    93  							errChan <- nil
    94  						}()
    95  
    96  						// Read routine.
    97  						go func() {
    98  							defer r.Close()
    99  
   100  							buf := make([]byte, tt.bufSize)
   101  
   102  							for i := 0; i < nWrites; i++ {
   103  								var (
   104  									aggregator []byte
   105  									n          int
   106  									err        error
   107  								)
   108  
   109  								for n < len(dataWritten) {
   110  									n, err = r.Read(buf)
   111  									if err != nil {
   112  										errChan <- err
   113  										return
   114  									}
   115  
   116  									aggregator = append(aggregator, buf[:n]...)
   117  								}
   118  
   119  								if !bytes.Equal(dataWritten, aggregator) {
   120  									errChan <- fmt.Errorf("wrong content read, expected '%s', got '%s'", string(dataWritten), string(buf[:n]))
   121  									return
   122  								}
   123  							}
   124  
   125  							errChan <- nil
   126  						}()
   127  
   128  						// Expect routines to finish.
   129  						timer := time.NewTimer(10 * time.Second)
   130  						defer timer.Stop()
   131  						for i := 0; i < 1; i++ {
   132  							select {
   133  							case err := <-errChan:
   134  								require.NoError(t, err)
   135  							case <-timer.C:
   136  								require.Fail(t, "routing didn't finished in time")
   137  							}
   138  						}
   139  					})
   140  				}
   141  			})
   142  
   143  			// Given a connection, read from it concurrently, asserting all content
   144  			// written is read.
   145  			//
   146  			// Messages can be out of order due to concurrent reads. Other tests must
   147  			// guarantee message ordering.
   148  			t.Run("ConcurrentReads", func(t *testing.T) {
   149  				// Number of writes performed.
   150  				nWrites := 10
   151  				// Data that is going to be written/read on the connection.
   152  				dataWritten := []byte("message")
   153  				// Size of each read call.
   154  				readSize := 2
   155  				// Number of reads necessary to read the full message
   156  				readNum := int(math.Ceil(float64(len(dataWritten)) / float64(readSize)))
   157  
   158  				r, w := makePingConn(t)
   159  				defer r.Close()
   160  				defer w.Close() // This call may be a noop, but it's here just in case.
   161  
   162  				// readResult struct is used to store the result of a read.
   163  				type readResult struct {
   164  					data []byte
   165  					err  error
   166  				}
   167  
   168  				// Channel is used to store the result of a read.
   169  				resChan := make(chan readResult)
   170  
   171  				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   172  				defer cancel()
   173  
   174  				// Write routine
   175  				go func() {
   176  					for i := 0; i < nWrites; i++ {
   177  						_, err := w.Write(dataWritten)
   178  						if err != nil {
   179  							return
   180  						}
   181  					}
   182  				}()
   183  
   184  				// Read routines.
   185  				for i := 0; i < nWrites/2; i++ {
   186  					go func() {
   187  						buf := make([]byte, readSize)
   188  						for {
   189  							n, err := r.Read(buf)
   190  							if err != nil {
   191  								switch {
   192  								// Since we're partially reading the message, the last
   193  								// read will return an EOF. In this case, do nothing
   194  								// and send the remaining bytes.
   195  								case errors.Is(err, io.EOF):
   196  								// The connection will be closed only if the test is
   197  								// completed. The read result will be empty, so return
   198  								// the function to complete the goroutine.
   199  								case errors.Is(err, io.ErrClosedPipe):
   200  									return
   201  								// Any other error should fail the test and complete the
   202  								// goroutine.
   203  								default:
   204  									resChan <- readResult{err: err}
   205  									return
   206  								}
   207  							}
   208  
   209  							chanBytes := make([]byte, n)
   210  							copy(chanBytes, buf[:n])
   211  							resChan <- readResult{data: chanBytes}
   212  						}
   213  					}()
   214  				}
   215  
   216  				var aggregator []byte
   217  				for i := 0; i < nWrites; i++ {
   218  					for j := 0; j < readNum; j++ {
   219  						select {
   220  						case <-ctx.Done():
   221  							require.Fail(t, "Failed to read message (context timeout)")
   222  						case res := <-resChan:
   223  							require.NoError(t, res.err, "Failed to read message")
   224  							aggregator = append(aggregator, res.data...)
   225  						}
   226  					}
   227  				}
   228  
   229  				require.Len(t, aggregator, len(dataWritten)*nWrites, "Wrong messages written")
   230  
   231  				require.NoError(t, w.Close())
   232  
   233  				res := <-resChan
   234  				// If there's an error here, it means the error was not io.EOF or io.ErrPipeClosed, as those should have been discarded
   235  				// by the goroutine above. This likely means that the errors in PingConn were wrapped with trace.Wrap, which can break
   236  				// callers of net.Conn.
   237  				require.NoError(t, res.err, "there should be no error on close, check if the errors have been wrapped with trace.Wrap")
   238  			})
   239  
   240  			t.Run("ConcurrentWrites", func(t *testing.T) {
   241  				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   242  				defer cancel()
   243  
   244  				w, r := makeBufferedPingConn(t)
   245  				defer w.Close()
   246  				defer r.Close()
   247  
   248  				nWrites := 10
   249  				dataWritten := []byte("message")
   250  				writeChan := make(chan error)
   251  
   252  				// Start write routines.
   253  				for i := 0; i < nWrites/2; i++ {
   254  					go func() {
   255  						for writes := 0; writes < 2; writes++ {
   256  							err := w.WritePing()
   257  							if err != nil {
   258  								writeChan <- err
   259  								return
   260  							}
   261  
   262  							_, err = w.Write(dataWritten)
   263  							if err != nil {
   264  								writeChan <- err
   265  								return
   266  							}
   267  						}
   268  
   269  						writeChan <- nil
   270  					}()
   271  				}
   272  
   273  				// Expect all writes to succeed.
   274  				for i := 0; i < nWrites/2; i++ {
   275  					select {
   276  					case <-ctx.Done():
   277  						require.Fail(t, "timeout write")
   278  					case err := <-writeChan:
   279  						require.NoError(t, err)
   280  					}
   281  				}
   282  
   283  				// Read all messages.
   284  				buf := make([]byte, len(dataWritten))
   285  				for i := 0; i < nWrites; i++ {
   286  					n, err := r.Read(buf)
   287  					require.NoError(t, err)
   288  					require.Equal(t, dataWritten, buf[:n])
   289  				}
   290  			})
   291  		})
   292  	}
   293  }
   294  
   295  // makePingConn creates a piped ping connection.
   296  func makePingConn(t *testing.T) (pingConn, pingConn) {
   297  	t.Helper()
   298  
   299  	writer, reader := net.Pipe()
   300  	return New(writer), New(reader)
   301  }
   302  
   303  // makePingTLSConn creates a piped TLS ping connection.
   304  func makePingTLSConn(t *testing.T) (pingConn, pingConn) {
   305  	t.Helper()
   306  
   307  	writer, reader := net.Pipe()
   308  	tlsWriter, tlsReader := makeTLSConn(t, writer, reader)
   309  
   310  	return NewTLS(tlsWriter), NewTLS(tlsReader)
   311  }
   312  
   313  // makeBufferedPingConn creates connections to have asynchronous writes.
   314  func makeBufferedPingConn(t *testing.T) (*PingConn, *PingConn) {
   315  	t.Helper()
   316  
   317  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   318  	defer cancel()
   319  
   320  	l, err := net.Listen("tcp", "localhost:0")
   321  	require.NoError(t, err)
   322  
   323  	connChan := make(chan struct {
   324  		net.Conn
   325  		error
   326  	}, 2)
   327  
   328  	// Accept
   329  	go func() {
   330  		conn, err := l.Accept()
   331  		connChan <- struct {
   332  			net.Conn
   333  			error
   334  		}{conn, err}
   335  	}()
   336  
   337  	// Dial
   338  	go func() {
   339  		conn, err := net.Dial("tcp", l.Addr().String())
   340  		connChan <- struct {
   341  			net.Conn
   342  			error
   343  		}{conn, err}
   344  	}()
   345  
   346  	connSlice := make([]net.Conn, 2)
   347  	for i := 0; i < 2; i++ {
   348  		select {
   349  		case <-ctx.Done():
   350  			require.Fail(t, "failed waiting for connections")
   351  		case res := <-connChan:
   352  			require.NoError(t, res.error)
   353  			connSlice[i] = res.Conn
   354  		}
   355  	}
   356  
   357  	tlsConnA, tlsConnB := makeTLSConn(t, connSlice[0], connSlice[1])
   358  	return New(tlsConnA), New(tlsConnB)
   359  }
   360  
   361  // makeTLSConn take two connections (client and server) and wrap them into TLS
   362  // connections.
   363  func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) {
   364  	tlsConnChan := make(chan struct {
   365  		*tls.Conn
   366  		error
   367  	}, 2)
   368  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   369  	defer cancel()
   370  
   371  	cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
   372  	require.NoError(t, err)
   373  
   374  	// Server
   375  	go func() {
   376  		tlsConn := tls.Server(server, &tls.Config{
   377  			Certificates: []tls.Certificate{cert},
   378  		})
   379  		tlsConnChan <- struct {
   380  			*tls.Conn
   381  			error
   382  		}{tlsConn, tlsConn.HandshakeContext(ctx)}
   383  	}()
   384  
   385  	// Client
   386  	go func() {
   387  		tlsConn := tls.Client(client, &tls.Config{InsecureSkipVerify: true})
   388  		tlsConnChan <- struct {
   389  			*tls.Conn
   390  			error
   391  		}{tlsConn, tlsConn.HandshakeContext(ctx)}
   392  	}()
   393  
   394  	tlsConnSlice := make([]*tls.Conn, 2)
   395  	for i := 0; i < 2; i++ {
   396  		select {
   397  		case <-ctx.Done():
   398  			server.Close()
   399  			client.Close()
   400  
   401  			require.Fail(t, "failed waiting for TLS connections", "%d connections remaining", 2-i)
   402  		case res := <-tlsConnChan:
   403  			require.NoError(t, res.error)
   404  			tlsConnSlice[i] = res.Conn
   405  		}
   406  	}
   407  
   408  	return tlsConnSlice[0], tlsConnSlice[1]
   409  }