github.com/glycerine/xcryptossh@v7.0.4+incompatible/test/test_unix_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build darwin dragonfly freebsd linux netbsd openbsd plan9
     6  
     7  package test
     8  
     9  // functional test harness for unix.
    10  
    11  import (
    12  	"bytes"
    13  	"context"
    14  	"fmt"
    15  	"io/ioutil"
    16  	"log"
    17  	"net"
    18  	"os"
    19  	"os/exec"
    20  	"os/user"
    21  	"path/filepath"
    22  	"testing"
    23  	"text/template"
    24  
    25  	"github.com/glycerine/xcryptossh"
    26  	"github.com/glycerine/xcryptossh/testdata"
    27  )
    28  
    29  const sshd_config = `
    30  Protocol 2
    31  HostKey {{.Dir}}/id_rsa
    32  HostKey {{.Dir}}/id_dsa
    33  HostKey {{.Dir}}/id_ecdsa
    34  HostCertificate {{.Dir}}/id_rsa-cert.pub
    35  Pidfile {{.Dir}}/sshd.pid
    36  #UsePrivilegeSeparation no
    37  #deprecated: KeyRegenerationInterval 3600
    38  #deprecated: ServerKeyBits 768
    39  SyslogFacility AUTH
    40  LogLevel DEBUG2
    41  LoginGraceTime 120
    42  PermitRootLogin no
    43  StrictModes no
    44  #deprecated: RSAAuthentication yes
    45  PubkeyAuthentication yes
    46  AuthorizedKeysFile	{{.Dir}}/authorized_keys
    47  TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
    48  IgnoreRhosts yes
    49  #deprecated: RhostsRSAAuthentication no
    50  HostbasedAuthentication no
    51  PubkeyAcceptedKeyTypes=*
    52  `
    53  
    54  var configTmpl = template.Must(template.New("").Parse(sshd_config))
    55  
    56  type server struct {
    57  	t          *testing.T
    58  	cleanup    func() // executed during Shutdown
    59  	configfile string
    60  	cmd        *exec.Cmd
    61  	output     bytes.Buffer // holds stderr from sshd process
    62  
    63  	// Client half of the network connection.
    64  	clientConn net.Conn
    65  }
    66  
    67  func username() string {
    68  	var username string
    69  	if user, err := user.Current(); err == nil {
    70  		username = user.Username
    71  	} else {
    72  		// user.Current() currently requires cgo. If an error is
    73  		// returned attempt to get the username from the environment.
    74  		log.Printf("user.Current: %v; falling back on $USER", err)
    75  		username = os.Getenv("USER")
    76  	}
    77  	if username == "" {
    78  		panic("Unable to get username")
    79  	}
    80  	return username
    81  }
    82  
    83  type storedHostKey struct {
    84  	// keys map from an algorithm string to binary key data.
    85  	keys map[string][]byte
    86  
    87  	// checkCount counts the Check calls. Used for testing
    88  	// rekeying.
    89  	checkCount int
    90  }
    91  
    92  func (k *storedHostKey) Add(key ssh.PublicKey) {
    93  	if k.keys == nil {
    94  		k.keys = map[string][]byte{}
    95  	}
    96  	k.keys[key.Type()] = key.Marshal()
    97  }
    98  
    99  func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
   100  	k.checkCount++
   101  	algo := key.Type()
   102  
   103  	if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
   104  		return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
   105  	}
   106  	return nil
   107  }
   108  
   109  func hostKeyDB() *storedHostKey {
   110  	keyChecker := &storedHostKey{}
   111  	keyChecker.Add(testPublicKeys["ecdsa"])
   112  	keyChecker.Add(testPublicKeys["rsa"])
   113  	keyChecker.Add(testPublicKeys["dsa"])
   114  	return keyChecker
   115  }
   116  
   117  func clientConfig(halt *ssh.Halter) *ssh.ClientConfig {
   118  	config := &ssh.ClientConfig{
   119  		User: username(),
   120  		Auth: []ssh.AuthMethod{
   121  			ssh.PublicKeys(testSigners["user"]),
   122  		},
   123  		HostKeyCallback: hostKeyDB().Check,
   124  		HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
   125  			ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
   126  			ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
   127  			ssh.KeyAlgoED25519,
   128  		},
   129  		Config: ssh.Config{Halt: halt},
   130  	}
   131  	return config
   132  }
   133  
   134  // unixConnection creates two halves of a connected net.UnixConn.  It
   135  // is used for connecting the Go SSH client with sshd without opening
   136  // ports.
   137  func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
   138  	dir, err := ioutil.TempDir("", "unixConnection")
   139  	if err != nil {
   140  		return nil, nil, err
   141  	}
   142  	defer os.Remove(dir)
   143  
   144  	addr := filepath.Join(dir, "ssh")
   145  	listener, err := net.Listen("unix", addr)
   146  	if err != nil {
   147  		return nil, nil, err
   148  	}
   149  	defer listener.Close()
   150  	c1, err := net.Dial("unix", addr)
   151  	if err != nil {
   152  		return nil, nil, err
   153  	}
   154  
   155  	c2, err := listener.Accept()
   156  	if err != nil {
   157  		c1.Close()
   158  		return nil, nil, err
   159  	}
   160  
   161  	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
   162  }
   163  
   164  func (s *server) TryDial(ctx context.Context, config *ssh.ClientConfig) (*ssh.Client, error) {
   165  	return s.TryDialWithAddr(ctx, config, "")
   166  }
   167  
   168  // addr is the user specified host:port. While we don't actually dial it,
   169  // we need to know this for host key matching
   170  func (s *server) TryDialWithAddr(ctx context.Context, config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
   171  	sshd, err := exec.LookPath("sshd")
   172  	if err != nil {
   173  		s.t.Skipf("skipping test: %v", err)
   174  	}
   175  
   176  	c1, c2, err := unixConnection()
   177  	if err != nil {
   178  		s.t.Fatalf("unixConnection: %v", err)
   179  	}
   180  
   181  	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
   182  	f, err := c2.File()
   183  	if err != nil {
   184  		s.t.Fatalf("UnixConn.File: %v", err)
   185  	}
   186  	defer f.Close()
   187  	s.cmd.Stdin = f
   188  	s.cmd.Stdout = f
   189  	s.cmd.Stderr = &s.output
   190  	if err := s.cmd.Start(); err != nil {
   191  		s.t.Fail()
   192  		s.Shutdown()
   193  		s.t.Fatalf("s.cmd.Start: %v", err)
   194  	}
   195  	s.clientConn = c1
   196  	conn, chans, reqs, err := ssh.NewClientConn(ctx, c1, addr, config)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	return ssh.NewClient(ctx, conn, chans, reqs, config.Halt), nil
   201  }
   202  
   203  func (s *server) Dial(ctx context.Context, config *ssh.ClientConfig) *ssh.Client {
   204  	conn, err := s.TryDial(ctx, config)
   205  	if err != nil {
   206  		s.t.Fail()
   207  		s.Shutdown()
   208  		s.t.Fatalf("ssh.Client: %v", err)
   209  	}
   210  	return conn
   211  }
   212  
   213  func (s *server) Shutdown() {
   214  	if s.cmd != nil && s.cmd.Process != nil {
   215  		// Don't check for errors; if it fails it's most
   216  		// likely "os: process already finished", and we don't
   217  		// care about that. Use os.Interrupt, so child
   218  		// processes are killed too.
   219  		s.cmd.Process.Signal(os.Interrupt)
   220  		s.cmd.Wait()
   221  	}
   222  	if s.t.Failed() {
   223  		// log any output from sshd process
   224  		s.t.Logf("sshd: %s", s.output.String())
   225  	}
   226  	s.cleanup()
   227  }
   228  
   229  func writeFile(path string, contents []byte) {
   230  	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
   231  	if err != nil {
   232  		panic(err)
   233  	}
   234  	defer f.Close()
   235  	if _, err := f.Write(contents); err != nil {
   236  		panic(err)
   237  	}
   238  }
   239  
   240  // newServer returns a new mock ssh server.
   241  func newServer(t *testing.T) *server {
   242  	if testing.Short() {
   243  		t.Skip("skipping test due to -short")
   244  	}
   245  	dir, err := ioutil.TempDir("", "sshtest")
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	f, err := os.Create(filepath.Join(dir, "sshd_config"))
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  	err = configTmpl.Execute(f, map[string]string{
   254  		"Dir": dir,
   255  	})
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	f.Close()
   260  
   261  	for k, v := range testdata.PEMBytes {
   262  		filename := "id_" + k
   263  		writeFile(filepath.Join(dir, filename), v)
   264  		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
   265  	}
   266  
   267  	for k, v := range testdata.SSHCertificates {
   268  		filename := "id_" + k + "-cert.pub"
   269  		writeFile(filepath.Join(dir, filename), v)
   270  	}
   271  
   272  	var authkeys bytes.Buffer
   273  	for k, _ := range testdata.PEMBytes {
   274  		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
   275  	}
   276  	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
   277  
   278  	return &server{
   279  		t:          t,
   280  		configfile: f.Name(),
   281  		cleanup: func() {
   282  			if err := os.RemoveAll(dir); err != nil {
   283  				t.Error(err)
   284  			}
   285  		},
   286  	}
   287  }
   288  
   289  func newTempSocket(t *testing.T) (string, func()) {
   290  	dir, err := ioutil.TempDir("", "socket")
   291  	if err != nil {
   292  		t.Fatal(err)
   293  	}
   294  	deferFunc := func() { os.RemoveAll(dir) }
   295  	addr := filepath.Join(dir, "sock")
   296  	return addr, deferFunc
   297  }