github.com/digineo/afero@v1.1.1/sftpfs/sftp_test_go (about)

     1  // Copyright © 2015 Jerry Jacobs <jerry.jacobs@xor-gate.org>.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package afero
    15  
    16  import (
    17  	"testing"
    18  	"os"
    19  	"log"
    20  	"fmt"
    21  	"net"
    22  	"flag"
    23  	"time"
    24  	"io/ioutil"
    25  	"crypto/rsa"
    26  	_rand "crypto/rand"
    27  	"encoding/pem"
    28  	"crypto/x509"
    29  
    30  	"golang.org/x/crypto/ssh"
    31  	"github.com/pkg/sftp"
    32  )
    33  
    34  type SftpFsContext struct {
    35  	sshc   *ssh.Client
    36  	sshcfg *ssh.ClientConfig
    37  	sftpc  *sftp.Client
    38  }
    39  
    40  // TODO we only connect with hardcoded user+pass for now
    41  // it should be possible to use $HOME/.ssh/id_rsa to login into the stub sftp server
    42  func SftpConnect(user, password, host string) (*SftpFsContext, error) {
    43  /*
    44  	pemBytes, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/id_rsa")
    45  	if err != nil {
    46  		return nil,err
    47  	}
    48  
    49  	signer, err := ssh.ParsePrivateKey(pemBytes)
    50  	if err != nil {
    51  		return nil,err
    52  	}
    53  
    54  	sshcfg := &ssh.ClientConfig{
    55  		User: user,
    56  		Auth: []ssh.AuthMethod{
    57  			ssh.Password(password),
    58  			ssh.PublicKeys(signer),
    59  		},
    60  	}
    61  */
    62  
    63  	sshcfg := &ssh.ClientConfig{
    64  		User: user,
    65  		Auth: []ssh.AuthMethod{
    66  			ssh.Password(password),
    67  		},
    68  	}
    69  
    70  	sshc, err := ssh.Dial("tcp", host, sshcfg)
    71  	if err != nil {
    72  		return nil,err
    73  	}
    74  
    75  	sftpc, err := sftp.NewClient(sshc)
    76  	if err != nil {
    77  		return nil,err
    78  	}
    79  
    80  	ctx := &SftpFsContext{
    81  		sshc: sshc,
    82  		sshcfg: sshcfg,
    83  		sftpc: sftpc,
    84  	}
    85  
    86  	return ctx,nil
    87  }
    88  
    89  func (ctx *SftpFsContext) Disconnect() error {
    90  	ctx.sftpc.Close()
    91  	ctx.sshc.Close()
    92  	return nil
    93  }
    94  
    95  // TODO for such a weird reason rootpath is "." when writing "file1" with afero sftp backend
    96  func RunSftpServer(rootpath string) {
    97  	var (
    98  		readOnly      bool
    99  		debugLevelStr string
   100  		debugLevel    int
   101  		debugStderr   bool
   102  		rootDir       string
   103  	)
   104  
   105  	flag.BoolVar(&readOnly, "R", false, "read-only server")
   106  	flag.BoolVar(&debugStderr, "e", true, "debug to stderr")
   107  	flag.StringVar(&debugLevelStr, "l", "none", "debug level")
   108  	flag.StringVar(&rootDir, "root", rootpath, "root directory")
   109  	flag.Parse()
   110  
   111  	debugStream := ioutil.Discard
   112  	if debugStderr {
   113  		debugStream = os.Stderr
   114  		debugLevel = 1
   115  	}
   116  
   117  	// An SSH server is represented by a ServerConfig, which holds
   118  	// certificate details and handles authentication of ServerConns.
   119  	config := &ssh.ServerConfig{
   120  		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
   121  			// Should use constant-time compare (or better, salt+hash) in
   122  			// a production setting.
   123  			fmt.Fprintf(debugStream, "Login: %s\n", c.User())
   124  			if c.User() == "test" && string(pass) == "test" {
   125  				return nil, nil
   126  			}
   127  			return nil, fmt.Errorf("password rejected for %q", c.User())
   128  		},
   129  	}
   130  
   131  	privateBytes, err := ioutil.ReadFile("./test/id_rsa")
   132  	if err != nil {
   133  		log.Fatal("Failed to load private key", err)
   134  	}
   135  
   136  	private, err := ssh.ParsePrivateKey(privateBytes)
   137  	if err != nil {
   138  		log.Fatal("Failed to parse private key", err)
   139  	}
   140  
   141  	config.AddHostKey(private)
   142  
   143  	// Once a ServerConfig has been configured, connections can be
   144  	// accepted.
   145  	listener, err := net.Listen("tcp", "0.0.0.0:2022")
   146  	if err != nil {
   147  		log.Fatal("failed to listen for connection", err)
   148  	}
   149  	fmt.Printf("Listening on %v\n", listener.Addr())
   150  
   151  	nConn, err := listener.Accept()
   152  	if err != nil {
   153  		log.Fatal("failed to accept incoming connection", err)
   154  	}
   155  
   156  	// Before use, a handshake must be performed on the incoming
   157  	// net.Conn.
   158  	_, chans, reqs, err := ssh.NewServerConn(nConn, config)
   159  	if err != nil {
   160  		log.Fatal("failed to handshake", err)
   161  	}
   162  	fmt.Fprintf(debugStream, "SSH server established\n")
   163  
   164  	// The incoming Request channel must be serviced.
   165  	go ssh.DiscardRequests(reqs)
   166  
   167  	// Service the incoming Channel channel.
   168  	for newChannel := range chans {
   169  		// Channels have a type, depending on the application level
   170  		// protocol intended. In the case of an SFTP session, this is "subsystem"
   171  		// with a payload string of "<length=4>sftp"
   172  		fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType())
   173  		if newChannel.ChannelType() != "session" {
   174  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
   175  			fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType())
   176  			continue
   177  		}
   178  		channel, requests, err := newChannel.Accept()
   179  		if err != nil {
   180  			log.Fatal("could not accept channel.", err)
   181  		}
   182  		fmt.Fprintf(debugStream, "Channel accepted\n")
   183  
   184  		// Sessions have out-of-band requests such as "shell",
   185  		// "pty-req" and "env".  Here we handle only the
   186  		// "subsystem" request.
   187  		go func(in <-chan *ssh.Request) {
   188  			for req := range in {
   189  				fmt.Fprintf(debugStream, "Request: %v\n", req.Type)
   190  				ok := false
   191  				switch req.Type {
   192  				case "subsystem":
   193  					fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:])
   194  					if string(req.Payload[4:]) == "sftp" {
   195  						ok = true
   196  					}
   197  				}
   198  				fmt.Fprintf(debugStream, " - accepted: %v\n", ok)
   199  				req.Reply(ok, nil)
   200  			}
   201  		}(requests)
   202  
   203  		server, err := sftp.NewServer(channel, channel, debugStream, debugLevel, readOnly, rootpath)
   204  		if err != nil {
   205  			log.Fatal(err)
   206  		}
   207  		if err := server.Serve(); err != nil {
   208  			log.Fatal("sftp server completed with error:", err)
   209  		}
   210  	}
   211  }
   212  
   213  // MakeSSHKeyPair make a pair of public and private keys for SSH access.
   214  // Public key is encoded in the format for inclusion in an OpenSSH authorized_keys file.
   215  // Private Key generated is PEM encoded
   216  func MakeSSHKeyPair(bits int, pubKeyPath, privateKeyPath string) error {
   217  	privateKey, err := rsa.GenerateKey(_rand.Reader, bits)
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	// generate and write private key as PEM
   223  	privateKeyFile, err := os.Create(privateKeyPath)
   224  	defer privateKeyFile.Close()
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	privateKeyPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}
   230  	if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
   231  		return err
   232  	}
   233  
   234  	// generate and write public key
   235  	pub, err := ssh.NewPublicKey(&privateKey.PublicKey)
   236  	if err != nil {
   237  		return err
   238  	}
   239  
   240  	return ioutil.WriteFile(pubKeyPath, ssh.MarshalAuthorizedKey(pub), 0655)
   241  }
   242  
   243  func TestSftpCreate(t *testing.T) {
   244  	os.Mkdir("./test", 0777)
   245  	MakeSSHKeyPair(1024, "./test/id_rsa.pub", "./test/id_rsa")
   246  
   247  	go RunSftpServer("./test/")
   248  	time.Sleep(5 * time.Second)
   249  
   250  	ctx, err := SftpConnect("test", "test", "localhost:2022")
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	defer ctx.Disconnect()
   255  
   256  	var AppFs Fs = SftpFs{
   257  		SftpClient: ctx.sftpc,
   258  	}
   259  
   260  	AppFs.MkdirAll("test/dir1/dir2/dir3", os.FileMode(0777))
   261  	AppFs.Mkdir("test/foo", os.FileMode(0000))
   262  	AppFs.Chmod("test/foo", os.FileMode(0700))
   263  	AppFs.Mkdir("test/bar", os.FileMode(0777))
   264  
   265  	file, err := AppFs.Create("file1")
   266  	if err != nil {
   267  		t.Error(err)
   268  	}
   269  	defer file.Close()
   270  
   271  	file.Write([]byte("hello\t"))
   272  	file.WriteString("world!\n")
   273  
   274  	f1, err := AppFs.Open("file1")
   275  	if err != nil {
   276  		log.Fatalf("open: %v", err)
   277  	}
   278  	defer f1.Close()
   279  
   280  	b := make([]byte, 100)
   281  
   282  	_, err = f1.Read(b)
   283  	fmt.Println(string(b))
   284  
   285  	// TODO check here if "hello\tworld\n" is in buffer b
   286  }