github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/pingconn/pingconn.go (about)

     1  /*
     2  Copyright 2023 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package pingconn
    18  
    19  import (
    20  	"encoding/binary"
    21  	"math"
    22  	"net"
    23  	"sync"
    24  
    25  	"github.com/gravitational/trace"
    26  )
    27  
    28  // New returns a ping connection wrapping the provided net.Conn.
    29  func New(conn net.Conn) *PingConn {
    30  	return &PingConn{Conn: conn}
    31  }
    32  
    33  // PingConn wraps a net.Conn and add ping capabilities to it, including the
    34  // `WritePing` function and `Read` (which excludes ping packets).
    35  //
    36  // When using this connection, the packets written will contain an initial data:
    37  // the packet size. When reading, this information is taken into account, but it
    38  // is not returned to the caller.
    39  //
    40  // Ping messages have a packet size of zero and are produced only when
    41  // `WritePing` is called. On `Read`, any Ping packet is discarded.
    42  type PingConn struct {
    43  	net.Conn
    44  
    45  	muRead  sync.Mutex
    46  	muWrite sync.Mutex
    47  
    48  	// currentSize size of bytes of the current packet.
    49  	currentSize uint32
    50  }
    51  
    52  // Read reads content from the underlying connection, discarding any ping
    53  // messages it finds.
    54  func (c *PingConn) Read(p []byte) (int, error) {
    55  	c.muRead.Lock()
    56  	defer c.muRead.Unlock()
    57  
    58  	err := c.discardPingReads()
    59  	if err != nil {
    60  		return 0, err
    61  	}
    62  
    63  	// Check if the current size is larger than the provided buffer.
    64  	readSize := c.currentSize
    65  	if c.currentSize > uint32(len(p)) {
    66  		readSize = uint32(len(p))
    67  	}
    68  
    69  	n, err := c.Conn.Read(p[:readSize])
    70  	c.currentSize -= uint32(n)
    71  
    72  	return n, err
    73  }
    74  
    75  // WritePing writes the ping packet to the connection.
    76  func (c *PingConn) WritePing() error {
    77  	c.muWrite.Lock()
    78  	defer c.muWrite.Unlock()
    79  
    80  	return binary.Write(c.Conn, binary.BigEndian, uint32(0))
    81  }
    82  
    83  // discardPingReads reads from the wrapped net.Conn until it encounters a
    84  // non-ping packet.
    85  func (c *PingConn) discardPingReads() error {
    86  	for c.currentSize == 0 {
    87  		err := binary.Read(c.Conn, binary.BigEndian, &c.currentSize)
    88  		if err != nil {
    89  			return err
    90  		}
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  // Write writes provided content to the underlying connection with proper
    97  // protocol fields.
    98  func (c *PingConn) Write(p []byte) (int, error) {
    99  	c.muWrite.Lock()
   100  	defer c.muWrite.Unlock()
   101  
   102  	// Avoid overflow when casting data length. It is only present to avoid
   103  	// panicking if the size cannot be cast. Callers should handle packet length
   104  	// limits, such as protocol implementations and audits.
   105  	if uint64(len(p)) > math.MaxUint32 {
   106  		return 0, trace.BadParameter("invalid content size, max size permitted is %d", uint64(math.MaxUint32))
   107  	}
   108  
   109  	size := uint32(len(p))
   110  	if size == 0 {
   111  		return 0, nil
   112  	}
   113  
   114  	// Write packet size.
   115  	if err := binary.Write(c.Conn, binary.BigEndian, size); err != nil {
   116  		return 0, trace.Wrap(err)
   117  	}
   118  
   119  	// Iterate until everything is written.
   120  	var written int
   121  	for written < len(p) {
   122  		n, err := c.Conn.Write(p)
   123  		written += n
   124  
   125  		if err != nil {
   126  			return written, trace.Wrap(err)
   127  		}
   128  	}
   129  
   130  	return written, nil
   131  }