github.com/sagernet/sing-box@v1.9.0-rc.20/outbound/ssh.go (about) 1 package outbound 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/base64" 7 "math/rand" 8 "net" 9 "os" 10 "strconv" 11 "strings" 12 "sync" 13 14 "github.com/sagernet/sing-box/adapter" 15 "github.com/sagernet/sing-box/common/dialer" 16 C "github.com/sagernet/sing-box/constant" 17 "github.com/sagernet/sing-box/log" 18 "github.com/sagernet/sing-box/option" 19 "github.com/sagernet/sing/common" 20 E "github.com/sagernet/sing/common/exceptions" 21 M "github.com/sagernet/sing/common/metadata" 22 N "github.com/sagernet/sing/common/network" 23 24 "golang.org/x/crypto/ssh" 25 ) 26 27 var ( 28 _ adapter.Outbound = (*SSH)(nil) 29 _ adapter.InterfaceUpdateListener = (*SSH)(nil) 30 ) 31 32 type SSH struct { 33 myOutboundAdapter 34 ctx context.Context 35 dialer N.Dialer 36 serverAddr M.Socksaddr 37 user string 38 hostKey []ssh.PublicKey 39 hostKeyAlgorithms []string 40 clientVersion string 41 authMethod []ssh.AuthMethod 42 clientAccess sync.Mutex 43 clientConn net.Conn 44 client *ssh.Client 45 } 46 47 func NewSSH(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (*SSH, error) { 48 outboundDialer, err := dialer.New(router, options.DialerOptions) 49 if err != nil { 50 return nil, err 51 } 52 outbound := &SSH{ 53 myOutboundAdapter: myOutboundAdapter{ 54 protocol: C.TypeSSH, 55 network: []string{N.NetworkTCP}, 56 router: router, 57 logger: logger, 58 tag: tag, 59 dependencies: withDialerDependency(options.DialerOptions), 60 }, 61 ctx: ctx, 62 dialer: outboundDialer, 63 serverAddr: options.ServerOptions.Build(), 64 user: options.User, 65 hostKeyAlgorithms: options.HostKeyAlgorithms, 66 clientVersion: options.ClientVersion, 67 } 68 if outbound.serverAddr.Port == 0 { 69 outbound.serverAddr.Port = 22 70 } 71 if outbound.user == "" { 72 outbound.user = "root" 73 } 74 if outbound.clientVersion == "" { 75 outbound.clientVersion = randomVersion() 76 } 77 if options.Password != "" { 78 outbound.authMethod = append(outbound.authMethod, ssh.Password(options.Password)) 79 } 80 if len(options.PrivateKey) > 0 || options.PrivateKeyPath != "" { 81 var privateKey []byte 82 if len(options.PrivateKey) > 0 { 83 privateKey = []byte(strings.Join(options.PrivateKey, "\n")) 84 } else { 85 var err error 86 privateKey, err = os.ReadFile(os.ExpandEnv(options.PrivateKeyPath)) 87 if err != nil { 88 return nil, E.Cause(err, "read private key") 89 } 90 } 91 var signer ssh.Signer 92 var err error 93 if options.PrivateKeyPassphrase == "" { 94 signer, err = ssh.ParsePrivateKey(privateKey) 95 } else { 96 signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, []byte(options.PrivateKeyPassphrase)) 97 } 98 if err != nil { 99 return nil, E.Cause(err, "parse private key") 100 } 101 outbound.authMethod = append(outbound.authMethod, ssh.PublicKeys(signer)) 102 } 103 if len(options.HostKey) > 0 { 104 for _, hostKey := range options.HostKey { 105 key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(hostKey)) 106 if err != nil { 107 return nil, E.New("parse host key ", key) 108 } 109 outbound.hostKey = append(outbound.hostKey, key) 110 } 111 } 112 return outbound, nil 113 } 114 115 func randomVersion() string { 116 version := "SSH-2.0-OpenSSH_" 117 if rand.Intn(2) == 0 { 118 version += "7." + strconv.Itoa(rand.Intn(10)) 119 } else { 120 version += "8." + strconv.Itoa(rand.Intn(9)) 121 } 122 return version 123 } 124 125 func (s *SSH) connect() (*ssh.Client, error) { 126 if s.client != nil { 127 return s.client, nil 128 } 129 130 s.clientAccess.Lock() 131 defer s.clientAccess.Unlock() 132 133 if s.client != nil { 134 return s.client, nil 135 } 136 137 conn, err := s.dialer.DialContext(s.ctx, N.NetworkTCP, s.serverAddr) 138 if err != nil { 139 return nil, err 140 } 141 config := &ssh.ClientConfig{ 142 User: s.user, 143 Auth: s.authMethod, 144 ClientVersion: s.clientVersion, 145 HostKeyAlgorithms: s.hostKeyAlgorithms, 146 HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { 147 if len(s.hostKey) == 0 { 148 return nil 149 } 150 serverKey := key.Marshal() 151 for _, hostKey := range s.hostKey { 152 if bytes.Equal(serverKey, hostKey.Marshal()) { 153 return nil 154 } 155 } 156 return E.New("host key mismatch, server send ", key.Type(), " ", base64.StdEncoding.EncodeToString(serverKey)) 157 }, 158 } 159 clientConn, chans, reqs, err := ssh.NewClientConn(conn, s.serverAddr.Addr.String(), config) 160 if err != nil { 161 conn.Close() 162 return nil, E.Cause(err, "connect to ssh server") 163 } 164 165 client := ssh.NewClient(clientConn, chans, reqs) 166 167 s.clientConn = conn 168 s.client = client 169 170 go func() { 171 client.Wait() 172 conn.Close() 173 s.clientAccess.Lock() 174 s.client = nil 175 s.clientConn = nil 176 s.clientAccess.Unlock() 177 }() 178 179 return client, nil 180 } 181 182 func (s *SSH) InterfaceUpdated() { 183 common.Close(s.clientConn) 184 return 185 } 186 187 func (s *SSH) Close() error { 188 return common.Close(s.clientConn) 189 } 190 191 func (s *SSH) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 192 client, err := s.connect() 193 if err != nil { 194 return nil, err 195 } 196 return client.Dial(network, destination.String()) 197 } 198 199 func (s *SSH) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 200 return nil, os.ErrInvalid 201 } 202 203 func (s *SSH) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 204 return NewConnection(ctx, s, conn, metadata) 205 } 206 207 func (s *SSH) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 208 return os.ErrInvalid 209 }