github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/agent/client_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package agent
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"io"
    12  	"net"
    13  	"os"
    14  	"os/exec"
    15  	"path/filepath"
    16  	"runtime"
    17  	"strconv"
    18  	"strings"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
    24  )
    25  
    26  // startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
    27  func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) {
    28  	if testing.Short() {
    29  		// ssh-agent is not always available, and the key
    30  		// types supported vary by platform.
    31  		t.Skip("skipping test due to -short")
    32  	}
    33  
    34  	bin, err := exec.LookPath("ssh-agent")
    35  	if err != nil {
    36  		t.Skip("could not find ssh-agent")
    37  	}
    38  
    39  	cmd := exec.Command(bin, "-s")
    40  	cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior.
    41  	cmd.Stderr = new(bytes.Buffer)
    42  	out, err := cmd.Output()
    43  	if err != nil {
    44  		t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
    45  	}
    46  
    47  	// Output looks like:
    48  	//
    49  	//	SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
    50  	//	SSH_AGENT_PID=15542; export SSH_AGENT_PID;
    51  	//	echo Agent pid 15542;
    52  
    53  	fields := bytes.Split(out, []byte(";"))
    54  	line := bytes.SplitN(fields[0], []byte("="), 2)
    55  	line[0] = bytes.TrimLeft(line[0], "\n")
    56  	if string(line[0]) != "SSH_AUTH_SOCK" {
    57  		t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
    58  	}
    59  	socket = string(line[1])
    60  
    61  	line = bytes.SplitN(fields[2], []byte("="), 2)
    62  	line[0] = bytes.TrimLeft(line[0], "\n")
    63  	if string(line[0]) != "SSH_AGENT_PID" {
    64  		t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
    65  	}
    66  	pidStr := line[1]
    67  	pid, err := strconv.Atoi(string(pidStr))
    68  	if err != nil {
    69  		t.Fatalf("Atoi(%q): %v", pidStr, err)
    70  	}
    71  
    72  	conn, err := net.Dial("unix", string(socket))
    73  	if err != nil {
    74  		t.Fatalf("net.Dial: %v", err)
    75  	}
    76  
    77  	ac := NewClient(conn)
    78  	return ac, socket, func() {
    79  		proc, _ := os.FindProcess(pid)
    80  		if proc != nil {
    81  			proc.Kill()
    82  		}
    83  		conn.Close()
    84  		os.RemoveAll(filepath.Dir(socket))
    85  	}
    86  }
    87  
    88  func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) {
    89  	c1, c2, err := netPipe()
    90  	if err != nil {
    91  		t.Fatalf("netPipe: %v", err)
    92  	}
    93  	go ServeAgent(agent, c2)
    94  
    95  	return NewClient(c1), func() {
    96  		c1.Close()
    97  		c2.Close()
    98  	}
    99  }
   100  
   101  // startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
   102  func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) {
   103  	return startAgent(t, NewKeyring())
   104  }
   105  
   106  func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
   107  	agent, _, cleanup := startOpenSSHAgent(t)
   108  	defer cleanup()
   109  
   110  	testAgentInterface(t, agent, key, cert, lifetimeSecs)
   111  }
   112  
   113  func testKeyringAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
   114  	agent, cleanup := startKeyringAgent(t)
   115  	defer cleanup()
   116  
   117  	testAgentInterface(t, agent, key, cert, lifetimeSecs)
   118  }
   119  
   120  func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
   121  	signer, err := ssh.NewSignerFromKey(key)
   122  	if err != nil {
   123  		t.Fatalf("NewSignerFromKey(%T): %v", key, err)
   124  	}
   125  	// The agent should start up empty.
   126  	if keys, err := agent.List(); err != nil {
   127  		t.Fatalf("RequestIdentities: %v", err)
   128  	} else if len(keys) > 0 {
   129  		t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
   130  	}
   131  
   132  	// Attempt to insert the key, with certificate if specified.
   133  	var pubKey ssh.PublicKey
   134  	if cert != nil {
   135  		err = agent.Add(AddedKey{
   136  			PrivateKey:   key,
   137  			Certificate:  cert,
   138  			Comment:      "comment",
   139  			LifetimeSecs: lifetimeSecs,
   140  		})
   141  		pubKey = cert
   142  	} else {
   143  		err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs})
   144  		pubKey = signer.PublicKey()
   145  	}
   146  	if err != nil {
   147  		t.Fatalf("insert(%T): %v", key, err)
   148  	}
   149  
   150  	// Did the key get inserted successfully?
   151  	if keys, err := agent.List(); err != nil {
   152  		t.Fatalf("List: %v", err)
   153  	} else if len(keys) != 1 {
   154  		t.Fatalf("got %v, want 1 key", keys)
   155  	} else if keys[0].Comment != "comment" {
   156  		t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
   157  	} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
   158  		t.Fatalf("key mismatch")
   159  	}
   160  
   161  	// Can the agent make a valid signature?
   162  	data := []byte("hello")
   163  	sig, err := agent.Sign(pubKey, data)
   164  	if err != nil {
   165  		t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
   166  	}
   167  
   168  	if err := pubKey.Verify(data, sig); err != nil {
   169  		t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
   170  	}
   171  
   172  	// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
   173  	if pubKey.Type() == "ssh-rsa" {
   174  		sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
   175  			sig, err = agent.SignWithFlags(pubKey, data, flag)
   176  			if err != nil {
   177  				t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err)
   178  			}
   179  			if sig.Format != expectedSigFormat {
   180  				t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat)
   181  			}
   182  			if err := pubKey.Verify(data, sig); err != nil {
   183  				t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
   184  			}
   185  		}
   186  		sshFlagTest(0, ssh.SigAlgoRSA)
   187  		sshFlagTest(SignatureFlagRsaSha256, ssh.SigAlgoRSASHA2256)
   188  		sshFlagTest(SignatureFlagRsaSha512, ssh.SigAlgoRSASHA2512)
   189  	}
   190  
   191  	// If the key has a lifetime, is it removed when it should be?
   192  	if lifetimeSecs > 0 {
   193  		time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond)
   194  		keys, err := agent.List()
   195  		if err != nil {
   196  			t.Fatalf("List: %v", err)
   197  		}
   198  		if len(keys) > 0 {
   199  			t.Fatalf("key not expired")
   200  		}
   201  	}
   202  
   203  }
   204  
   205  func TestMalformedRequests(t *testing.T) {
   206  	keyringAgent := NewKeyring()
   207  	listener, err := netListener()
   208  	if err != nil {
   209  		t.Fatalf("netListener: %v", err)
   210  	}
   211  	defer listener.Close()
   212  
   213  	testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
   214  		var wg sync.WaitGroup
   215  		wg.Add(1)
   216  		go func() {
   217  			defer wg.Done()
   218  			c, err := listener.Accept()
   219  			if err != nil {
   220  				t.Errorf("listener.Accept: %v", err)
   221  				return
   222  			}
   223  			defer c.Close()
   224  
   225  			err = ServeAgent(keyringAgent, c)
   226  			if err == nil {
   227  				t.Error("ServeAgent should have returned an error to malformed input")
   228  			} else {
   229  				if (err != io.EOF) != wantServerErr {
   230  					t.Errorf("ServeAgent returned expected error: %v", err)
   231  				}
   232  			}
   233  		}()
   234  
   235  		c, err := net.Dial("tcp", listener.Addr().String())
   236  		if err != nil {
   237  			t.Fatalf("net.Dial: %v", err)
   238  		}
   239  		_, err = c.Write(requestBytes)
   240  		if err != nil {
   241  			t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
   242  		}
   243  		c.Close()
   244  		wg.Wait()
   245  	}
   246  
   247  	var testCases = []struct {
   248  		name          string
   249  		requestBytes  []byte
   250  		wantServerErr bool
   251  	}{
   252  		{"Empty request", []byte{}, false},
   253  		{"Short header", []byte{0x00}, true},
   254  		{"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
   255  		{"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
   256  	}
   257  	for _, tc := range testCases {
   258  		t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
   259  	}
   260  }
   261  
   262  func TestAgent(t *testing.T) {
   263  	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
   264  		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
   265  		testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
   266  	}
   267  }
   268  
   269  func TestCert(t *testing.T) {
   270  	cert := &ssh.Certificate{
   271  		Key:         testPublicKeys["rsa"],
   272  		ValidBefore: ssh.CertTimeInfinity,
   273  		CertType:    ssh.UserCert,
   274  	}
   275  	cert.SignCert(rand.Reader, testSigners["ecdsa"])
   276  
   277  	testOpenSSHAgent(t, testPrivateKeys["rsa"], cert, 0)
   278  	testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
   279  }
   280  
   281  // netListener creates a localhost network listener.
   282  func netListener() (net.Listener, error) {
   283  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   284  	if err != nil {
   285  		listener, err = net.Listen("tcp", "[::1]:0")
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  	}
   290  	return listener, nil
   291  }
   292  
   293  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
   294  // therefore is buffered (net.Pipe deadlocks if both sides start with
   295  // a write.)
   296  func netPipe() (net.Conn, net.Conn, error) {
   297  	listener, err := netListener()
   298  	if err != nil {
   299  		return nil, nil, err
   300  	}
   301  	defer listener.Close()
   302  	c1, err := net.Dial("tcp", listener.Addr().String())
   303  	if err != nil {
   304  		return nil, nil, err
   305  	}
   306  
   307  	c2, err := listener.Accept()
   308  	if err != nil {
   309  		c1.Close()
   310  		return nil, nil, err
   311  	}
   312  
   313  	return c1, c2, nil
   314  }
   315  
   316  func TestServerResponseTooLarge(t *testing.T) {
   317  	a, b, err := netPipe()
   318  	if err != nil {
   319  		t.Fatalf("netPipe: %v", err)
   320  	}
   321  	done := make(chan struct{})
   322  	defer func() { <-done }()
   323  
   324  	defer a.Close()
   325  	defer b.Close()
   326  
   327  	var response identitiesAnswerAgentMsg
   328  	response.NumKeys = 1
   329  	response.Keys = make([]byte, maxAgentResponseBytes+1)
   330  
   331  	agent := NewClient(a)
   332  	go func() {
   333  		defer close(done)
   334  		n, err := b.Write(ssh.Marshal(response))
   335  		if n < 4 {
   336  			if runtime.GOOS == "plan9" {
   337  				if e1, ok := err.(*net.OpError); ok {
   338  					if e2, ok := e1.Err.(*os.PathError); ok {
   339  						switch e2.Err.Error() {
   340  						case "Hangup", "i/o on hungup channel":
   341  							// syscall.Pwrite returns -1 in this case even when some data did get written.
   342  							return
   343  						}
   344  					}
   345  				}
   346  			}
   347  			t.Errorf("At least 4 bytes (the response size) should have been successfully written: %d < 4: %v", n, err)
   348  		}
   349  	}()
   350  	_, err = agent.List()
   351  	if err == nil {
   352  		t.Fatal("Did not get error result")
   353  	}
   354  	if err.Error() != "agent: client error: response too large" {
   355  		t.Fatal("Did not get expected error result")
   356  	}
   357  }
   358  
   359  func TestAuth(t *testing.T) {
   360  	agent, _, cleanup := startOpenSSHAgent(t)
   361  	defer cleanup()
   362  
   363  	a, b, err := netPipe()
   364  	if err != nil {
   365  		t.Fatalf("netPipe: %v", err)
   366  	}
   367  
   368  	defer a.Close()
   369  	defer b.Close()
   370  
   371  	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
   372  		t.Errorf("Add: %v", err)
   373  	}
   374  
   375  	serverConf := ssh.ServerConfig{}
   376  	serverConf.AddHostKey(testSigners["rsa"])
   377  	serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
   378  		if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
   379  			return nil, nil
   380  		}
   381  
   382  		return nil, errors.New("pubkey rejected")
   383  	}
   384  
   385  	go func() {
   386  		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
   387  		if err != nil {
   388  			t.Fatalf("Server: %v", err)
   389  		}
   390  		conn.Close()
   391  	}()
   392  
   393  	conf := ssh.ClientConfig{
   394  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
   395  	}
   396  	conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
   397  	conn, _, _, err := ssh.NewClientConn(b, "", &conf)
   398  	if err != nil {
   399  		t.Fatalf("NewClientConn: %v", err)
   400  	}
   401  	conn.Close()
   402  }
   403  
   404  func TestLockOpenSSHAgent(t *testing.T) {
   405  	agent, _, cleanup := startOpenSSHAgent(t)
   406  	defer cleanup()
   407  	testLockAgent(agent, t)
   408  }
   409  
   410  func TestLockKeyringAgent(t *testing.T) {
   411  	agent, cleanup := startKeyringAgent(t)
   412  	defer cleanup()
   413  	testLockAgent(agent, t)
   414  }
   415  
   416  func testLockAgent(agent Agent, t *testing.T) {
   417  	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
   418  		t.Errorf("Add: %v", err)
   419  	}
   420  	if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil {
   421  		t.Errorf("Add: %v", err)
   422  	}
   423  	if keys, err := agent.List(); err != nil {
   424  		t.Errorf("List: %v", err)
   425  	} else if len(keys) != 2 {
   426  		t.Errorf("Want 2 keys, got %v", keys)
   427  	}
   428  
   429  	passphrase := []byte("secret")
   430  	if err := agent.Lock(passphrase); err != nil {
   431  		t.Errorf("Lock: %v", err)
   432  	}
   433  
   434  	if keys, err := agent.List(); err != nil {
   435  		t.Errorf("List: %v", err)
   436  	} else if len(keys) != 0 {
   437  		t.Errorf("Want 0 keys, got %v", keys)
   438  	}
   439  
   440  	signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
   441  	if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
   442  		t.Fatalf("Sign did not fail")
   443  	}
   444  
   445  	if err := agent.Remove(signer.PublicKey()); err == nil {
   446  		t.Fatalf("Remove did not fail")
   447  	}
   448  
   449  	if err := agent.RemoveAll(); err == nil {
   450  		t.Fatalf("RemoveAll did not fail")
   451  	}
   452  
   453  	if err := agent.Unlock(nil); err == nil {
   454  		t.Errorf("Unlock with wrong passphrase succeeded")
   455  	}
   456  	if err := agent.Unlock(passphrase); err != nil {
   457  		t.Errorf("Unlock: %v", err)
   458  	}
   459  
   460  	if err := agent.Remove(signer.PublicKey()); err != nil {
   461  		t.Fatalf("Remove: %v", err)
   462  	}
   463  
   464  	if keys, err := agent.List(); err != nil {
   465  		t.Errorf("List: %v", err)
   466  	} else if len(keys) != 1 {
   467  		t.Errorf("Want 1 keys, got %v", keys)
   468  	}
   469  }
   470  
   471  func testOpenSSHAgentLifetime(t *testing.T) {
   472  	agent, _, cleanup := startOpenSSHAgent(t)
   473  	defer cleanup()
   474  	testAgentLifetime(t, agent)
   475  }
   476  
   477  func testKeyringAgentLifetime(t *testing.T) {
   478  	agent, cleanup := startKeyringAgent(t)
   479  	defer cleanup()
   480  	testAgentLifetime(t, agent)
   481  }
   482  
   483  func testAgentLifetime(t *testing.T, agent Agent) {
   484  	for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
   485  		// Add private keys to the agent.
   486  		err := agent.Add(AddedKey{
   487  			PrivateKey:   testPrivateKeys[keyType],
   488  			Comment:      "comment",
   489  			LifetimeSecs: 1,
   490  		})
   491  		if err != nil {
   492  			t.Fatalf("add: %v", err)
   493  		}
   494  		// Add certs to the agent.
   495  		cert := &ssh.Certificate{
   496  			Key:         testPublicKeys[keyType],
   497  			ValidBefore: ssh.CertTimeInfinity,
   498  			CertType:    ssh.UserCert,
   499  		}
   500  		cert.SignCert(rand.Reader, testSigners[keyType])
   501  		err = agent.Add(AddedKey{
   502  			PrivateKey:   testPrivateKeys[keyType],
   503  			Certificate:  cert,
   504  			Comment:      "comment",
   505  			LifetimeSecs: 1,
   506  		})
   507  		if err != nil {
   508  			t.Fatalf("add: %v", err)
   509  		}
   510  	}
   511  	time.Sleep(1100 * time.Millisecond)
   512  	if keys, err := agent.List(); err != nil {
   513  		t.Errorf("List: %v", err)
   514  	} else if len(keys) != 0 {
   515  		t.Errorf("Want 0 keys, got %v", len(keys))
   516  	}
   517  }
   518  
   519  type keyringExtended struct {
   520  	*keyring
   521  }
   522  
   523  func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) {
   524  	if extensionType != "my-extension@example.com" {
   525  		return []byte{agentExtensionFailure}, nil
   526  	}
   527  	return append([]byte{agentSuccess}, contents...), nil
   528  }
   529  
   530  func TestAgentExtensions(t *testing.T) {
   531  	agent, _, cleanup := startOpenSSHAgent(t)
   532  	defer cleanup()
   533  	_, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
   534  	if err == nil {
   535  		t.Fatal("should have gotten agent extension failure")
   536  	}
   537  
   538  	agent, cleanup = startAgent(t, &keyringExtended{})
   539  	defer cleanup()
   540  	result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
   541  	if err != nil {
   542  		t.Fatalf("agent extension failure: %v", err)
   543  	}
   544  	if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) {
   545  		t.Fatalf("agent extension result invalid: %v", result)
   546  	}
   547  
   548  	_, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02})
   549  	if err == nil {
   550  		t.Fatal("should have gotten agent extension failure")
   551  	}
   552  }