github.com/sagernet/sing-box@v1.2.7/outbound/ssh.go (about)

     1  package outbound
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"math/rand"
     8  	"net"
     9  	"os"
    10  	"strconv"
    11  	"sync"
    12  
    13  	"github.com/sagernet/sing-box/adapter"
    14  	"github.com/sagernet/sing-box/common/dialer"
    15  	C "github.com/sagernet/sing-box/constant"
    16  	"github.com/sagernet/sing-box/log"
    17  	"github.com/sagernet/sing-box/option"
    18  	"github.com/sagernet/sing/common"
    19  	E "github.com/sagernet/sing/common/exceptions"
    20  	M "github.com/sagernet/sing/common/metadata"
    21  	N "github.com/sagernet/sing/common/network"
    22  
    23  	"golang.org/x/crypto/ssh"
    24  )
    25  
    26  var (
    27  	_ adapter.Outbound                = (*SSH)(nil)
    28  	_ adapter.InterfaceUpdateListener = (*SSH)(nil)
    29  )
    30  
    31  type SSH struct {
    32  	myOutboundAdapter
    33  	ctx               context.Context
    34  	dialer            N.Dialer
    35  	serverAddr        M.Socksaddr
    36  	user              string
    37  	hostKey           []ssh.PublicKey
    38  	hostKeyAlgorithms []string
    39  	clientVersion     string
    40  	authMethod        []ssh.AuthMethod
    41  	clientAccess      sync.Mutex
    42  	clientConn        net.Conn
    43  	client            *ssh.Client
    44  }
    45  
    46  func NewSSH(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (*SSH, error) {
    47  	outbound := &SSH{
    48  		myOutboundAdapter: myOutboundAdapter{
    49  			protocol: C.TypeSSH,
    50  			network:  []string{N.NetworkTCP},
    51  			router:   router,
    52  			logger:   logger,
    53  			tag:      tag,
    54  		},
    55  		ctx:               ctx,
    56  		dialer:            dialer.New(router, options.DialerOptions),
    57  		serverAddr:        options.ServerOptions.Build(),
    58  		user:              options.User,
    59  		hostKeyAlgorithms: options.HostKeyAlgorithms,
    60  		clientVersion:     options.ClientVersion,
    61  	}
    62  	if outbound.serverAddr.Port == 0 {
    63  		outbound.serverAddr.Port = 22
    64  	}
    65  	if outbound.user == "" {
    66  		outbound.user = "root"
    67  	}
    68  	if outbound.clientVersion == "" {
    69  		outbound.clientVersion = randomVersion()
    70  	}
    71  	if options.Password != "" {
    72  		outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password))
    73  	}
    74  	if options.PrivateKey != "" || options.PrivateKeyPath != "" {
    75  		var privateKey []byte
    76  		if options.PrivateKey != "" {
    77  			privateKey = []byte(options.PrivateKey)
    78  		} else {
    79  			var err error
    80  			privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath))
    81  			if err != nil {
    82  				return nil, E.Cause(err, "read private key")
    83  			}
    84  		}
    85  		var signer ssh.Signer
    86  		var err error
    87  		if options.PrivateKeyPassphrase == "" {
    88  			signer, err = ssh.ParsePrivateKey(privateKey)
    89  		} else {
    90  			signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase))
    91  		}
    92  		if err != nil {
    93  			return nil, E.Cause(err, "parse private key")
    94  		}
    95  		outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer))
    96  	}
    97  	if len(options.HostKey) > 0 {
    98  		for _, hostKey := range options.HostKey {
    99  			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey))
   100  			if err != nil {
   101  				return nil, E.New("parse host key ", key)
   102  			}
   103  			outbound.hostKey = append(outbound.hostKey, key)
   104  		}
   105  	}
   106  	return outbound, nil
   107  }
   108  
   109  func randomVersion() string {
   110  	version := "SSH-2.0-OpenSSH_"
   111  	if rand.Intn(2) == 0 {
   112  		version += "7." + strconv.Itoa(rand.Intn(10))
   113  	} else {
   114  		version += "8." + strconv.Itoa(rand.Intn(9))
   115  	}
   116  	return version
   117  }
   118  
   119  func (s *SSH) connect() (*ssh.Client, error) {
   120  	if s.client != nil {
   121  		return s.client, nil
   122  	}
   123  
   124  	s.clientAccess.Lock()
   125  	defer s.clientAccess.Unlock()
   126  
   127  	if s.client != nil {
   128  		return s.client, nil
   129  	}
   130  
   131  	conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	config := &ssh.ClientConfig{
   136  		User:              s.user,
   137  		Auth:              s.authMethod,
   138  		ClientVersion:     s.clientVersion,
   139  		HostKeyAlgorithms: s.hostKeyAlgorithms,
   140  		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   141  			if len(s.hostKey) == 0 {
   142  				return nil
   143  			}
   144  			serverKey := key.Marshal()
   145  			for _, hostKey := range s.hostKey {
   146  				if bytes.Equal(serverKey, hostKey.Marshal()) {
   147  					return nil
   148  				}
   149  			}
   150  			return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey))
   151  		},
   152  	}
   153  	clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config)
   154  	if err != nil {
   155  		conn.Close()
   156  		return nil, E.Cause(err, "connect to ssh server")
   157  	}
   158  
   159  	client := ssh.NewClient(clientConn, chans, reqs)
   160  
   161  	s.clientConn = conn
   162  	s.client = client
   163  
   164  	go func() {
   165  		client.Wait()
   166  		conn.Close()
   167  		s.clientAccess.Lock()
   168  		s.client = nil
   169  		s.clientConn = nil
   170  		s.clientAccess.Unlock()
   171  	}()
   172  
   173  	return client, nil
   174  }
   175  
   176  func (s *SSH) InterfaceUpdated() error {
   177  	common.Close(s.clientConn)
   178  	return nil
   179  }
   180  
   181  func (s *SSH) Close() error {
   182  	return common.Close(s.clientConn)
   183  }
   184  
   185  func (s *SSH) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   186  	client, err := s.connect()
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	return client.Dial(network, destination.String())
   191  }
   192  
   193  func (s *SSH) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   194  	return nil, os.ErrInvalid
   195  }
   196  
   197  func (s *SSH) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   198  	return NewConnection(ctx, s, conn, metadata)
   199  }
   200  
   201  func (s *SSH) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   202  	return os.ErrInvalid
   203  }