github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/outbound/ssh.go (about)

     1  package outbound
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"math/rand"
     8  	"net"
     9  	"os"
    10  	"strconv"
    11  	"sync"
    12  
    13  	"github.com/inazumav/sing-box/adapter"
    14  	"github.com/inazumav/sing-box/common/dialer"
    15  	C "github.com/inazumav/sing-box/constant"
    16  	"github.com/inazumav/sing-box/log"
    17  	"github.com/inazumav/sing-box/option"
    18  	"github.com/sagernet/sing/common"
    19  	E "github.com/sagernet/sing/common/exceptions"
    20  	M "github.com/sagernet/sing/common/metadata"
    21  	N "github.com/sagernet/sing/common/network"
    22  
    23  	"golang.org/x/crypto/ssh"
    24  )
    25  
    26  var (
    27  	_ adapter.Outbound                = (*SSH)(nil)
    28  	_ adapter.InterfaceUpdateListener = (*SSH)(nil)
    29  )
    30  
    31  type SSH struct {
    32  	myOutboundAdapter
    33  	ctx               context.Context
    34  	dialer            N.Dialer
    35  	serverAddr        M.Socksaddr
    36  	user              string
    37  	hostKey           []ssh.PublicKey
    38  	hostKeyAlgorithms []string
    39  	clientVersion     string
    40  	authMethod        []ssh.AuthMethod
    41  	clientAccess      sync.Mutex
    42  	clientConn        net.Conn
    43  	client            *ssh.Client
    44  }
    45  
    46  func NewSSH(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (*SSH, error) {
    47  	outboundDialer, err := dialer.New(router, options.DialerOptions)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	outbound := &SSH{
    52  		myOutboundAdapter: myOutboundAdapter{
    53  			protocol:     C.TypeSSH,
    54  			network:      []string{N.NetworkTCP},
    55  			router:       router,
    56  			logger:       logger,
    57  			tag:          tag,
    58  			dependencies: withDialerDependency(options.DialerOptions),
    59  		},
    60  		ctx:               ctx,
    61  		dialer:            outboundDialer,
    62  		serverAddr:        options.ServerOptions.Build(),
    63  		user:              options.User,
    64  		hostKeyAlgorithms: options.HostKeyAlgorithms,
    65  		clientVersion:     options.ClientVersion,
    66  	}
    67  	if outbound.serverAddr.Port == 0 {
    68  		outbound.serverAddr.Port = 22
    69  	}
    70  	if outbound.user == "" {
    71  		outbound.user = "root"
    72  	}
    73  	if outbound.clientVersion == "" {
    74  		outbound.clientVersion = randomVersion()
    75  	}
    76  	if options.Password != "" {
    77  		outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password))
    78  	}
    79  	if options.PrivateKey != "" || options.PrivateKeyPath != "" {
    80  		var privateKey []byte
    81  		if options.PrivateKey != "" {
    82  			privateKey = []byte(options.PrivateKey)
    83  		} else {
    84  			var err error
    85  			privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath))
    86  			if err != nil {
    87  				return nil, E.Cause(err, "read private key")
    88  			}
    89  		}
    90  		var signer ssh.Signer
    91  		var err error
    92  		if options.PrivateKeyPassphrase == "" {
    93  			signer, err = ssh.ParsePrivateKey(privateKey)
    94  		} else {
    95  			signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase))
    96  		}
    97  		if err != nil {
    98  			return nil, E.Cause(err, "parse private key")
    99  		}
   100  		outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer))
   101  	}
   102  	if len(options.HostKey) > 0 {
   103  		for _, hostKey := range options.HostKey {
   104  			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
   105  			if err != nil {
   106  				return nil, E.New("parse host key ", key)
   107  			}
   108  			outbound.hostKey = append(outbound.hostKey, key)
   109  		}
   110  	}
   111  	return outbound, nil
   112  }
   113  
   114  func randomVersion() string {
   115  	version := "SSH-2.0-OpenSSH_"
   116  	if rand.Intn(2) == 0 {
   117  		version += "7." + strconv.Itoa(rand.Intn(10))
   118  	} else {
   119  		version += "8." + strconv.Itoa(rand.Intn(9))
   120  	}
   121  	return version
   122  }
   123  
   124  func (s *SSH) connect() (*ssh.Client, error) {
   125  	if s.client != nil {
   126  		return s.client, nil
   127  	}
   128  
   129  	s.clientAccess.Lock()
   130  	defer s.clientAccess.Unlock()
   131  
   132  	if s.client != nil {
   133  		return s.client, nil
   134  	}
   135  
   136  	conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	config := &ssh.ClientConfig{
   141  		User:              s.user,
   142  		Auth:              s.authMethod,
   143  		ClientVersion:     s.clientVersion,
   144  		HostKeyAlgorithms: s.hostKeyAlgorithms,
   145  		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   146  			if len(s.hostKey) == 0 {
   147  				return nil
   148  			}
   149  			serverKey := key.Marshal()
   150  			for _, hostKey := range s.hostKey {
   151  				if bytes.Equal(serverKey, hostKey.Marshal()) {
   152  					return nil
   153  				}
   154  			}
   155  			return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
   156  		},
   157  	}
   158  	clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config)
   159  	if err != nil {
   160  		conn.Close()
   161  		return nil, E.Cause(err, "connect to ssh server")
   162  	}
   163  
   164  	client := ssh.NewClient(clientConn, chans, reqs)
   165  
   166  	s.clientConn = conn
   167  	s.client = client
   168  
   169  	go func() {
   170  		client.Wait()
   171  		conn.Close()
   172  		s.clientAccess.Lock()
   173  		s.client = nil
   174  		s.clientConn = nil
   175  		s.clientAccess.Unlock()
   176  	}()
   177  
   178  	return client, nil
   179  }
   180  
   181  func (s *SSH) InterfaceUpdated() {
   182  	common.Close(s.clientConn)
   183  	return
   184  }
   185  
   186  func (s *SSH) Close() error {
   187  	return common.Close(s.clientConn)
   188  }
   189  
   190  func (s *SSH) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   191  	client, err := s.connect()
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	return client.Dial(network, destination.String())
   196  }
   197  
   198  func (s *SSH) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   199  	return nil, os.ErrInvalid
   200  }
   201  
   202  func (s *SSH) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   203  	return NewConnection(ctx, s, conn, metadata)
   204  }
   205  
   206  func (s *SSH) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   207  	return os.ErrInvalid
   208  }