github.com/zooyer/miskit@v1.0.71/ssh/pool.go (about)

     1  package ssh
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/pkg/sftp"
     7  	"github.com/zooyer/miskit/utils/pool"
     8  	"golang.org/x/crypto/ssh"
     9  	"io"
    10  	"os"
    11  	"sync"
    12  	"time"
    13  )
    14  
    15  type Remote struct {
    16  	Addr     string
    17  	Username string
    18  	Password string
    19  }
    20  
    21  type PoolOption struct {
    22  	MaxConn int
    23  	MinConn int
    24  }
    25  
    26  type Pool struct {
    27  	min    int
    28  	max    int
    29  	idle   time.Duration
    30  	conn   map[string]chan *ssh.Client
    31  	remote map[string]Remote
    32  	errors []error
    33  
    34  	wg    sync.WaitGroup
    35  	mutex sync.Mutex
    36  	close chan struct{}
    37  }
    38  
    39  func NewPool(min, max int, idle time.Duration) *Pool {
    40  	var pool = Pool{
    41  		min:    min,
    42  		max:    max,
    43  		idle:   idle,
    44  		conn:   make(map[string]chan *ssh.Client),
    45  		remote: make(map[string]Remote),
    46  		close:  make(chan struct{}, 10),
    47  	}
    48  
    49  	//go pool.loop()
    50  
    51  	return &pool
    52  }
    53  
    54  func NewPool2(min, max int, idle time.Duration) *Pool2 {
    55  	return &Pool2{
    56  		min:  min,
    57  		max:  max,
    58  		idle: idle,
    59  		pool: make(map[string]*pool.Pool),
    60  	}
    61  }
    62  
    63  func (p *Pool) Add(addr, user, password string) {
    64  	p.remote[addr+user+password] = Remote{
    65  		Addr:     addr,
    66  		Username: user,
    67  		Password: password,
    68  	}
    69  }
    70  
    71  func (p *Pool) initOne(addr, user, password string) {
    72  	var key = p.key(addr, user, password)
    73  	if p.conn[key] == nil {
    74  		p.conn[key] = make(chan *ssh.Client, p.max+p.min)
    75  	}
    76  
    77  	start := time.Now()
    78  	if client, err := Client(user, password, addr); err == nil {
    79  		p.conn[key] <- client
    80  	}
    81  	fmt.Println("ssh connect:", user, addr, time.Since(start))
    82  
    83  	return
    84  }
    85  
    86  func (p *Pool) init(addr, user, password string) {
    87  	var key = p.key(addr, user, password)
    88  	if p.conn[key] == nil {
    89  		p.conn[key] = make(chan *ssh.Client, p.max+p.min)
    90  	}
    91  
    92  	var wg sync.WaitGroup
    93  	var count = p.min - len(p.conn[key])
    94  	wg.Add(count)
    95  	start := time.Now()
    96  	for i := 0; i < count; i++ {
    97  		go func() {
    98  			defer wg.Done()
    99  			if client, err := Client(user, password, addr); err == nil {
   100  				p.conn[key] <- client
   101  			}
   102  		}()
   103  	}
   104  	wg.Wait()
   105  	fmt.Println("ssh connect:", user, addr, time.Since(start))
   106  
   107  	return
   108  }
   109  
   110  func (p *Pool) key(addr, user, password string) string {
   111  	return fmt.Sprintf("%s@%s:%s", user, addr, password)
   112  }
   113  
   114  func (p *Pool) getConn(addr, user, password string) (*ssh.Client, error) {
   115  	p.mutex.Lock()
   116  	defer p.mutex.Unlock()
   117  
   118  	var key = p.key(addr, user, password)
   119  	if len(p.conn[key]) == 0 {
   120  		p.init(addr, user, password)
   121  	}
   122  	if len(p.conn[key]) == 0 {
   123  		return nil, fmt.Errorf("%s@%s no connection available", user, addr)
   124  	}
   125  
   126  	return <-p.conn[key], nil
   127  }
   128  
   129  func (p *Pool) putConn(client *ssh.Client, addr, user, password string) {
   130  	if client == nil {
   131  		return
   132  	}
   133  	p.mutex.Lock()
   134  	defer p.mutex.Unlock()
   135  
   136  	var key = p.key(addr, user, password)
   137  
   138  	p.conn[key] <- client
   139  }
   140  
   141  func (p *Pool) Session(addr, user, password string) (session *ssh.Session, err error) {
   142  	client, err := p.getConn(addr, user, password)
   143  	if err != nil {
   144  		return
   145  	}
   146  	defer p.putConn(client, addr, user, password)
   147  
   148  	return client.NewSession()
   149  }
   150  
   151  func (p *Pool) SftpClient(addr, user, password string) (client *sftp.Client, err error) {
   152  	sshClient, err := p.getConn(addr, user, password)
   153  	if err != nil {
   154  		return
   155  	}
   156  	defer p.putConn(sshClient, addr, user, password)
   157  
   158  	return sftp.NewClient(sshClient)
   159  }
   160  
   161  func (p *Pool) ScpReader(reader io.Reader, remote, password string, fn func(size int)) (err error) {
   162  	user, addr, filename, err := parse(remote)
   163  	if err != nil {
   164  		return
   165  	}
   166  
   167  	client, err := p.SftpClient(addr, user, password)
   168  	if err != nil {
   169  		return
   170  	}
   171  	defer client.Close()
   172  
   173  	return ScpReader(client, filename, newReader(reader, fn))
   174  }
   175  
   176  func (p *Pool) Scp(local, remote, password string, fn func(current, total int64)) (err error) {
   177  	file, err := os.Open(local)
   178  	if err != nil {
   179  		return
   180  	}
   181  	defer file.Close()
   182  
   183  	stat, err := file.Stat()
   184  	if err != nil {
   185  		return
   186  	}
   187  
   188  	var (
   189  		total   = stat.Size()
   190  		current int64
   191  	)
   192  
   193  	return p.ScpReader(file, remote, password, func(size int) {
   194  		current += int64(size)
   195  		fn(current, total)
   196  	})
   197  }
   198  
   199  func (p *Pool) Command(remote, password, cmd string) (output string, err error) {
   200  	user, addr, _, err := parse(remote)
   201  	if err != nil {
   202  		return
   203  	}
   204  
   205  	session, err := p.Session(addr, user, password)
   206  	if err != nil {
   207  		return
   208  	}
   209  	defer session.Close()
   210  
   211  	return CommandSession(session, cmd)
   212  }
   213  
   214  func (p *Pool) loop() {
   215  	for range p.close {
   216  		for _, remote := range p.remote {
   217  			p.mutex.Lock()
   218  			p.init(remote.Addr, remote.Username, remote.Password)
   219  			p.mutex.Unlock()
   220  		}
   221  		time.Sleep(time.Second)
   222  	}
   223  }
   224  
   225  func (p *Pool) Init(remote ...Remote) {
   226  	p.mutex.Lock()
   227  	defer p.mutex.Unlock()
   228  	for _, remote := range remote {
   229  		p.init(remote.Addr, remote.Username, remote.Password)
   230  	}
   231  }
   232  
   233  func (p *Pool) Close() error {
   234  	p.close <- struct{}{}
   235  	p.wg.Wait()
   236  	var err error
   237  	for _, err := range p.errors {
   238  		if err != nil {
   239  			return err
   240  		}
   241  	}
   242  	return err
   243  }
   244  
   245  type client struct {
   246  	*ssh.Client
   247  }
   248  
   249  func (c *client) Ping() error {
   250  	session, err := c.NewSession()
   251  	if err != nil {
   252  		return err
   253  	}
   254  	return session.Close()
   255  }
   256  
   257  type Pool2 struct {
   258  	min   int
   259  	max   int
   260  	idle  time.Duration
   261  	pool  map[string]*pool.Pool
   262  	mutex sync.Mutex
   263  }
   264  
   265  func (p *Pool2) key(addr, username, password string) string {
   266  	return fmt.Sprintf("%s@%s:%s", username, addr, password)
   267  }
   268  
   269  func (p *Pool2) get(addr, username, password string) (*ssh.Client, error) {
   270  	var key = p.key(addr, username, password)
   271  
   272  	p.mutex.Lock()
   273  	defer p.mutex.Unlock()
   274  
   275  	if p.pool[key] == nil {
   276  		var factory = func() (entry pool.Entry, err error) {
   277  			cli, err := Client(username, password, addr)
   278  			if err != nil {
   279  				return
   280  			}
   281  			return &client{Client: cli}, nil
   282  		}
   283  
   284  		p.pool[key] = pool.New(p.min, p.max, p.idle, factory)
   285  	}
   286  
   287  	var ctx = context.Background()
   288  
   289  	cli, err := p.pool[key].Get(ctx)
   290  	if err != nil {
   291  		return nil, err
   292  	}
   293  
   294  	return cli.(*client).Client, nil
   295  }
   296  
   297  func (p *Pool2) put(client *client, addr, username, password string) (err error) {
   298  	var key = p.key(addr, username, password)
   299  
   300  	p.mutex.Lock()
   301  	defer p.mutex.Unlock()
   302  
   303  	if err = p.pool[key].Put(client); err != nil {
   304  		return
   305  	}
   306  
   307  	return
   308  }
   309  
   310  func (p *Pool2) Session(addr, username, password string) (session *ssh.Session, err error) {
   311  	cli, err := p.get(addr, username, password)
   312  	if err != nil {
   313  		return
   314  	}
   315  	defer p.put(&client{Client: cli}, addr, username, password)
   316  
   317  	return cli.NewSession()
   318  }
   319  
   320  func (p *Pool2) SftpClient(addr, username, password string) (*sftp.Client, error) {
   321  	cli, err := p.get(addr, username, password)
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  	defer p.put(&client{Client: cli}, addr, username, password)
   326  
   327  	return sftp.NewClient(cli)
   328  }
   329  
   330  func (p *Pool2) ScpReader(reader io.Reader, remote, password string, fn func(size int)) (err error) {
   331  	user, addr, filename, err := parse(remote)
   332  	if err != nil {
   333  		return
   334  	}
   335  
   336  	client, err := p.SftpClient(addr, user, password)
   337  	if err != nil {
   338  		return
   339  	}
   340  	defer client.Close()
   341  
   342  	return ScpReader(client, filename, newReader(reader, fn))
   343  }
   344  
   345  func (p *Pool2) Scp(local, remote, password string, fn func(current, total int64)) (err error) {
   346  	file, err := os.Open(local)
   347  	if err != nil {
   348  		return
   349  	}
   350  	defer file.Close()
   351  
   352  	stat, err := file.Stat()
   353  	if err != nil {
   354  		return
   355  	}
   356  
   357  	var (
   358  		total   = stat.Size()
   359  		current int64
   360  	)
   361  
   362  	return p.ScpReader(file, remote, password, func(size int) {
   363  		current += int64(size)
   364  		fn(current, total)
   365  	})
   366  }
   367  
   368  func (p *Pool2) Command(remote, password, cmd string) (output string, err error) {
   369  	user, addr, _, err := parse(remote)
   370  	if err != nil {
   371  		return
   372  	}
   373  
   374  	session, err := p.Session(addr, user, password)
   375  	if err != nil {
   376  		return
   377  	}
   378  	defer session.Close()
   379  
   380  	return CommandSession(session, cmd)
   381  }