github.com/sagernet/sing-box@v1.9.0-rc.20/outbound/ssh.go (about)

     1  package outbound
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"math/rand"
     8  	"net"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/sagernet/sing-box/adapter"
    15  	"github.com/sagernet/sing-box/common/dialer"
    16  	C "github.com/sagernet/sing-box/constant"
    17  	"github.com/sagernet/sing-box/log"
    18  	"github.com/sagernet/sing-box/option"
    19  	"github.com/sagernet/sing/common"
    20  	E "github.com/sagernet/sing/common/exceptions"
    21  	M "github.com/sagernet/sing/common/metadata"
    22  	N "github.com/sagernet/sing/common/network"
    23  
    24  	"golang.org/x/crypto/ssh"
    25  )
    26  
    27  var (
    28  	_ adapter.Outbound                = (*SSH)(nil)
    29  	_ adapter.InterfaceUpdateListener = (*SSH)(nil)
    30  )
    31  
    32  type SSH struct {
    33  	myOutboundAdapter
    34  	ctx               context.Context
    35  	dialer            N.Dialer
    36  	serverAddr        M.Socksaddr
    37  	user              string
    38  	hostKey           []ssh.PublicKey
    39  	hostKeyAlgorithms []string
    40  	clientVersion     string
    41  	authMethod        []ssh.AuthMethod
    42  	clientAccess      sync.Mutex
    43  	clientConn        net.Conn
    44  	client            *ssh.Client
    45  }
    46  
    47  func NewSSH(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (*SSH, error) {
    48  	outboundDialer, err := dialer.New(router, options.DialerOptions)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	outbound := &SSH{
    53  		myOutboundAdapter: myOutboundAdapter{
    54  			protocol:     C.TypeSSH,
    55  			network:      []string{N.NetworkTCP},
    56  			router:       router,
    57  			logger:       logger,
    58  			tag:          tag,
    59  			dependencies: withDialerDependency(options.DialerOptions),
    60  		},
    61  		ctx:               ctx,
    62  		dialer:            outboundDialer,
    63  		serverAddr:        options.ServerOptions.Build(),
    64  		user:              options.User,
    65  		hostKeyAlgorithms: options.HostKeyAlgorithms,
    66  		clientVersion:     options.ClientVersion,
    67  	}
    68  	if outbound.serverAddr.Port == 0 {
    69  		outbound.serverAddr.Port = 22
    70  	}
    71  	if outbound.user == "" {
    72  		outbound.user = "root"
    73  	}
    74  	if outbound.clientVersion == "" {
    75  		outbound.clientVersion = randomVersion()
    76  	}
    77  	if options.Password != "" {
    78  		outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password))
    79  	}
    80  	if len(options.PrivateKey) > 0 || options.PrivateKeyPath != "" {
    81  		var privateKey []byte
    82  		if len(options.PrivateKey) > 0 {
    83  			privateKey = []byte(strings.Join(options.PrivateKey, "\n"))
    84  		} else {
    85  			var err error
    86  			privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath))
    87  			if err != nil {
    88  				return nil, E.Cause(err, "read private key")
    89  			}
    90  		}
    91  		var signer ssh.Signer
    92  		var err error
    93  		if options.PrivateKeyPassphrase == "" {
    94  			signer, err = ssh.ParsePrivateKey(privateKey)
    95  		} else {
    96  			signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase))
    97  		}
    98  		if err != nil {
    99  			return nil, E.Cause(err, "parse private key")
   100  		}
   101  		outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer))
   102  	}
   103  	if len(options.HostKey) > 0 {
   104  		for _, hostKey := range options.HostKey {
   105  			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
   106  			if err != nil {
   107  				return nil, E.New("parse host key ", key)
   108  			}
   109  			outbound.hostKey = append(outbound.hostKey, key)
   110  		}
   111  	}
   112  	return outbound, nil
   113  }
   114  
   115  func randomVersion() string {
   116  	version := "SSH-2.0-OpenSSH_"
   117  	if rand.Intn(2) == 0 {
   118  		version += "7." + strconv.Itoa(rand.Intn(10))
   119  	} else {
   120  		version += "8." + strconv.Itoa(rand.Intn(9))
   121  	}
   122  	return version
   123  }
   124  
   125  func (s *SSH) connect() (*ssh.Client, error) {
   126  	if s.client != nil {
   127  		return s.client, nil
   128  	}
   129  
   130  	s.clientAccess.Lock()
   131  	defer s.clientAccess.Unlock()
   132  
   133  	if s.client != nil {
   134  		return s.client, nil
   135  	}
   136  
   137  	conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  	config := &ssh.ClientConfig{
   142  		User:              s.user,
   143  		Auth:              s.authMethod,
   144  		ClientVersion:     s.clientVersion,
   145  		HostKeyAlgorithms: s.hostKeyAlgorithms,
   146  		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   147  			if len(s.hostKey) == 0 {
   148  				return nil
   149  			}
   150  			serverKey := key.Marshal()
   151  			for _, hostKey := range s.hostKey {
   152  				if bytes.Equal(serverKey, hostKey.Marshal()) {
   153  					return nil
   154  				}
   155  			}
   156  			return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
   157  		},
   158  	}
   159  	clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config)
   160  	if err != nil {
   161  		conn.Close()
   162  		return nil, E.Cause(err, "connect to ssh server")
   163  	}
   164  
   165  	client := ssh.NewClient(clientConn, chans, reqs)
   166  
   167  	s.clientConn = conn
   168  	s.client = client
   169  
   170  	go func() {
   171  		client.Wait()
   172  		conn.Close()
   173  		s.clientAccess.Lock()
   174  		s.client = nil
   175  		s.clientConn = nil
   176  		s.clientAccess.Unlock()
   177  	}()
   178  
   179  	return client, nil
   180  }
   181  
   182  func (s *SSH) InterfaceUpdated() {
   183  	common.Close(s.clientConn)
   184  	return
   185  }
   186  
   187  func (s *SSH) Close() error {
   188  	return common.Close(s.clientConn)
   189  }
   190  
   191  func (s *SSH) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   192  	client, err := s.connect()
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  	return client.Dial(network, destination.String())
   197  }
   198  
   199  func (s *SSH) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   200  	return nil, os.ErrInvalid
   201  }
   202  
   203  func (s *SSH) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   204  	return NewConnection(ctx, s, conn, metadata)
   205  }
   206  
   207  func (s *SSH) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   208  	return os.ErrInvalid
   209  }