github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/pingconn/pingconn.go (about) 1 /* 2 Copyright 2023 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package pingconn 18 19 import ( 20 "encoding/binary" 21 "math" 22 "net" 23 "sync" 24 25 "github.com/gravitational/trace" 26 ) 27 28 // New returns a ping connection wrapping the provided net.Conn. 29 func New(conn net.Conn) *PingConn { 30 return &PingConn{Conn: conn} 31 } 32 33 // PingConn wraps a net.Conn and add ping capabilities to it, including the 34 // `WritePing` function and `Read` (which excludes ping packets). 35 // 36 // When using this connection, the packets written will contain an initial data: 37 // the packet size. When reading, this information is taken into account, but it 38 // is not returned to the caller. 39 // 40 // Ping messages have a packet size of zero and are produced only when 41 // `WritePing` is called. On `Read`, any Ping packet is discarded. 42 type PingConn struct { 43 net.Conn 44 45 muRead sync.Mutex 46 muWrite sync.Mutex 47 48 // currentSize size of bytes of the current packet. 49 currentSize uint32 50 } 51 52 // Read reads content from the underlying connection, discarding any ping 53 // messages it finds. 54 func (c *PingConn) Read(p []byte) (int, error) { 55 c.muRead.Lock() 56 defer c.muRead.Unlock() 57 58 err := c.discardPingReads() 59 if err != nil { 60 return 0, err 61 } 62 63 // Check if the current size is larger than the provided buffer. 64 readSize := c.currentSize 65 if c.currentSize > uint32(len(p)) { 66 readSize = uint32(len(p)) 67 } 68 69 n, err := c.Conn.Read(p[:readSize]) 70 c.currentSize -= uint32(n) 71 72 return n, err 73 } 74 75 // WritePing writes the ping packet to the connection. 76 func (c *PingConn) WritePing() error { 77 c.muWrite.Lock() 78 defer c.muWrite.Unlock() 79 80 return binary.Write(c.Conn, binary.BigEndian, uint32(0)) 81 } 82 83 // discardPingReads reads from the wrapped net.Conn until it encounters a 84 // non-ping packet. 85 func (c *PingConn) discardPingReads() error { 86 for c.currentSize == 0 { 87 err := binary.Read(c.Conn, binary.BigEndian, &c.currentSize) 88 if err != nil { 89 return err 90 } 91 } 92 93 return nil 94 } 95 96 // Write writes provided content to the underlying connection with proper 97 // protocol fields. 98 func (c *PingConn) Write(p []byte) (int, error) { 99 c.muWrite.Lock() 100 defer c.muWrite.Unlock() 101 102 // Avoid overflow when casting data length. It is only present to avoid 103 // panicking if the size cannot be cast. Callers should handle packet length 104 // limits, such as protocol implementations and audits. 105 if uint64(len(p)) > math.MaxUint32 { 106 return 0, trace.BadParameter("invalid content size, max size permitted is %d", uint64(math.MaxUint32)) 107 } 108 109 size := uint32(len(p)) 110 if size == 0 { 111 return 0, nil 112 } 113 114 // Write packet size. 115 if err := binary.Write(c.Conn, binary.BigEndian, size); err != nil { 116 return 0, trace.Wrap(err) 117 } 118 119 // Iterate until everything is written. 120 var written int 121 for written < len(p) { 122 n, err := c.Conn.Write(p) 123 written += n 124 125 if err != nil { 126 return written, trace.Wrap(err) 127 } 128 } 129 130 return written, nil 131 }