github.com/zooyer/miskit@v1.0.71/ssh/pool.go (about) 1 package ssh 2 3 import ( 4 "context" 5 "fmt" 6 "github.com/pkg/sftp" 7 "github.com/zooyer/miskit/utils/pool" 8 "golang.org/x/crypto/ssh" 9 "io" 10 "os" 11 "sync" 12 "time" 13 ) 14 15 type Remote struct { 16 Addr string 17 Username string 18 Password string 19 } 20 21 type PoolOption struct { 22 MaxConn int 23 MinConn int 24 } 25 26 type Pool struct { 27 min int 28 max int 29 idle time.Duration 30 conn map[string]chan *ssh.Client 31 remote map[string]Remote 32 errors []error 33 34 wg sync.WaitGroup 35 mutex sync.Mutex 36 close chan struct{} 37 } 38 39 func NewPool(min, max int, idle time.Duration) *Pool { 40 var pool = Pool{ 41 min: min, 42 max: max, 43 idle: idle, 44 conn: make(map[string]chan *ssh.Client), 45 remote: make(map[string]Remote), 46 close: make(chan struct{}, 10), 47 } 48 49 //go pool.loop() 50 51 return &pool 52 } 53 54 func NewPool2(min, max int, idle time.Duration) *Pool2 { 55 return &Pool2{ 56 min: min, 57 max: max, 58 idle: idle, 59 pool: make(map[string]*pool.Pool), 60 } 61 } 62 63 func (p *Pool) Add(addr, user, password string) { 64 p.remote[addr+user+password] = Remote{ 65 Addr: addr, 66 Username: user, 67 Password: password, 68 } 69 } 70 71 func (p *Pool) initOne(addr, user, password string) { 72 var key = p.key(addr, user, password) 73 if p.conn[key] == nil { 74 p.conn[key] = make(chan *ssh.Client, p.max+p.min) 75 } 76 77 start := time.Now() 78 if client, err := Client(user, password, addr); err == nil { 79 p.conn[key] <- client 80 } 81 fmt.Println("ssh connect:", user, addr, time.Since(start)) 82 83 return 84 } 85 86 func (p *Pool) init(addr, user, password string) { 87 var key = p.key(addr, user, password) 88 if p.conn[key] == nil { 89 p.conn[key] = make(chan *ssh.Client, p.max+p.min) 90 } 91 92 var wg sync.WaitGroup 93 var count = p.min - len(p.conn[key]) 94 wg.Add(count) 95 start := time.Now() 96 for i := 0; i < count; i++ { 97 go func() { 98 defer wg.Done() 99 if client, err := Client(user, password, addr); err == nil { 100 p.conn[key] <- client 101 } 102 }() 103 } 104 wg.Wait() 105 fmt.Println("ssh connect:", user, addr, time.Since(start)) 106 107 return 108 } 109 110 func (p *Pool) key(addr, user, password string) string { 111 return fmt.Sprintf("%s@%s:%s", user, addr, password) 112 } 113 114 func (p *Pool) getConn(addr, user, password string) (*ssh.Client, error) { 115 p.mutex.Lock() 116 defer p.mutex.Unlock() 117 118 var key = p.key(addr, user, password) 119 if len(p.conn[key]) == 0 { 120 p.init(addr, user, password) 121 } 122 if len(p.conn[key]) == 0 { 123 return nil, fmt.Errorf("%s@%s no connection available", user, addr) 124 } 125 126 return <-p.conn[key], nil 127 } 128 129 func (p *Pool) putConn(client *ssh.Client, addr, user, password string) { 130 if client == nil { 131 return 132 } 133 p.mutex.Lock() 134 defer p.mutex.Unlock() 135 136 var key = p.key(addr, user, password) 137 138 p.conn[key] <- client 139 } 140 141 func (p *Pool) Session(addr, user, password string) (session *ssh.Session, err error) { 142 client, err := p.getConn(addr, user, password) 143 if err != nil { 144 return 145 } 146 defer p.putConn(client, addr, user, password) 147 148 return client.NewSession() 149 } 150 151 func (p *Pool) SftpClient(addr, user, password string) (client *sftp.Client, err error) { 152 sshClient, err := p.getConn(addr, user, password) 153 if err != nil { 154 return 155 } 156 defer p.putConn(sshClient, addr, user, password) 157 158 return sftp.NewClient(sshClient) 159 } 160 161 func (p *Pool) ScpReader(reader io.Reader, remote, password string, fn func(size int)) (err error) { 162 user, addr, filename, err := parse(remote) 163 if err != nil { 164 return 165 } 166 167 client, err := p.SftpClient(addr, user, password) 168 if err != nil { 169 return 170 } 171 defer client.Close() 172 173 return ScpReader(client, filename, newReader(reader, fn)) 174 } 175 176 func (p *Pool) Scp(local, remote, password string, fn func(current, total int64)) (err error) { 177 file, err := os.Open(local) 178 if err != nil { 179 return 180 } 181 defer file.Close() 182 183 stat, err := file.Stat() 184 if err != nil { 185 return 186 } 187 188 var ( 189 total = stat.Size() 190 current int64 191 ) 192 193 return p.ScpReader(file, remote, password, func(size int) { 194 current += int64(size) 195 fn(current, total) 196 }) 197 } 198 199 func (p *Pool) Command(remote, password, cmd string) (output string, err error) { 200 user, addr, _, err := parse(remote) 201 if err != nil { 202 return 203 } 204 205 session, err := p.Session(addr, user, password) 206 if err != nil { 207 return 208 } 209 defer session.Close() 210 211 return CommandSession(session, cmd) 212 } 213 214 func (p *Pool) loop() { 215 for range p.close { 216 for _, remote := range p.remote { 217 p.mutex.Lock() 218 p.init(remote.Addr, remote.Username, remote.Password) 219 p.mutex.Unlock() 220 } 221 time.Sleep(time.Second) 222 } 223 } 224 225 func (p *Pool) Init(remote ...Remote) { 226 p.mutex.Lock() 227 defer p.mutex.Unlock() 228 for _, remote := range remote { 229 p.init(remote.Addr, remote.Username, remote.Password) 230 } 231 } 232 233 func (p *Pool) Close() error { 234 p.close <- struct{}{} 235 p.wg.Wait() 236 var err error 237 for _, err := range p.errors { 238 if err != nil { 239 return err 240 } 241 } 242 return err 243 } 244 245 type client struct { 246 *ssh.Client 247 } 248 249 func (c *client) Ping() error { 250 session, err := c.NewSession() 251 if err != nil { 252 return err 253 } 254 return session.Close() 255 } 256 257 type Pool2 struct { 258 min int 259 max int 260 idle time.Duration 261 pool map[string]*pool.Pool 262 mutex sync.Mutex 263 } 264 265 func (p *Pool2) key(addr, username, password string) string { 266 return fmt.Sprintf("%s@%s:%s", username, addr, password) 267 } 268 269 func (p *Pool2) get(addr, username, password string) (*ssh.Client, error) { 270 var key = p.key(addr, username, password) 271 272 p.mutex.Lock() 273 defer p.mutex.Unlock() 274 275 if p.pool[key] == nil { 276 var factory = func() (entry pool.Entry, err error) { 277 cli, err := Client(username, password, addr) 278 if err != nil { 279 return 280 } 281 return &client{Client: cli}, nil 282 } 283 284 p.pool[key] = pool.New(p.min, p.max, p.idle, factory) 285 } 286 287 var ctx = context.Background() 288 289 cli, err := p.pool[key].Get(ctx) 290 if err != nil { 291 return nil, err 292 } 293 294 return cli.(*client).Client, nil 295 } 296 297 func (p *Pool2) put(client *client, addr, username, password string) (err error) { 298 var key = p.key(addr, username, password) 299 300 p.mutex.Lock() 301 defer p.mutex.Unlock() 302 303 if err = p.pool[key].Put(client); err != nil { 304 return 305 } 306 307 return 308 } 309 310 func (p *Pool2) Session(addr, username, password string) (session *ssh.Session, err error) { 311 cli, err := p.get(addr, username, password) 312 if err != nil { 313 return 314 } 315 defer p.put(&client{Client: cli}, addr, username, password) 316 317 return cli.NewSession() 318 } 319 320 func (p *Pool2) SftpClient(addr, username, password string) (*sftp.Client, error) { 321 cli, err := p.get(addr, username, password) 322 if err != nil { 323 return nil, err 324 } 325 defer p.put(&client{Client: cli}, addr, username, password) 326 327 return sftp.NewClient(cli) 328 } 329 330 func (p *Pool2) ScpReader(reader io.Reader, remote, password string, fn func(size int)) (err error) { 331 user, addr, filename, err := parse(remote) 332 if err != nil { 333 return 334 } 335 336 client, err := p.SftpClient(addr, user, password) 337 if err != nil { 338 return 339 } 340 defer client.Close() 341 342 return ScpReader(client, filename, newReader(reader, fn)) 343 } 344 345 func (p *Pool2) Scp(local, remote, password string, fn func(current, total int64)) (err error) { 346 file, err := os.Open(local) 347 if err != nil { 348 return 349 } 350 defer file.Close() 351 352 stat, err := file.Stat() 353 if err != nil { 354 return 355 } 356 357 var ( 358 total = stat.Size() 359 current int64 360 ) 361 362 return p.ScpReader(file, remote, password, func(size int) { 363 current += int64(size) 364 fn(current, total) 365 }) 366 } 367 368 func (p *Pool2) Command(remote, password, cmd string) (output string, err error) { 369 user, addr, _, err := parse(remote) 370 if err != nil { 371 return 372 } 373 374 session, err := p.Session(addr, user, password) 375 if err != nil { 376 return 377 } 378 defer session.Close() 379 380 return CommandSession(session, cmd) 381 }