github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/agent/server_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package agent
     6  
     7  import (
     8  	"crypto"
     9  	"crypto/rand"
    10  	"fmt"
    11  	pseudorand "math/rand"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
    17  )
    18  
    19  func TestServer(t *testing.T) {
    20  	c1, c2, err := netPipe()
    21  	if err != nil {
    22  		t.Fatalf("netPipe: %v", err)
    23  	}
    24  	defer c1.Close()
    25  	defer c2.Close()
    26  	client := NewClient(c1)
    27  
    28  	go ServeAgent(NewKeyring(), c2)
    29  
    30  	testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
    31  }
    32  
    33  func TestLockServer(t *testing.T) {
    34  	testLockAgent(NewKeyring(), t)
    35  }
    36  
    37  func TestSetupForwardAgent(t *testing.T) {
    38  	a, b, err := netPipe()
    39  	if err != nil {
    40  		t.Fatalf("netPipe: %v", err)
    41  	}
    42  
    43  	defer a.Close()
    44  	defer b.Close()
    45  
    46  	_, socket, cleanup := startOpenSSHAgent(t)
    47  	defer cleanup()
    48  
    49  	serverConf := ssh.ServerConfig{
    50  		NoClientAuth: true,
    51  	}
    52  	serverConf.AddHostKey(testSigners["rsa"])
    53  	incoming := make(chan *ssh.ServerConn, 1)
    54  	go func() {
    55  		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
    56  		if err != nil {
    57  			t.Fatalf("Server: %v", err)
    58  		}
    59  		incoming <- conn
    60  	}()
    61  
    62  	conf := ssh.ClientConfig{
    63  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    64  	}
    65  	conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
    66  	if err != nil {
    67  		t.Fatalf("NewClientConn: %v", err)
    68  	}
    69  	client := ssh.NewClient(conn, chans, reqs)
    70  
    71  	if err := ForwardToRemote(client, socket); err != nil {
    72  		t.Fatalf("SetupForwardAgent: %v", err)
    73  	}
    74  
    75  	server := <-incoming
    76  	ch, reqs, err := server.OpenChannel(channelType, nil)
    77  	if err != nil {
    78  		t.Fatalf("OpenChannel(%q): %v", channelType, err)
    79  	}
    80  	go ssh.DiscardRequests(reqs)
    81  
    82  	agentClient := NewClient(ch)
    83  	testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
    84  	conn.Close()
    85  }
    86  
    87  func TestV1ProtocolMessages(t *testing.T) {
    88  	c1, c2, err := netPipe()
    89  	if err != nil {
    90  		t.Fatalf("netPipe: %v", err)
    91  	}
    92  	defer c1.Close()
    93  	defer c2.Close()
    94  	c := NewClient(c1)
    95  
    96  	go ServeAgent(NewKeyring(), c2)
    97  
    98  	testV1ProtocolMessages(t, c.(*client))
    99  }
   100  
   101  func testV1ProtocolMessages(t *testing.T, c *client) {
   102  	reply, err := c.call([]byte{agentRequestV1Identities})
   103  	if err != nil {
   104  		t.Fatalf("v1 request all failed: %v", err)
   105  	}
   106  	if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
   107  		t.Fatalf("invalid request all response: %#v", reply)
   108  	}
   109  
   110  	reply, err = c.call([]byte{agentRemoveAllV1Identities})
   111  	if err != nil {
   112  		t.Fatalf("v1 remove all failed: %v", err)
   113  	}
   114  	if _, ok := reply.(*successAgentMsg); !ok {
   115  		t.Fatalf("invalid remove all response: %#v", reply)
   116  	}
   117  }
   118  
   119  func verifyKey(sshAgent Agent) error {
   120  	keys, err := sshAgent.List()
   121  	if err != nil {
   122  		return fmt.Errorf("listing keys: %v", err)
   123  	}
   124  
   125  	if len(keys) != 1 {
   126  		return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
   127  	}
   128  
   129  	buf := make([]byte, 128)
   130  	if _, err := rand.Read(buf); err != nil {
   131  		return fmt.Errorf("rand: %v", err)
   132  	}
   133  
   134  	sig, err := sshAgent.Sign(keys[0], buf)
   135  	if err != nil {
   136  		return fmt.Errorf("sign: %v", err)
   137  	}
   138  
   139  	if err := keys[0].Verify(buf, sig); err != nil {
   140  		return fmt.Errorf("verify: %v", err)
   141  	}
   142  	return nil
   143  }
   144  
   145  func addKeyToAgent(key crypto.PrivateKey) error {
   146  	sshAgent := NewKeyring()
   147  	if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
   148  		return fmt.Errorf("add: %v", err)
   149  	}
   150  	return verifyKey(sshAgent)
   151  }
   152  
   153  func TestKeyTypes(t *testing.T) {
   154  	for k, v := range testPrivateKeys {
   155  		if err := addKeyToAgent(v); err != nil {
   156  			t.Errorf("error adding key type %s, %v", k, err)
   157  		}
   158  		if err := addCertToAgentSock(v, nil); err != nil {
   159  			t.Errorf("error adding key type %s, %v", k, err)
   160  		}
   161  	}
   162  }
   163  
   164  func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
   165  	a, b, err := netPipe()
   166  	if err != nil {
   167  		return err
   168  	}
   169  	agentServer := NewKeyring()
   170  	go ServeAgent(agentServer, a)
   171  
   172  	agentClient := NewClient(b)
   173  	if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
   174  		return fmt.Errorf("add: %v", err)
   175  	}
   176  	return verifyKey(agentClient)
   177  }
   178  
   179  func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
   180  	sshAgent := NewKeyring()
   181  	if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
   182  		return fmt.Errorf("add: %v", err)
   183  	}
   184  	return verifyKey(sshAgent)
   185  }
   186  
   187  func TestCertTypes(t *testing.T) {
   188  	for keyType, key := range testPublicKeys {
   189  		cert := &ssh.Certificate{
   190  			ValidPrincipals: []string{"gopher1"},
   191  			ValidAfter:      0,
   192  			ValidBefore:     ssh.CertTimeInfinity,
   193  			Key:             key,
   194  			Serial:          1,
   195  			CertType:        ssh.UserCert,
   196  			SignatureKey:    testPublicKeys["rsa"],
   197  			Permissions: ssh.Permissions{
   198  				CriticalOptions: map[string]string{},
   199  				Extensions:      map[string]string{},
   200  			},
   201  		}
   202  		if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
   203  			t.Fatalf("signcert: %v", err)
   204  		}
   205  		if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
   206  			t.Fatalf("%v", err)
   207  		}
   208  		if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
   209  			t.Fatalf("%v", err)
   210  		}
   211  	}
   212  }
   213  
   214  func TestParseConstraints(t *testing.T) {
   215  	// Test LifetimeSecs
   216  	var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
   217  	lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
   218  	if err != nil {
   219  		t.Fatalf("parseConstraints: %v", err)
   220  	}
   221  	if lifetimeSecs != msg.LifetimeSecs {
   222  		t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
   223  	}
   224  
   225  	// Test ConfirmBeforeUse
   226  	_, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
   227  	if err != nil {
   228  		t.Fatalf("%v", err)
   229  	}
   230  	if !confirmBeforeUse {
   231  		t.Error("got comfirmBeforeUse == false")
   232  	}
   233  
   234  	// Test ConstraintExtensions
   235  	var data []byte
   236  	var expect []ConstraintExtension
   237  	for i := 0; i < 10; i++ {
   238  		var ext = ConstraintExtension{
   239  			ExtensionName:    fmt.Sprintf("name%d", i),
   240  			ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
   241  		}
   242  		expect = append(expect, ext)
   243  		data = append(data, agentConstrainExtension)
   244  		data = append(data, ssh.Marshal(ext)...)
   245  	}
   246  	_, _, extensions, err := parseConstraints(data)
   247  	if err != nil {
   248  		t.Fatalf("%v", err)
   249  	}
   250  	if !reflect.DeepEqual(expect, extensions) {
   251  		t.Errorf("got extension %v, want %v", extensions, expect)
   252  	}
   253  
   254  	// Test Unknown Constraint
   255  	_, _, _, err = parseConstraints([]byte{128})
   256  	if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
   257  		t.Errorf("unexpected error: %v", err)
   258  	}
   259  }