github.com/glycerine/xcryptossh@v7.0.4+incompatible/streamlocal.go (about)

     1  package ssh
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  )
     9  
    10  // streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
    11  // with "direct-streamlocal@openssh.com" string.
    12  //
    13  // See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
    14  // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
    15  type streamLocalChannelOpenDirectMsg struct {
    16  	socketPath string
    17  	reserved0  string
    18  	reserved1  uint32
    19  }
    20  
    21  // forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
    22  // with "forwarded-streamlocal@openssh.com" string.
    23  type forwardedStreamLocalPayload struct {
    24  	SocketPath string
    25  	Reserved0  string
    26  }
    27  
    28  // streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
    29  // with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
    30  type streamLocalChannelForwardMsg struct {
    31  	socketPath string
    32  }
    33  
    34  // ListenUnix is similar to ListenTCP but uses a Unix domain socket.
    35  func (c *Client) ListenUnix(ctx context.Context, socketPath string) (net.Listener, error) {
    36  	m := streamLocalChannelForwardMsg{
    37  		socketPath,
    38  	}
    39  	// send message
    40  	ok, _, err := c.SendRequest(ctx, "streamlocal-forward@openssh.com", true, Marshal(&m))
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	if !ok {
    45  		return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
    46  	}
    47  	ch := c.Forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
    48  
    49  	return &unixListener{
    50  		socketPath: socketPath,
    51  		conn:       c,
    52  		in:         ch,
    53  		TmpCtx:     ctx,
    54  	}, nil
    55  }
    56  
    57  func (c *Client) dialStreamLocal(ctx context.Context, socketPath string) (Channel, error) {
    58  	msg := streamLocalChannelOpenDirectMsg{
    59  		socketPath: socketPath,
    60  	}
    61  	ch, in, err := c.OpenChannel(ctx, "direct-streamlocal@openssh.com", Marshal(&msg), nil)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	go DiscardRequests(ctx, in, c.Halt)
    66  	return ch, err
    67  }
    68  
    69  type unixListener struct {
    70  	socketPath string
    71  
    72  	conn *Client
    73  	in   <-chan forward
    74  
    75  	// must be set before calling Close()/Accept()
    76  	TmpCtx context.Context
    77  }
    78  
    79  // Accept waits for and returns the next connection to the listener.
    80  func (l *unixListener) Accept() (net.Conn, error) {
    81  	var ok bool
    82  	var s forward
    83  	select {
    84  	case <-l.conn.Done():
    85  		return nil, io.EOF
    86  	case s, ok = <-l.in:
    87  		if !ok {
    88  			return nil, io.EOF
    89  		}
    90  	}
    91  	ch, incoming, err := s.newCh.Accept()
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	go DiscardRequests(l.TmpCtx, incoming, l.conn.Halt)
    96  
    97  	return &chanConn{
    98  		Channel: ch,
    99  		laddr: &net.UnixAddr{
   100  			Name: l.socketPath,
   101  			Net:  "unix",
   102  		},
   103  		raddr: &net.UnixAddr{
   104  			Name: "@",
   105  			Net:  "unix",
   106  		},
   107  	}, nil
   108  }
   109  
   110  // Close closes the listener.
   111  func (l *unixListener) Close() error {
   112  	// this also closes the listener.
   113  	l.conn.Forwards.Remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
   114  	m := streamLocalChannelForwardMsg{
   115  		l.socketPath,
   116  	}
   117  	ok, _, err := l.conn.SendRequest(l.TmpCtx, "cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
   118  	if err == nil && !ok {
   119  		err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
   120  	}
   121  	return err
   122  }
   123  
   124  // Addr returns the listener's network address.
   125  func (l *unixListener) Addr() net.Addr {
   126  	return &net.UnixAddr{
   127  		Name: l.socketPath,
   128  		Net:  "unix",
   129  	}
   130  }