github.com/zooyer/miskit@v1.0.71/ssh/ssh.go (about)

     1  package ssh
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"net"
     7  	"os"
     8  	"os/user"
     9  	"runtime"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/pkg/sftp"
    15  	"golang.org/x/crypto/ssh"
    16  )
    17  
    18  // 创建ssh客户端
    19  func Client(user, password, addr string) (client *ssh.Client, err error) {
    20  	var config = ssh.ClientConfig{
    21  		User:    user,
    22  		Auth:    []ssh.AuthMethod{ssh.Password(password)},
    23  		Timeout: time.Second * 30,
    24  		//这个是问你要不要验证远程主机,以保证安全性。这里不验证
    25  		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
    26  			return nil
    27  		},
    28  	}
    29  
    30  	if client, err = ssh.Dial("tcp", addr, &config); err != nil {
    31  		return
    32  	}
    33  
    34  	return
    35  }
    36  
    37  // 创建ssh客户端与会话
    38  func Session(user, password, addr string) (client *ssh.Client, session *ssh.Session, err error) {
    39  	if client, err = Client(user, password, addr); err != nil {
    40  		return
    41  	}
    42  
    43  	if session, err = client.NewSession(); err != nil {
    44  		return
    45  	}
    46  
    47  	return
    48  }
    49  
    50  // 创建sftp客户端
    51  func SftpClient(user, password, addr string) (client *sftp.Client, err error) {
    52  	sshClient, err := Client(user, password, addr)
    53  	if err != nil {
    54  		return
    55  	}
    56  
    57  	if client, err = sftp.NewClient(sshClient); err != nil {
    58  		return
    59  	}
    60  
    61  	return
    62  }
    63  
    64  // 数据拷贝
    65  func ScpReader(client *sftp.Client, filename string, reader io.Reader) (err error) {
    66  	if client == nil {
    67  		return errors.New("sftp client is nil")
    68  	}
    69  	if reader == nil {
    70  		return errors.New("sftp reader is nil")
    71  	}
    72  
    73  	file, err := client.Create(filename)
    74  	if err != nil {
    75  		return
    76  	}
    77  	defer file.Close()
    78  
    79  	if _, err = io.Copy(file, reader); err != nil {
    80  		return
    81  	}
    82  
    83  	return
    84  }
    85  
    86  // 文件拷贝
    87  func Scp(local, remote, password string, fn func(current, total int64)) (err error) {
    88  	user, addr, filename, err := parse(remote)
    89  	if err != nil {
    90  		return
    91  	}
    92  
    93  	sftpClient, err := SftpClient(user, password, addr)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer sftpClient.Close()
    98  
    99  	src, err := os.Open(local)
   100  	if err != nil {
   101  		return err
   102  	}
   103  	defer src.Close()
   104  
   105  	var total int64
   106  	if fn != nil {
   107  		stat, err := src.Stat()
   108  		if err != nil {
   109  			return err
   110  		}
   111  		total = stat.Size()
   112  	}
   113  
   114  	if err = ScpReader(sftpClient, filename, newReader(src, func(size int) {
   115  		if fn != nil {
   116  			fn(int64(size), total)
   117  		}
   118  	})); err != nil {
   119  
   120  	}
   121  
   122  	return
   123  }
   124  
   125  // 会话执行命令
   126  func CommandSession(session *ssh.Session, cmd string) (output string, err error) {
   127  	data, err := session.CombinedOutput(cmd)
   128  	if err != nil {
   129  		return
   130  	}
   131  	return strings.TrimSpace(string(data)), nil
   132  }
   133  
   134  // 执行命令
   135  func Command(remote, password, cmd string) (output string, err error) {
   136  	user, addr, _, err := parse(remote)
   137  	if err != nil {
   138  		return
   139  	}
   140  
   141  	client, session, err := Session(user, password, addr)
   142  	if err != nil {
   143  		return
   144  	}
   145  	defer client.Close()
   146  	defer session.Close()
   147  
   148  	return CommandSession(session, cmd)
   149  }
   150  
   151  func username() string {
   152  	u, err := user.Current()
   153  	if err != nil {
   154  		return ""
   155  	}
   156  	if runtime.GOOS == "windows" {
   157  		if fields := strings.Split(u.Username, "\\"); len(fields) > 1 {
   158  			return fields[1]
   159  		}
   160  	}
   161  	return u.Username
   162  }
   163  
   164  func parse(target string) (user, addr, filename string, err error) {
   165  	var port = "22"
   166  
   167  	defer func() {
   168  		addr += ":" + port
   169  		if user == "" {
   170  			user = username()
   171  		}
   172  		if filename == "" {
   173  			filename = "/"
   174  		}
   175  	}()
   176  
   177  	if index := strings.Index(target, "@"); index != -1 {
   178  		user = target[:index]
   179  		target = target[index+1:]
   180  	}
   181  
   182  	switch index := strings.Index(target, ":"); strings.Count(target, ":") {
   183  	case 1:
   184  		addr = target[:index]
   185  		if target = target[index+1:]; len(target) == 0 {
   186  			return
   187  		}
   188  		if c := target[0]; c >= '0' && c <= '9' {
   189  			if _, err = strconv.Atoi(target[1:]); err != nil {
   190  				return
   191  			}
   192  			port = target
   193  		} else {
   194  			filename = target
   195  		}
   196  	case 2:
   197  		addr = target[:index]
   198  		target = target[index+1:]
   199  		if index = strings.Index(target, ":"); index != -1 {
   200  			port = target[:index]
   201  			target = target[index+1:]
   202  		}
   203  		filename = target
   204  	default:
   205  		addr = target
   206  	}
   207  
   208  	return
   209  }