github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/pingconn/pingconn_test.go (about) 1 // Copyright 2022 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package pingconn 16 17 import ( 18 "bytes" 19 "context" 20 "crypto/tls" 21 "errors" 22 "fmt" 23 "io" 24 "math" 25 "net" 26 "testing" 27 "time" 28 29 "github.com/stretchr/testify/require" 30 31 "github.com/gravitational/teleport/api/fixtures" 32 ) 33 34 type pingConn interface { 35 net.Conn 36 WritePing() error 37 } 38 39 func TestPingConnection(t *testing.T) { 40 connTypes := []struct { 41 name string 42 makeFunc func(t *testing.T) (pingConn, pingConn) 43 }{ 44 { 45 name: "PingConn", 46 makeFunc: makePingConn, 47 }, 48 { 49 name: "PingTLSConn", 50 makeFunc: makePingTLSConn, 51 }, 52 } 53 54 for _, connType := range connTypes { 55 t.Run(connType.name, func(t *testing.T) { 56 t.Run("BufferSize", func(t *testing.T) { 57 nWrites := 10 58 dataWritten := []byte("message") 59 60 for _, tt := range []struct { 61 desc string 62 bufSize int 63 }{ 64 {desc: "Same", bufSize: len(dataWritten)}, 65 {desc: "Large", bufSize: len(dataWritten) * 2}, 66 {desc: "Short", bufSize: len(dataWritten) / 2}, 67 } { 68 t.Run(tt.desc, func(t *testing.T) { 69 r, w := makePingConn(t) 70 71 // Write routine 72 errChan := make(chan error, 2) 73 go func() { 74 defer w.Close() 75 76 for i := 0; i < nWrites; i++ { 77 // Eventually write some pings. 78 if i%2 == 0 { 79 err := w.WritePing() 80 if err != nil { 81 errChan <- err 82 return 83 } 84 } 85 86 _, err := w.Write(dataWritten) 87 if err != nil { 88 errChan <- err 89 return 90 } 91 } 92 93 errChan <- nil 94 }() 95 96 // Read routine. 97 go func() { 98 defer r.Close() 99 100 buf := make([]byte, tt.bufSize) 101 102 for i := 0; i < nWrites; i++ { 103 var ( 104 aggregator []byte 105 n int 106 err error 107 ) 108 109 for n < len(dataWritten) { 110 n, err = r.Read(buf) 111 if err != nil { 112 errChan <- err 113 return 114 } 115 116 aggregator = append(aggregator, buf[:n]...) 117 } 118 119 if !bytes.Equal(dataWritten, aggregator) { 120 errChan <- fmt.Errorf("wrong content read, expected '%s', got '%s'", string(dataWritten), string(buf[:n])) 121 return 122 } 123 } 124 125 errChan <- nil 126 }() 127 128 // Expect routines to finish. 129 timer := time.NewTimer(10 * time.Second) 130 defer timer.Stop() 131 for i := 0; i < 1; i++ { 132 select { 133 case err := <-errChan: 134 require.NoError(t, err) 135 case <-timer.C: 136 require.Fail(t, "routing didn't finished in time") 137 } 138 } 139 }) 140 } 141 }) 142 143 // Given a connection, read from it concurrently, asserting all content 144 // written is read. 145 // 146 // Messages can be out of order due to concurrent reads. Other tests must 147 // guarantee message ordering. 148 t.Run("ConcurrentReads", func(t *testing.T) { 149 // Number of writes performed. 150 nWrites := 10 151 // Data that is going to be written/read on the connection. 152 dataWritten := []byte("message") 153 // Size of each read call. 154 readSize := 2 155 // Number of reads necessary to read the full message 156 readNum := int(math.Ceil(float64(len(dataWritten)) / float64(readSize))) 157 158 r, w := makePingConn(t) 159 defer r.Close() 160 defer w.Close() // This call may be a noop, but it's here just in case. 161 162 // readResult struct is used to store the result of a read. 163 type readResult struct { 164 data []byte 165 err error 166 } 167 168 // Channel is used to store the result of a read. 169 resChan := make(chan readResult) 170 171 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 172 defer cancel() 173 174 // Write routine 175 go func() { 176 for i := 0; i < nWrites; i++ { 177 _, err := w.Write(dataWritten) 178 if err != nil { 179 return 180 } 181 } 182 }() 183 184 // Read routines. 185 for i := 0; i < nWrites/2; i++ { 186 go func() { 187 buf := make([]byte, readSize) 188 for { 189 n, err := r.Read(buf) 190 if err != nil { 191 switch { 192 // Since we're partially reading the message, the last 193 // read will return an EOF. In this case, do nothing 194 // and send the remaining bytes. 195 case errors.Is(err, io.EOF): 196 // The connection will be closed only if the test is 197 // completed. The read result will be empty, so return 198 // the function to complete the goroutine. 199 case errors.Is(err, io.ErrClosedPipe): 200 return 201 // Any other error should fail the test and complete the 202 // goroutine. 203 default: 204 resChan <- readResult{err: err} 205 return 206 } 207 } 208 209 chanBytes := make([]byte, n) 210 copy(chanBytes, buf[:n]) 211 resChan <- readResult{data: chanBytes} 212 } 213 }() 214 } 215 216 var aggregator []byte 217 for i := 0; i < nWrites; i++ { 218 for j := 0; j < readNum; j++ { 219 select { 220 case <-ctx.Done(): 221 require.Fail(t, "Failed to read message (context timeout)") 222 case res := <-resChan: 223 require.NoError(t, res.err, "Failed to read message") 224 aggregator = append(aggregator, res.data...) 225 } 226 } 227 } 228 229 require.Len(t, aggregator, len(dataWritten)*nWrites, "Wrong messages written") 230 231 require.NoError(t, w.Close()) 232 233 res := <-resChan 234 // If there's an error here, it means the error was not io.EOF or io.ErrPipeClosed, as those should have been discarded 235 // by the goroutine above. This likely means that the errors in PingConn were wrapped with trace.Wrap, which can break 236 // callers of net.Conn. 237 require.NoError(t, res.err, "there should be no error on close, check if the errors have been wrapped with trace.Wrap") 238 }) 239 240 t.Run("ConcurrentWrites", func(t *testing.T) { 241 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 242 defer cancel() 243 244 w, r := makeBufferedPingConn(t) 245 defer w.Close() 246 defer r.Close() 247 248 nWrites := 10 249 dataWritten := []byte("message") 250 writeChan := make(chan error) 251 252 // Start write routines. 253 for i := 0; i < nWrites/2; i++ { 254 go func() { 255 for writes := 0; writes < 2; writes++ { 256 err := w.WritePing() 257 if err != nil { 258 writeChan <- err 259 return 260 } 261 262 _, err = w.Write(dataWritten) 263 if err != nil { 264 writeChan <- err 265 return 266 } 267 } 268 269 writeChan <- nil 270 }() 271 } 272 273 // Expect all writes to succeed. 274 for i := 0; i < nWrites/2; i++ { 275 select { 276 case <-ctx.Done(): 277 require.Fail(t, "timeout write") 278 case err := <-writeChan: 279 require.NoError(t, err) 280 } 281 } 282 283 // Read all messages. 284 buf := make([]byte, len(dataWritten)) 285 for i := 0; i < nWrites; i++ { 286 n, err := r.Read(buf) 287 require.NoError(t, err) 288 require.Equal(t, dataWritten, buf[:n]) 289 } 290 }) 291 }) 292 } 293 } 294 295 // makePingConn creates a piped ping connection. 296 func makePingConn(t *testing.T) (pingConn, pingConn) { 297 t.Helper() 298 299 writer, reader := net.Pipe() 300 return New(writer), New(reader) 301 } 302 303 // makePingTLSConn creates a piped TLS ping connection. 304 func makePingTLSConn(t *testing.T) (pingConn, pingConn) { 305 t.Helper() 306 307 writer, reader := net.Pipe() 308 tlsWriter, tlsReader := makeTLSConn(t, writer, reader) 309 310 return NewTLS(tlsWriter), NewTLS(tlsReader) 311 } 312 313 // makeBufferedPingConn creates connections to have asynchronous writes. 314 func makeBufferedPingConn(t *testing.T) (*PingConn, *PingConn) { 315 t.Helper() 316 317 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 318 defer cancel() 319 320 l, err := net.Listen("tcp", "localhost:0") 321 require.NoError(t, err) 322 323 connChan := make(chan struct { 324 net.Conn 325 error 326 }, 2) 327 328 // Accept 329 go func() { 330 conn, err := l.Accept() 331 connChan <- struct { 332 net.Conn 333 error 334 }{conn, err} 335 }() 336 337 // Dial 338 go func() { 339 conn, err := net.Dial("tcp", l.Addr().String()) 340 connChan <- struct { 341 net.Conn 342 error 343 }{conn, err} 344 }() 345 346 connSlice := make([]net.Conn, 2) 347 for i := 0; i < 2; i++ { 348 select { 349 case <-ctx.Done(): 350 require.Fail(t, "failed waiting for connections") 351 case res := <-connChan: 352 require.NoError(t, res.error) 353 connSlice[i] = res.Conn 354 } 355 } 356 357 tlsConnA, tlsConnB := makeTLSConn(t, connSlice[0], connSlice[1]) 358 return New(tlsConnA), New(tlsConnB) 359 } 360 361 // makeTLSConn take two connections (client and server) and wrap them into TLS 362 // connections. 363 func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) { 364 tlsConnChan := make(chan struct { 365 *tls.Conn 366 error 367 }, 2) 368 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 369 defer cancel() 370 371 cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) 372 require.NoError(t, err) 373 374 // Server 375 go func() { 376 tlsConn := tls.Server(server, &tls.Config{ 377 Certificates: []tls.Certificate{cert}, 378 }) 379 tlsConnChan <- struct { 380 *tls.Conn 381 error 382 }{tlsConn, tlsConn.HandshakeContext(ctx)} 383 }() 384 385 // Client 386 go func() { 387 tlsConn := tls.Client(client, &tls.Config{InsecureSkipVerify: true}) 388 tlsConnChan <- struct { 389 *tls.Conn 390 error 391 }{tlsConn, tlsConn.HandshakeContext(ctx)} 392 }() 393 394 tlsConnSlice := make([]*tls.Conn, 2) 395 for i := 0; i < 2; i++ { 396 select { 397 case <-ctx.Done(): 398 server.Close() 399 client.Close() 400 401 require.Fail(t, "failed waiting for TLS connections", "%d connections remaining", 2-i) 402 case res := <-tlsConnChan: 403 require.NoError(t, res.error) 404 tlsConnSlice[i] = res.Conn 405 } 406 } 407 408 return tlsConnSlice[0], tlsConnSlice[1] 409 }