github.com/glycerine/xcryptossh@v7.0.4+incompatible/agent/server_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package agent
     6  
     7  import (
     8  	"context"
     9  	"crypto"
    10  	"crypto/rand"
    11  	"fmt"
    12  	pseudorand "math/rand"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/glycerine/xcryptossh"
    18  )
    19  
    20  func TestServer(t *testing.T) {
    21  	c1, c2, err := netPipe()
    22  	if err != nil {
    23  		t.Fatalf("netPipe: %v", err)
    24  	}
    25  	defer c1.Close()
    26  	defer c2.Close()
    27  	client := NewClient(c1)
    28  
    29  	go ServeAgent(NewKeyring(), c2)
    30  
    31  	testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
    32  }
    33  
    34  func TestLockServer(t *testing.T) {
    35  	testLockAgent(NewKeyring(), t)
    36  }
    37  
    38  func TestSetupForwardAgent(t *testing.T) {
    39  	ctx, cancelctx := context.WithCancel(context.Background())
    40  	defer cancelctx()
    41  
    42  	halt := ssh.NewHalter()
    43  	defer halt.ReqStop.Close()
    44  
    45  	a, b, err := netPipe()
    46  	if err != nil {
    47  		t.Fatalf("netPipe: %v", err)
    48  	}
    49  
    50  	defer a.Close()
    51  	defer b.Close()
    52  
    53  	_, socket, cleanup := startOpenSSHAgent(t)
    54  	defer cleanup()
    55  
    56  	serverConf := ssh.ServerConfig{
    57  		NoClientAuth: true,
    58  		Config:       ssh.Config{Halt: halt},
    59  	}
    60  	serverConf.AddHostKey(testSigners["rsa"])
    61  	incoming := make(chan *ssh.ServerConn, 1)
    62  	go func() {
    63  		conn, _, _, err := ssh.NewServerConn(ctx, a, &serverConf)
    64  		if err != nil {
    65  			t.Fatalf("Server: %v", err)
    66  		}
    67  		incoming <- conn
    68  	}()
    69  
    70  	conf := ssh.ClientConfig{
    71  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    72  		Config:          ssh.Config{Halt: halt},
    73  	}
    74  	conn, chans, reqs, err := ssh.NewClientConn(ctx, b, "", &conf)
    75  	if err != nil {
    76  		t.Fatalf("NewClientConn: %v", err)
    77  	}
    78  	client := ssh.NewClient(ctx, conn, chans, reqs, halt)
    79  
    80  	if err := ForwardToRemote(ctx, client, socket); err != nil {
    81  		t.Fatalf("SetupForwardAgent: %v", err)
    82  	}
    83  
    84  	server := <-incoming
    85  	ch, reqs, err := server.OpenChannel(ctx, channelType, nil)
    86  	if err != nil {
    87  		t.Fatalf("OpenChannel(%q): %v", channelType, err)
    88  	}
    89  	go ssh.DiscardRequests(ctx, reqs, halt)
    90  
    91  	agentClient := NewClient(ch)
    92  	testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
    93  	conn.Close()
    94  }
    95  
    96  func TestV1ProtocolMessages(t *testing.T) {
    97  	c1, c2, err := netPipe()
    98  	if err != nil {
    99  		t.Fatalf("netPipe: %v", err)
   100  	}
   101  	defer c1.Close()
   102  	defer c2.Close()
   103  	c := NewClient(c1)
   104  
   105  	go ServeAgent(NewKeyring(), c2)
   106  
   107  	testV1ProtocolMessages(t, c.(*client))
   108  }
   109  
   110  func testV1ProtocolMessages(t *testing.T, c *client) {
   111  	reply, err := c.call([]byte{agentRequestV1Identities})
   112  	if err != nil {
   113  		t.Fatalf("v1 request all failed: %v", err)
   114  	}
   115  	if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
   116  		t.Fatalf("invalid request all response: %#v", reply)
   117  	}
   118  
   119  	reply, err = c.call([]byte{agentRemoveAllV1Identities})
   120  	if err != nil {
   121  		t.Fatalf("v1 remove all failed: %v", err)
   122  	}
   123  	if _, ok := reply.(*successAgentMsg); !ok {
   124  		t.Fatalf("invalid remove all response: %#v", reply)
   125  	}
   126  }
   127  
   128  func verifyKey(sshAgent Agent) error {
   129  	keys, err := sshAgent.List()
   130  	if err != nil {
   131  		return fmt.Errorf("listing keys: %v", err)
   132  	}
   133  
   134  	if len(keys) != 1 {
   135  		return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
   136  	}
   137  
   138  	buf := make([]byte, 128)
   139  	if _, err := rand.Read(buf); err != nil {
   140  		return fmt.Errorf("rand: %v", err)
   141  	}
   142  
   143  	sig, err := sshAgent.Sign(keys[0], buf)
   144  	if err != nil {
   145  		return fmt.Errorf("sign: %v", err)
   146  	}
   147  
   148  	if err := keys[0].Verify(buf, sig); err != nil {
   149  		return fmt.Errorf("verify: %v", err)
   150  	}
   151  	return nil
   152  }
   153  
   154  func addKeyToAgent(key crypto.PrivateKey) error {
   155  	sshAgent := NewKeyring()
   156  	if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
   157  		return fmt.Errorf("add: %v", err)
   158  	}
   159  	return verifyKey(sshAgent)
   160  }
   161  
   162  func TestKeyTypes(t *testing.T) {
   163  	for k, v := range testPrivateKeys {
   164  		if err := addKeyToAgent(v); err != nil {
   165  			t.Errorf("error adding key type %s, %v", k, err)
   166  		}
   167  		if err := addCertToAgentSock(v, nil); err != nil {
   168  			t.Errorf("error adding key type %s, %v", k, err)
   169  		}
   170  	}
   171  }
   172  
   173  func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
   174  	a, b, err := netPipe()
   175  	if err != nil {
   176  		return err
   177  	}
   178  	agentServer := NewKeyring()
   179  	go ServeAgent(agentServer, a)
   180  
   181  	agentClient := NewClient(b)
   182  	if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
   183  		return fmt.Errorf("add: %v", err)
   184  	}
   185  	return verifyKey(agentClient)
   186  }
   187  
   188  func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
   189  	sshAgent := NewKeyring()
   190  	if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
   191  		return fmt.Errorf("add: %v", err)
   192  	}
   193  	return verifyKey(sshAgent)
   194  }
   195  
   196  func TestCertTypes(t *testing.T) {
   197  	for keyType, key := range testPublicKeys {
   198  		cert := &ssh.Certificate{
   199  			ValidPrincipals: []string{"gopher1"},
   200  			ValidAfter:      0,
   201  			ValidBefore:     ssh.CertTimeInfinity,
   202  			Key:             key,
   203  			Serial:          1,
   204  			CertType:        ssh.UserCert,
   205  			SignatureKey:    testPublicKeys["rsa"],
   206  			Permissions: ssh.Permissions{
   207  				CriticalOptions: map[string]string{},
   208  				Extensions:      map[string]string{},
   209  			},
   210  		}
   211  		if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
   212  			t.Fatalf("signcert: %v", err)
   213  		}
   214  		if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
   215  			t.Fatalf("%v", err)
   216  		}
   217  		if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
   218  			t.Fatalf("%v", err)
   219  		}
   220  	}
   221  }
   222  
   223  func TestParseConstraints(t *testing.T) {
   224  	// Test LifetimeSecs
   225  	var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
   226  	lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
   227  	if err != nil {
   228  		t.Fatalf("parseConstraints: %v", err)
   229  	}
   230  	if lifetimeSecs != msg.LifetimeSecs {
   231  		t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
   232  	}
   233  
   234  	// Test ConfirmBeforeUse
   235  	_, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
   236  	if err != nil {
   237  		t.Fatalf("%v", err)
   238  	}
   239  	if !confirmBeforeUse {
   240  		t.Error("got comfirmBeforeUse == false")
   241  	}
   242  
   243  	// Test ConstraintExtensions
   244  	var data []byte
   245  	var expect []ConstraintExtension
   246  	for i := 0; i < 10; i++ {
   247  		var ext = ConstraintExtension{
   248  			ExtensionName:    fmt.Sprintf("name%d", i),
   249  			ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
   250  		}
   251  		expect = append(expect, ext)
   252  		data = append(data, agentConstrainExtension)
   253  		data = append(data, ssh.Marshal(ext)...)
   254  	}
   255  	_, _, extensions, err := parseConstraints(data)
   256  	if err != nil {
   257  		t.Fatalf("%v", err)
   258  	}
   259  	if !reflect.DeepEqual(expect, extensions) {
   260  		t.Errorf("got extension %v, want %v", extensions, expect)
   261  	}
   262  
   263  	// Test Unknown Constraint
   264  	_, _, _, err = parseConstraints([]byte{128})
   265  	if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
   266  		t.Errorf("unexpected error: %v", err)
   267  	}
   268  }