github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/sshutils/chconn.go (about)

     1  /*
     2  Copyright 2015-2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sshutils
    18  
    19  import (
    20  	"errors"
    21  	"io"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/gravitational/trace"
    27  	"golang.org/x/crypto/ssh"
    28  
    29  	"github.com/gravitational/teleport/api/constants"
    30  )
    31  
    32  type Conn interface {
    33  	io.Closer
    34  	// RemoteAddr returns the remote address for this connection.
    35  	RemoteAddr() net.Addr
    36  	// LocalAddr returns the local address for this connection.
    37  	LocalAddr() net.Addr
    38  }
    39  
    40  // NewChConn returns a new net.Conn implemented over
    41  // SSH channel
    42  func NewChConn(conn Conn, ch ssh.Channel) *ChConn {
    43  	return newChConn(conn, ch, false)
    44  }
    45  
    46  // NewExclusiveChConn returns a new net.Conn implemented over
    47  // SSH channel, whenever this connection closes
    48  func NewExclusiveChConn(conn Conn, ch ssh.Channel) *ChConn {
    49  	return newChConn(conn, ch, true)
    50  }
    51  
    52  func newChConn(conn Conn, ch ssh.Channel, exclusive bool) *ChConn {
    53  	reader, writer := net.Pipe()
    54  	c := &ChConn{
    55  		Channel:   ch,
    56  		conn:      conn,
    57  		exclusive: exclusive,
    58  		reader:    reader,
    59  		writer:    writer,
    60  	}
    61  	// Start copying from the SSH channel to the writer part of the pipe. The
    62  	// clients are reading from the reader part of the pipe (see Read below).
    63  	//
    64  	// This goroutine stops when either the SSH channel closes or this
    65  	// connection is closed e.g. by a http.Server (see Close below).
    66  	go func() {
    67  		io.Copy(writer, ch)
    68  		// propagate EOF across the pipe to the read half.
    69  		writer.Close()
    70  	}()
    71  	return c
    72  }
    73  
    74  // ChConn is a net.Conn like object
    75  // that uses SSH channel
    76  type ChConn struct {
    77  	mu sync.Mutex
    78  
    79  	ssh.Channel
    80  	conn Conn
    81  	// exclusive indicates that whenever this channel connection
    82  	// is getting closed, the underlying connection is closed as well
    83  	exclusive bool
    84  
    85  	// reader is the part of the pipe that clients read from.
    86  	reader net.Conn
    87  	// writer is the part of the pipe that receives data from SSH channel.
    88  	writer net.Conn
    89  
    90  	// closed prevents double-close
    91  	closed bool
    92  }
    93  
    94  // Close closes channel and if the ChConn is exclusive, connection as well
    95  func (c *ChConn) Close() error {
    96  	c.mu.Lock()
    97  	defer c.mu.Unlock()
    98  	if c.closed {
    99  		return nil
   100  	}
   101  	c.closed = true
   102  	var errors []error
   103  	if err := c.Channel.Close(); err != nil {
   104  		errors = append(errors, err)
   105  	}
   106  	if err := c.reader.Close(); err != nil {
   107  		errors = append(errors, err)
   108  	}
   109  	if err := c.writer.Close(); err != nil {
   110  		errors = append(errors, err)
   111  	}
   112  	// Exclusive means close the underlying SSH connection as well.
   113  	if !c.exclusive {
   114  		return trace.NewAggregate(errors...)
   115  	}
   116  	if err := c.conn.Close(); err != nil {
   117  		errors = append(errors, err)
   118  	}
   119  	return trace.NewAggregate(errors...)
   120  }
   121  
   122  // LocalAddr returns a local address of a connection
   123  // Uses underlying net.Conn implementation
   124  func (c *ChConn) LocalAddr() net.Addr {
   125  	return c.conn.LocalAddr()
   126  }
   127  
   128  // RemoteAddr returns a remote address of a connection
   129  // Uses underlying net.Conn implementation
   130  func (c *ChConn) RemoteAddr() net.Addr {
   131  	return c.conn.RemoteAddr()
   132  }
   133  
   134  // Read reads from the channel.
   135  func (c *ChConn) Read(data []byte) (int, error) {
   136  	n, err := c.reader.Read(data)
   137  	// A lot of code relies on "use of closed network connection" error to
   138  	// gracefully handle terminated connections so convert the closed pipe
   139  	// error to it.
   140  	if err != nil && errors.Is(err, io.ErrClosedPipe) {
   141  		return n, trace.ConnectionProblem(err, constants.UseOfClosedNetworkConnection)
   142  	}
   143  	// Do not wrap the error to avoid masking the underlying error such as
   144  	// timeout error which is returned when read deadline is exceeded.
   145  	return n, err
   146  }
   147  
   148  // SetDeadline sets a connection deadline.
   149  func (c *ChConn) SetDeadline(t time.Time) error {
   150  	return c.reader.SetDeadline(t)
   151  }
   152  
   153  // SetReadDeadline sets a connection read deadline.
   154  func (c *ChConn) SetReadDeadline(t time.Time) error {
   155  	return c.reader.SetReadDeadline(t)
   156  }
   157  
   158  // SetWriteDeadline sets write deadline on a connection
   159  // ignored for the channel connection
   160  func (c *ChConn) SetWriteDeadline(t time.Time) error {
   161  	return nil
   162  }
   163  
   164  const (
   165  	// ConnectionTypeRequest is a request sent over a SSH channel that returns a
   166  	// boolean which indicates the connection type (direct or tunnel).
   167  	ConnectionTypeRequest = "x-teleport-connection-type"
   168  )