github.com/glycerine/xcryptossh@v7.0.4+incompatible/tcpip.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/rand"
    13  	"net"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  )
    19  
    20  // Listen requests the remote peer open a listening socket on
    21  // addr. Incoming connections will be available by calling Accept on
    22  // the returned net.Listener. The listener must be serviced, or the
    23  // SSH connection may hang.
    24  // N must be "tcp", "tcp4", "tcp6", or "unix".
    25  func (c *Client) Listen(n, addr string) (net.Listener, error) {
    26  	ctx := c.TmpCtx
    27  	if ctx == nil {
    28  		ctx = context.Background()
    29  	}
    30  
    31  	switch n {
    32  	case "tcp", "tcp4", "tcp6":
    33  		laddr, err := net.ResolveTCPAddr(n, addr)
    34  		if err != nil {
    35  			return nil, err
    36  		}
    37  		return c.ListenTCP(ctx, laddr)
    38  	case "unix":
    39  		return c.ListenUnix(ctx, addr)
    40  	default:
    41  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
    42  	}
    43  }
    44  
    45  // Automatic port allocation is broken with OpenSSH before 6.0. See
    46  // also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
    47  // particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
    48  // rather than the actual port number. This means you can never open
    49  // two different listeners with auto allocated ports. We work around
    50  // this by trying explicit ports until we succeed.
    51  
    52  const openSSHPrefix = "OpenSSH_"
    53  
    54  var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
    55  
    56  // isBrokenOpenSSHVersion returns true if the given version string
    57  // specifies a version of OpenSSH that is known to have a bug in port
    58  // forwarding.
    59  func isBrokenOpenSSHVersion(versionStr string) bool {
    60  	i := strings.Index(versionStr, openSSHPrefix)
    61  	if i < 0 {
    62  		return false
    63  	}
    64  	i += len(openSSHPrefix)
    65  	j := i
    66  	for ; j < len(versionStr); j++ {
    67  		if versionStr[j] < '0' || versionStr[j] > '9' {
    68  			break
    69  		}
    70  	}
    71  	version, _ := strconv.Atoi(versionStr[i:j])
    72  	return version < 6
    73  }
    74  
    75  // autoPortListenWorkaround simulates automatic port allocation by
    76  // trying random ports repeatedly.
    77  func (c *Client) autoPortListenWorkaround(ctx context.Context, laddr *net.TCPAddr) (*tcpListener, error) {
    78  	var sshListener *tcpListener
    79  	var err error
    80  	const tries = 10
    81  	for i := 0; i < tries; i++ {
    82  		addr := *laddr
    83  		addr.Port = 1024 + portRandomizer.Intn(60000)
    84  		sshListener, err = c.ListenTCP(ctx, &addr)
    85  		if err == nil {
    86  			laddr.Port = addr.Port
    87  			return sshListener, err
    88  		}
    89  	}
    90  	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
    91  }
    92  
    93  // RFC 4254 7.1
    94  type channelForwardMsg struct {
    95  	addr  string
    96  	rport uint32
    97  }
    98  
    99  // ListenTCP requests the remote peer open a listening socket
   100  // on laddr. Incoming connections will be available by calling
   101  // Accept on the returned net.Listener.
   102  func (c *Client) ListenTCP(ctx context.Context, laddr *net.TCPAddr) (*tcpListener, error) {
   103  	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
   104  		return c.autoPortListenWorkaround(ctx, laddr)
   105  	}
   106  
   107  	m := channelForwardMsg{
   108  		laddr.IP.String(),
   109  		uint32(laddr.Port),
   110  	}
   111  	// send message
   112  	ok, resp, err := c.SendRequest(ctx, "tcpip-forward", true, Marshal(&m))
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  	if !ok {
   117  		return nil, errors.New("ssh: tcpip-forward request denied by peer")
   118  	}
   119  
   120  	// If the original port was 0, then the remote side will
   121  	// supply a real port number in the response.
   122  	if laddr.Port == 0 {
   123  		var p struct {
   124  			Port uint32
   125  		}
   126  		if err := Unmarshal(resp, &p); err != nil {
   127  			return nil, err
   128  		}
   129  		laddr.Port = int(p.Port)
   130  	}
   131  
   132  	// Register this forward, using the port number we obtained.
   133  	ch := c.Forwards.add(laddr)
   134  
   135  	return &tcpListener{
   136  		laddr:  laddr,
   137  		conn:   c,
   138  		in:     ch,
   139  		TmpCtx: c.TmpCtx}, nil
   140  }
   141  
   142  // forwardList stores a mapping between remote
   143  // forward requests and the tcpListeners.
   144  type ForwardList struct {
   145  	sync.Mutex
   146  	entries []forwardEntry
   147  }
   148  
   149  // forwardEntry represents an established mapping of a laddr on a
   150  // remote ssh server to a channel connected to a tcpListener.
   151  type forwardEntry struct {
   152  	laddr net.Addr
   153  	c     chan forward
   154  }
   155  
   156  // forward represents an incoming forwarded tcpip connection. The
   157  // arguments to add/remove/lookup should be address as specified in
   158  // the original forward-request.
   159  type forward struct {
   160  	newCh NewChannel // the ssh client channel underlying this forward
   161  	raddr net.Addr   // the raddr of the incoming connection
   162  }
   163  
   164  func (l *ForwardList) add(addr net.Addr) chan forward {
   165  	l.Lock()
   166  	defer l.Unlock()
   167  	f := forwardEntry{
   168  		laddr: addr,
   169  		c:     make(chan forward, 1),
   170  	}
   171  	l.entries = append(l.entries, f)
   172  	return f.c
   173  }
   174  
   175  // See RFC 4254, section 7.2
   176  type forwardedTCPPayload struct {
   177  	Addr       string
   178  	Port       uint32
   179  	OriginAddr string
   180  	OriginPort uint32
   181  }
   182  
   183  // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
   184  func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
   185  	if port == 0 || port > 65535 {
   186  		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
   187  	}
   188  	ip := net.ParseIP(string(addr))
   189  	if ip == nil {
   190  		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
   191  	}
   192  	return &net.TCPAddr{IP: ip, Port: int(port)}, nil
   193  }
   194  
   195  func (l *ForwardList) HandleChannels(ctx context.Context, in <-chan NewChannel, conn Conn) {
   196  	var ch NewChannel
   197  	for {
   198  		select {
   199  		case <-conn.Done():
   200  			return
   201  		case <-ctx.Done():
   202  			return
   203  		case ch = <-in:
   204  			var (
   205  				laddr net.Addr
   206  				raddr net.Addr
   207  				err   error
   208  			)
   209  			switch channelType := ch.ChannelType(); channelType {
   210  			case "forwarded-tcpip":
   211  				var payload forwardedTCPPayload
   212  				if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
   213  					ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
   214  					continue
   215  				}
   216  
   217  				// RFC 4254 section 7.2 specifies that incoming
   218  				// addresses should list the address, in string
   219  				// format. It is implied that this should be an IP
   220  				// address, as it would be impossible to connect to it
   221  				// otherwise.
   222  				laddr, err = parseTCPAddr(payload.Addr, payload.Port)
   223  				if err != nil {
   224  					ch.Reject(ConnectionFailed, err.Error())
   225  					continue
   226  				}
   227  				raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
   228  				if err != nil {
   229  					ch.Reject(ConnectionFailed, err.Error())
   230  					continue
   231  				}
   232  
   233  			case "forwarded-streamlocal@openssh.com":
   234  				var payload forwardedStreamLocalPayload
   235  				if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
   236  					ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
   237  					continue
   238  				}
   239  				laddr = &net.UnixAddr{
   240  					Name: payload.SocketPath,
   241  					Net:  "unix",
   242  				}
   243  				raddr = &net.UnixAddr{
   244  					Name: "@",
   245  					Net:  "unix",
   246  				}
   247  			default:
   248  				panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
   249  			}
   250  			ok, err := l.Forward(ctx, laddr, raddr, ch, conn)
   251  			if err != nil {
   252  				return
   253  			}
   254  			if !ok {
   255  				// Section 7.2, implementations MUST reject spurious incoming
   256  				// connections.
   257  				ch.Reject(Prohibited, "no forward for address")
   258  				continue
   259  			}
   260  		}
   261  	}
   262  }
   263  
   264  // remove removes the forward entry, and the channel feeding its
   265  // listener.
   266  func (l *ForwardList) Remove(addr net.Addr) {
   267  	l.Lock()
   268  	defer l.Unlock()
   269  	for i, f := range l.entries {
   270  		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
   271  			l.entries = append(l.entries[:i], l.entries[i+1:]...)
   272  			close(f.c)
   273  			return
   274  		}
   275  	}
   276  }
   277  
   278  // closeAll closes and clears all forwards.
   279  func (l *ForwardList) CloseAll() {
   280  	l.Lock()
   281  	defer l.Unlock()
   282  	for _, f := range l.entries {
   283  		close(f.c)
   284  	}
   285  	l.entries = nil
   286  }
   287  
   288  func (l *ForwardList) Forward(ctx context.Context, laddr, raddr net.Addr, ch NewChannel, conn Conn) (bool, error) {
   289  	l.Lock()
   290  	defer l.Unlock()
   291  	for _, f := range l.entries {
   292  		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
   293  			select {
   294  			case f.c <- forward{newCh: ch, raddr: raddr}:
   295  				return true, nil
   296  			case <-conn.Done():
   297  				return false, io.EOF
   298  			case <-ctx.Done():
   299  				return false, io.EOF
   300  			}
   301  		}
   302  	}
   303  	return false, nil
   304  }
   305  
   306  type tcpListener struct {
   307  	laddr *net.TCPAddr
   308  
   309  	conn *Client
   310  	in   <-chan forward
   311  
   312  	// must be set for Accept() and Close() call.
   313  	TmpCtx context.Context
   314  }
   315  
   316  // Accept waits for and returns the next connection to the listener.
   317  func (l *tcpListener) Accept() (net.Conn, error) {
   318  	var ok bool
   319  	var s forward
   320  	select {
   321  	case <-l.conn.Done():
   322  		return nil, io.EOF
   323  	case <-l.TmpCtx.Done():
   324  		return nil, io.EOF
   325  	case s, ok = <-l.in:
   326  		if !ok {
   327  			return nil, io.EOF
   328  		}
   329  	}
   330  	ch, incoming, err := s.newCh.Accept()
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  	go DiscardRequests(l.TmpCtx, incoming, l.conn.Halt)
   335  
   336  	return &chanConn{
   337  		Channel: ch,
   338  		laddr:   l.laddr,
   339  		raddr:   s.raddr,
   340  	}, nil
   341  }
   342  
   343  // Close closes the listener.
   344  func (l *tcpListener) Close() error {
   345  	m := channelForwardMsg{
   346  		l.laddr.IP.String(),
   347  		uint32(l.laddr.Port),
   348  	}
   349  
   350  	// this also closes the listener.
   351  	l.conn.Forwards.Remove(l.laddr)
   352  	ok, _, err := l.conn.SendRequest(l.TmpCtx, "cancel-tcpip-forward", true, Marshal(&m))
   353  	if err == nil && !ok {
   354  		err = errors.New("ssh: cancel-tcpip-forward failed")
   355  	}
   356  	return err
   357  }
   358  
   359  // Addr returns the listener's network address.
   360  func (l *tcpListener) Addr() net.Addr {
   361  	return l.laddr
   362  }
   363  
   364  // Dial initiates a connection to the addr from the remote host.
   365  // The n argument is the network: "tcp", "tcp4", "tcp6", "unix".
   366  // The resulting connection has a zero LocalAddr() and RemoteAddr().
   367  func (c *Client) Dial(n, addr string) (Channel, error) {
   368  	ctx := c.TmpCtx
   369  	if ctx == nil {
   370  		ctx = context.Background()
   371  	}
   372  	return c.DialWithContext(ctx, n, addr)
   373  }
   374  
   375  // DialWithContext is the same as Dial, but with ctx.
   376  // The resulting connection has a zero LocalAddr() and RemoteAddr().
   377  func (c *Client) DialWithContext(ctx context.Context, n, addr string) (Channel, error) {
   378  
   379  	var ch Channel
   380  	switch n {
   381  	case "tcp", "tcp4", "tcp6":
   382  		// Parse the address into host and numeric port.
   383  		host, portString, err := net.SplitHostPort(addr)
   384  		if err != nil {
   385  			return nil, err
   386  		}
   387  		port, err := strconv.ParseUint(portString, 10, 16)
   388  		if err != nil {
   389  			return nil, err
   390  		}
   391  		ch, err = c.dial(ctx, net.IPv4zero.String(), 0, host, int(port))
   392  		if err != nil {
   393  			return nil, err
   394  		}
   395  		// Use a zero address for local and remote address.
   396  		zeroAddr := &net.TCPAddr{
   397  			IP:   net.IPv4zero,
   398  			Port: 0,
   399  		}
   400  		return &chanConn{
   401  			Channel: ch,
   402  			laddr:   zeroAddr,
   403  			raddr:   zeroAddr,
   404  		}, nil
   405  	case "unix":
   406  		var err error
   407  		ch, err = c.dialStreamLocal(ctx, addr)
   408  		if err != nil {
   409  			return nil, err
   410  		}
   411  		return &chanConn{
   412  			Channel: ch,
   413  			laddr: &net.UnixAddr{
   414  				Name: "@",
   415  				Net:  "unix",
   416  			},
   417  			raddr: &net.UnixAddr{
   418  				Name: addr,
   419  				Net:  "unix",
   420  			},
   421  		}, nil
   422  	default:
   423  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
   424  	}
   425  }
   426  
   427  // DialTCP connects to the remote address raddr on the network net,
   428  // which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
   429  // as the local address for the connection.
   430  func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
   431  	ctx := c.TmpCtx
   432  	if ctx == nil {
   433  		ctx = context.Background()
   434  	}
   435  
   436  	if laddr == nil {
   437  		laddr = &net.TCPAddr{
   438  			IP:   net.IPv4zero,
   439  			Port: 0,
   440  		}
   441  	}
   442  	ch, err := c.dial(ctx, laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
   443  	if err != nil {
   444  		return nil, err
   445  	}
   446  	return &chanConn{
   447  		Channel: ch,
   448  		laddr:   laddr,
   449  		raddr:   raddr,
   450  	}, nil
   451  }
   452  
   453  // RFC 4254 7.2
   454  type channelOpenDirectMsg struct {
   455  	raddr string
   456  	rport uint32
   457  	laddr string
   458  	lport uint32
   459  }
   460  
   461  func (c *Client) dial(ctx context.Context, laddr string, lport int, raddr string, rport int) (Channel, error) {
   462  	msg := channelOpenDirectMsg{
   463  		raddr: raddr,
   464  		rport: uint32(rport),
   465  		laddr: laddr,
   466  		lport: uint32(lport),
   467  	}
   468  	ch, in, err := c.OpenChannel(ctx, "direct-tcpip", Marshal(&msg), nil)
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  	go DiscardRequests(ctx, in, c.Halt)
   473  	return ch, err
   474  }
   475  
   476  type tcpChan struct {
   477  	Channel // the backing channel
   478  }
   479  
   480  // chanConn fulfills the net.Conn interface without
   481  // the tcpChan having to hold laddr or raddr directly.
   482  type chanConn struct {
   483  	Channel
   484  	laddr, raddr net.Addr
   485  }
   486  
   487  // LocalAddr returns the local network address.
   488  func (t *chanConn) LocalAddr() net.Addr {
   489  	return t.laddr
   490  }
   491  
   492  // RemoteAddr returns the remote network address.
   493  func (t *chanConn) RemoteAddr() net.Addr {
   494  	return t.raddr
   495  }
   496  
   497  // SetDeadline sets the read and write deadlines associated
   498  // with the connection.
   499  func (t *chanConn) SetDeadline(deadline time.Time) error {
   500  	if err := t.SetReadDeadline(deadline); err != nil {
   501  		return err
   502  	}
   503  	return t.SetWriteDeadline(deadline)
   504  }
   505  
   506  // SetReadDeadline sets the read deadline.
   507  // A zero value for t means Read will not time out.
   508  // After the deadline, the error from Read will implement net.Error
   509  // with Timeout() == true.
   510  func (t *chanConn) SetReadDeadline(deadline time.Time) error {
   511  	// for compatibility with previous version,
   512  	// the error message contains "tcpChan"
   513  	return errors.New("ssh: tcpChan: deadline not supported")
   514  }
   515  
   516  // SetWriteDeadline exists to satisfy the net.Conn interface
   517  // but is not implemented by this type.  It always returns an error.
   518  func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
   519  	return errors.New("ssh: tcpChan: deadline not supported")
   520  }