github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/tcpip.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"math/rand"
    12  	"net"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  )
    18  
    19  // Listen requests the remote peer open a listening socket on
    20  // addr. Incoming connections will be available by calling Accept on
    21  // the returned net.Listener. The listener must be serviced, or the
    22  // SSH connection may hang.
    23  // N must be "tcp", "tcp4", "tcp6", or "unix".
    24  func (c *Client) Listen(n, addr string) (net.Listener, error) {
    25  	switch n {
    26  	case "tcp", "tcp4", "tcp6":
    27  		laddr, err := net.ResolveTCPAddr(n, addr)
    28  		if err != nil {
    29  			return nil, err
    30  		}
    31  		return c.ListenTCP(laddr)
    32  	case "unix":
    33  		return c.ListenUnix(addr)
    34  	default:
    35  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
    36  	}
    37  }
    38  
    39  // Automatic port allocation is broken with OpenSSH before 6.0. See
    40  // also https://bugzilla.mindrot.org/show_bug.cgi?id=2017.  In
    41  // particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
    42  // rather than the actual port number. This means you can never open
    43  // two different listeners with auto allocated ports. We work around
    44  // this by trying explicit ports until we succeed.
    45  
    46  const openSSHPrefix = "OpenSSH_"
    47  
    48  var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
    49  
    50  // isBrokenOpenSSHVersion returns true if the given version string
    51  // specifies a version of OpenSSH that is known to have a bug in port
    52  // forwarding.
    53  func isBrokenOpenSSHVersion(versionStr string) bool {
    54  	i := strings.Index(versionStr, openSSHPrefix)
    55  	if i < 0 {
    56  		return false
    57  	}
    58  	i += len(openSSHPrefix)
    59  	j := i
    60  	for ; j < len(versionStr); j++ {
    61  		if versionStr[j] < '0' || versionStr[j] > '9' {
    62  			break
    63  		}
    64  	}
    65  	version, _ := strconv.Atoi(versionStr[i:j])
    66  	return version < 6
    67  }
    68  
    69  // autoPortListenWorkaround simulates automatic port allocation by
    70  // trying random ports repeatedly.
    71  func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
    72  	var sshListener net.Listener
    73  	var err error
    74  	const tries = 10
    75  	for i := 0; i < tries; i++ {
    76  		addr := *laddr
    77  		addr.Port = 1024 + portRandomizer.Intn(60000)
    78  		sshListener, err = c.ListenTCP(&addr)
    79  		if err == nil {
    80  			laddr.Port = addr.Port
    81  			return sshListener, err
    82  		}
    83  	}
    84  	return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
    85  }
    86  
    87  // RFC 4254 7.1
    88  type channelForwardMsg struct {
    89  	addr  string
    90  	rport uint32
    91  }
    92  
    93  // handleForwards starts goroutines handling forwarded connections.
    94  // It's called on first use by (*Client).ListenTCP to not launch
    95  // goroutines until needed.
    96  func (c *Client) handleForwards() {
    97  	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
    98  	go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
    99  }
   100  
   101  // ListenTCP requests the remote peer open a listening socket
   102  // on laddr. Incoming connections will be available by calling
   103  // Accept on the returned net.Listener.
   104  func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
   105  	c.handleForwardsOnce.Do(c.handleForwards)
   106  	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
   107  		return c.autoPortListenWorkaround(laddr)
   108  	}
   109  
   110  	m := channelForwardMsg{
   111  		laddr.IP.String(),
   112  		uint32(laddr.Port),
   113  	}
   114  	// send message
   115  	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	if !ok {
   120  		return nil, errors.New("ssh: tcpip-forward request denied by peer")
   121  	}
   122  
   123  	// If the original port was 0, then the remote side will
   124  	// supply a real port number in the response.
   125  	if laddr.Port == 0 {
   126  		var p struct {
   127  			Port uint32
   128  		}
   129  		if err := Unmarshal(resp, &p); err != nil {
   130  			return nil, err
   131  		}
   132  		laddr.Port = int(p.Port)
   133  	}
   134  
   135  	// Register this forward, using the port number we obtained.
   136  	ch := c.forwards.add(laddr)
   137  
   138  	return &tcpListener{laddr, c, ch}, nil
   139  }
   140  
   141  // forwardList stores a mapping between remote
   142  // forward requests and the tcpListeners.
   143  type forwardList struct {
   144  	sync.Mutex
   145  	entries []forwardEntry
   146  }
   147  
   148  // forwardEntry represents an established mapping of a laddr on a
   149  // remote ssh server to a channel connected to a tcpListener.
   150  type forwardEntry struct {
   151  	laddr net.Addr
   152  	c     chan forward
   153  }
   154  
   155  // forward represents an incoming forwarded tcpip connection. The
   156  // arguments to add/remove/lookup should be address as specified in
   157  // the original forward-request.
   158  type forward struct {
   159  	newCh NewChannel // the ssh client channel underlying this forward
   160  	raddr net.Addr   // the raddr of the incoming connection
   161  }
   162  
   163  func (l *forwardList) add(addr net.Addr) chan forward {
   164  	l.Lock()
   165  	defer l.Unlock()
   166  	f := forwardEntry{
   167  		laddr: addr,
   168  		c:     make(chan forward, 1),
   169  	}
   170  	l.entries = append(l.entries, f)
   171  	return f.c
   172  }
   173  
   174  // See RFC 4254, section 7.2
   175  type forwardedTCPPayload struct {
   176  	Addr       string
   177  	Port       uint32
   178  	OriginAddr string
   179  	OriginPort uint32
   180  }
   181  
   182  // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
   183  func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
   184  	if port == 0 || port > 65535 {
   185  		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
   186  	}
   187  	ip := net.ParseIP(string(addr))
   188  	if ip == nil {
   189  		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
   190  	}
   191  	return &net.TCPAddr{IP: ip, Port: int(port)}, nil
   192  }
   193  
   194  func (l *forwardList) handleChannels(in <-chan NewChannel) {
   195  	for ch := range in {
   196  		var (
   197  			laddr net.Addr
   198  			raddr net.Addr
   199  			err   error
   200  		)
   201  		switch channelType := ch.ChannelType(); channelType {
   202  		case "forwarded-tcpip":
   203  			var payload forwardedTCPPayload
   204  			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
   205  				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
   206  				continue
   207  			}
   208  
   209  			// RFC 4254 section 7.2 specifies that incoming
   210  			// addresses should list the address, in string
   211  			// format. It is implied that this should be an IP
   212  			// address, as it would be impossible to connect to it
   213  			// otherwise.
   214  			laddr, err = parseTCPAddr(payload.Addr, payload.Port)
   215  			if err != nil {
   216  				ch.Reject(ConnectionFailed, err.Error())
   217  				continue
   218  			}
   219  			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
   220  			if err != nil {
   221  				ch.Reject(ConnectionFailed, err.Error())
   222  				continue
   223  			}
   224  
   225  		case "forwarded-streamlocal@openssh.com":
   226  			var payload forwardedStreamLocalPayload
   227  			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
   228  				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
   229  				continue
   230  			}
   231  			laddr = &net.UnixAddr{
   232  				Name: payload.SocketPath,
   233  				Net:  "unix",
   234  			}
   235  			raddr = &net.UnixAddr{
   236  				Name: "@",
   237  				Net:  "unix",
   238  			}
   239  		default:
   240  			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
   241  		}
   242  		if ok := l.forward(laddr, raddr, ch); !ok {
   243  			// Section 7.2, implementations MUST reject spurious incoming
   244  			// connections.
   245  			ch.Reject(Prohibited, "no forward for address")
   246  			continue
   247  		}
   248  
   249  	}
   250  }
   251  
   252  // remove removes the forward entry, and the channel feeding its
   253  // listener.
   254  func (l *forwardList) remove(addr net.Addr) {
   255  	l.Lock()
   256  	defer l.Unlock()
   257  	for i, f := range l.entries {
   258  		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
   259  			l.entries = append(l.entries[:i], l.entries[i+1:]...)
   260  			close(f.c)
   261  			return
   262  		}
   263  	}
   264  }
   265  
   266  // closeAll closes and clears all forwards.
   267  func (l *forwardList) closeAll() {
   268  	l.Lock()
   269  	defer l.Unlock()
   270  	for _, f := range l.entries {
   271  		close(f.c)
   272  	}
   273  	l.entries = nil
   274  }
   275  
   276  func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
   277  	l.Lock()
   278  	defer l.Unlock()
   279  	for _, f := range l.entries {
   280  		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
   281  			f.c <- forward{newCh: ch, raddr: raddr}
   282  			return true
   283  		}
   284  	}
   285  	return false
   286  }
   287  
   288  type tcpListener struct {
   289  	laddr *net.TCPAddr
   290  
   291  	conn *Client
   292  	in   <-chan forward
   293  }
   294  
   295  // Accept waits for and returns the next connection to the listener.
   296  func (l *tcpListener) Accept() (net.Conn, error) {
   297  	s, ok := <-l.in
   298  	if !ok {
   299  		return nil, io.EOF
   300  	}
   301  	ch, incoming, err := s.newCh.Accept()
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  	go DiscardRequests(incoming)
   306  
   307  	return &chanConn{
   308  		Channel: ch,
   309  		laddr:   l.laddr,
   310  		raddr:   s.raddr,
   311  	}, nil
   312  }
   313  
   314  // Close closes the listener.
   315  func (l *tcpListener) Close() error {
   316  	m := channelForwardMsg{
   317  		l.laddr.IP.String(),
   318  		uint32(l.laddr.Port),
   319  	}
   320  
   321  	// this also closes the listener.
   322  	l.conn.forwards.remove(l.laddr)
   323  	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
   324  	if err == nil && !ok {
   325  		err = errors.New("ssh: cancel-tcpip-forward failed")
   326  	}
   327  	return err
   328  }
   329  
   330  // Addr returns the listener's network address.
   331  func (l *tcpListener) Addr() net.Addr {
   332  	return l.laddr
   333  }
   334  
   335  // [Psiphon]
   336  // directTCPIPNoSplitTunnel is the same as "direct-tcpip", except it indicates
   337  // custom split tunnel behavior. It shares the same payload. We allow the
   338  // Client.Dial network type to optionally specify a channel type instead.
   339  const directTCPIPNoSplitTunnel = "direct-tcpip-no-split-tunnel@psiphon.ca"
   340  
   341  // Dial initiates a connection to the addr from the remote host.
   342  // The resulting connection has a zero LocalAddr() and RemoteAddr().
   343  func (c *Client) Dial(n, addr string) (net.Conn, error) {
   344  	var ch Channel
   345  	switch n {
   346  	case "tcp", "tcp4", "tcp6", "direct-tcpip", directTCPIPNoSplitTunnel:
   347  		// Parse the address into host and numeric port.
   348  		host, portString, err := net.SplitHostPort(addr)
   349  		if err != nil {
   350  			return nil, err
   351  		}
   352  		port, err := strconv.ParseUint(portString, 10, 16)
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  
   357  		// [Psiphon]
   358  		channelType := "direct-tcpip"
   359  		if n == directTCPIPNoSplitTunnel {
   360  			channelType = directTCPIPNoSplitTunnel
   361  		}
   362  
   363  		ch, err = c.dial(channelType, net.IPv4zero.String(), 0, host, int(port))
   364  		if err != nil {
   365  			return nil, err
   366  		}
   367  		// Use a zero address for local and remote address.
   368  		zeroAddr := &net.TCPAddr{
   369  			IP:   net.IPv4zero,
   370  			Port: 0,
   371  		}
   372  		return &chanConn{
   373  			Channel: ch,
   374  			laddr:   zeroAddr,
   375  			raddr:   zeroAddr,
   376  		}, nil
   377  	case "unix":
   378  		var err error
   379  		ch, err = c.dialStreamLocal(addr)
   380  		if err != nil {
   381  			return nil, err
   382  		}
   383  		return &chanConn{
   384  			Channel: ch,
   385  			laddr: &net.UnixAddr{
   386  				Name: "@",
   387  				Net:  "unix",
   388  			},
   389  			raddr: &net.UnixAddr{
   390  				Name: addr,
   391  				Net:  "unix",
   392  			},
   393  		}, nil
   394  	default:
   395  		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
   396  	}
   397  }
   398  
   399  // DialTCP connects to the remote address raddr on the network net,
   400  // which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
   401  // as the local address for the connection.
   402  func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
   403  	if laddr == nil {
   404  		laddr = &net.TCPAddr{
   405  			IP:   net.IPv4zero,
   406  			Port: 0,
   407  		}
   408  	}
   409  	ch, err := c.dial("direct-tcpip", laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  	return &chanConn{
   414  		Channel: ch,
   415  		laddr:   laddr,
   416  		raddr:   raddr,
   417  	}, nil
   418  }
   419  
   420  // RFC 4254 7.2
   421  type channelOpenDirectMsg struct {
   422  	raddr string
   423  	rport uint32
   424  	laddr string
   425  	lport uint32
   426  }
   427  
   428  func (c *Client) dial(channelType string, laddr string, lport int, raddr string, rport int) (Channel, error) {
   429  	msg := channelOpenDirectMsg{
   430  		raddr: raddr,
   431  		rport: uint32(rport),
   432  		laddr: laddr,
   433  		lport: uint32(lport),
   434  	}
   435  	ch, in, err := c.OpenChannel(channelType, Marshal(&msg))
   436  	if err != nil {
   437  		return nil, err
   438  	}
   439  	go DiscardRequests(in)
   440  	return ch, err
   441  }
   442  
   443  type tcpChan struct {
   444  	Channel // the backing channel
   445  }
   446  
   447  // chanConn fulfills the net.Conn interface without
   448  // the tcpChan having to hold laddr or raddr directly.
   449  type chanConn struct {
   450  	Channel
   451  	laddr, raddr net.Addr
   452  }
   453  
   454  // LocalAddr returns the local network address.
   455  func (t *chanConn) LocalAddr() net.Addr {
   456  	return t.laddr
   457  }
   458  
   459  // RemoteAddr returns the remote network address.
   460  func (t *chanConn) RemoteAddr() net.Addr {
   461  	return t.raddr
   462  }
   463  
   464  // SetDeadline sets the read and write deadlines associated
   465  // with the connection.
   466  func (t *chanConn) SetDeadline(deadline time.Time) error {
   467  	if err := t.SetReadDeadline(deadline); err != nil {
   468  		return err
   469  	}
   470  	return t.SetWriteDeadline(deadline)
   471  }
   472  
   473  // SetReadDeadline sets the read deadline.
   474  // A zero value for t means Read will not time out.
   475  // After the deadline, the error from Read will implement net.Error
   476  // with Timeout() == true.
   477  func (t *chanConn) SetReadDeadline(deadline time.Time) error {
   478  	// for compatibility with previous version,
   479  	// the error message contains "tcpChan"
   480  	return errors.New("ssh: tcpChan: deadline not supported")
   481  }
   482  
   483  // SetWriteDeadline exists to satisfy the net.Conn interface
   484  // but is not implemented by this type.  It always returns an error.
   485  func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
   486  	return errors.New("ssh: tcpChan: deadline not supported")
   487  }