github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/cmd/juju/ssh_test.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"net/url"
    10  	"os"
    11  	"path/filepath"
    12  
    13  	gc "launchpad.net/gocheck"
    14  
    15  	"github.com/juju/juju/charm"
    16  	charmtesting "github.com/juju/juju/charm/testing"
    17  	"github.com/juju/juju/cmd"
    18  	"github.com/juju/juju/cmd/envcmd"
    19  	"github.com/juju/juju/instance"
    20  	"github.com/juju/juju/juju/testing"
    21  	"github.com/juju/juju/state"
    22  	coretesting "github.com/juju/juju/testing"
    23  )
    24  
    25  var _ = gc.Suite(&SSHSuite{})
    26  
    27  type SSHSuite struct {
    28  	SSHCommonSuite
    29  }
    30  
    31  type SSHCommonSuite struct {
    32  	testing.JujuConnSuite
    33  	bin string
    34  }
    35  
    36  // fakecommand outputs its arguments to stdout for verification
    37  var fakecommand = `#!/bin/bash
    38  
    39  echo $@ | tee $0.args
    40  `
    41  
    42  func (s *SSHCommonSuite) SetUpTest(c *gc.C) {
    43  	s.JujuConnSuite.SetUpTest(c)
    44  	s.PatchValue(&getJujuExecutable, func() (string, error) { return "juju", nil })
    45  
    46  	s.bin = c.MkDir()
    47  	s.PatchEnvPathPrepend(s.bin)
    48  	for _, name := range []string{"ssh", "scp"} {
    49  		f, err := os.OpenFile(filepath.Join(s.bin, name), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0777)
    50  		c.Assert(err, gc.IsNil)
    51  		_, err = f.Write([]byte(fakecommand))
    52  		c.Assert(err, gc.IsNil)
    53  		err = f.Close()
    54  		c.Assert(err, gc.IsNil)
    55  	}
    56  }
    57  
    58  const (
    59  	commonArgsNoProxy = `-o StrictHostKeyChecking no -o PasswordAuthentication no `
    60  	commonArgs        = `-o StrictHostKeyChecking no -o ProxyCommand juju ssh --proxy=false --pty=false 127.0.0.1 nc -q0 %h %p -o PasswordAuthentication no `
    61  	sshArgs           = commonArgs + `-t -t `
    62  	sshArgsNoProxy    = commonArgsNoProxy + `-t -t `
    63  )
    64  
    65  var sshTests = []struct {
    66  	about  string
    67  	args   []string
    68  	result string
    69  }{
    70  	{
    71  		"connect to machine 0",
    72  		[]string{"ssh", "0"},
    73  		sshArgs + "ubuntu@dummyenv-0.internal\n",
    74  	},
    75  	{
    76  		"connect to machine 0 and pass extra arguments",
    77  		[]string{"ssh", "0", "uname", "-a"},
    78  		sshArgs + "ubuntu@dummyenv-0.internal uname -a\n",
    79  	},
    80  	{
    81  		"connect to unit mysql/0",
    82  		[]string{"ssh", "mysql/0"},
    83  		sshArgs + "ubuntu@dummyenv-0.internal\n",
    84  	},
    85  	{
    86  		"connect to unit mongodb/1 and pass extra arguments",
    87  		[]string{"ssh", "mongodb/1", "ls", "/"},
    88  		sshArgs + "ubuntu@dummyenv-2.internal ls /\n",
    89  	},
    90  	{
    91  		"connect to unit mysql/0 without proxy",
    92  		[]string{"ssh", "--proxy=false", "mysql/0"},
    93  		sshArgsNoProxy + "ubuntu@dummyenv-0.dns\n",
    94  	},
    95  }
    96  
    97  func (s *SSHSuite) TestSSHCommand(c *gc.C) {
    98  	m := s.makeMachines(3, c, true)
    99  	ch := charmtesting.Charms.Dir("dummy")
   100  	curl := charm.MustParseURL(
   101  		fmt.Sprintf("local:quantal/%s-%d", ch.Meta().Name, ch.Revision()),
   102  	)
   103  	bundleURL, err := url.Parse("http://bundles.testing.invalid/dummy-1")
   104  	c.Assert(err, gc.IsNil)
   105  	dummy, err := s.State.AddCharm(ch, curl, bundleURL, "dummy-1-sha256")
   106  	c.Assert(err, gc.IsNil)
   107  	srv := s.AddTestingService(c, "mysql", dummy)
   108  	s.addUnit(srv, m[0], c)
   109  
   110  	srv = s.AddTestingService(c, "mongodb", dummy)
   111  	s.addUnit(srv, m[1], c)
   112  	s.addUnit(srv, m[2], c)
   113  
   114  	for i, t := range sshTests {
   115  		c.Logf("test %d: %s -> %s\n", i, t.about, t.args)
   116  		ctx := coretesting.Context(c)
   117  		jujucmd := cmd.NewSuperCommand(cmd.SuperCommandParams{})
   118  		jujucmd.Register(envcmd.Wrap(&SSHCommand{}))
   119  
   120  		code := cmd.Main(jujucmd, ctx, t.args)
   121  		c.Check(code, gc.Equals, 0)
   122  		c.Check(ctx.Stderr.(*bytes.Buffer).String(), gc.Equals, "")
   123  		c.Check(ctx.Stdout.(*bytes.Buffer).String(), gc.Equals, t.result)
   124  	}
   125  }
   126  
   127  func (s *SSHSuite) TestSSHCommandEnvironProxySSH(c *gc.C) {
   128  	s.makeMachines(1, c, true)
   129  	// Setting proxy-ssh=false in the environment overrides --proxy.
   130  	err := s.State.UpdateEnvironConfig(map[string]interface{}{"proxy-ssh": false}, nil, nil)
   131  	c.Assert(err, gc.IsNil)
   132  	ctx := coretesting.Context(c)
   133  	jujucmd := cmd.NewSuperCommand(cmd.SuperCommandParams{})
   134  	jujucmd.Register(&SSHCommand{})
   135  	code := cmd.Main(jujucmd, ctx, []string{"ssh", "0"})
   136  	c.Check(code, gc.Equals, 0)
   137  	c.Check(ctx.Stderr.(*bytes.Buffer).String(), gc.Equals, "")
   138  	c.Check(ctx.Stdout.(*bytes.Buffer).String(), gc.Equals, sshArgsNoProxy+"ubuntu@dummyenv-0.dns\n")
   139  }
   140  
   141  type callbackAttemptStarter struct {
   142  	next func() bool
   143  }
   144  
   145  func (s *callbackAttemptStarter) Start() attempt {
   146  	return callbackAttempt{next: s.next}
   147  }
   148  
   149  type callbackAttempt struct {
   150  	next func() bool
   151  }
   152  
   153  func (a callbackAttempt) Next() bool {
   154  	return a.next()
   155  }
   156  
   157  func (s *SSHSuite) TestSSHCommandHostAddressRetry(c *gc.C) {
   158  	s.testSSHCommandHostAddressRetry(c, false)
   159  }
   160  
   161  func (s *SSHSuite) TestSSHCommandHostAddressRetryProxy(c *gc.C) {
   162  	s.testSSHCommandHostAddressRetry(c, true)
   163  }
   164  
   165  func (s *SSHSuite) testSSHCommandHostAddressRetry(c *gc.C, proxy bool) {
   166  	m := s.makeMachines(1, c, false)
   167  	ctx := coretesting.Context(c)
   168  
   169  	var called int
   170  	next := func() bool {
   171  		called++
   172  		return called < 2
   173  	}
   174  	attemptStarter := &callbackAttemptStarter{next: next}
   175  	s.PatchValue(&sshHostFromTargetAttemptStrategy, attemptStarter)
   176  
   177  	// Ensure that the ssh command waits for a public address, or the attempt
   178  	// strategy's Done method returns false.
   179  	args := []string{"--proxy=" + fmt.Sprint(proxy), "0"}
   180  	code := cmd.Main(&SSHCommand{}, ctx, args)
   181  	c.Check(code, gc.Equals, 1)
   182  	c.Assert(called, gc.Equals, 2)
   183  	called = 0
   184  	attemptStarter.next = func() bool {
   185  		called++
   186  		if called > 1 {
   187  			s.setAddresses(m[0], c)
   188  		}
   189  		return true
   190  	}
   191  	code = cmd.Main(&SSHCommand{}, ctx, args)
   192  	c.Check(code, gc.Equals, 0)
   193  	c.Assert(called, gc.Equals, 2)
   194  }
   195  
   196  func (s *SSHCommonSuite) setAddresses(m *state.Machine, c *gc.C) {
   197  	addrPub := instance.NewAddress(fmt.Sprintf("dummyenv-%s.dns", m.Id()), instance.NetworkPublic)
   198  	addrPriv := instance.NewAddress(fmt.Sprintf("dummyenv-%s.internal", m.Id()), instance.NetworkCloudLocal)
   199  	err := m.SetAddresses(addrPub, addrPriv)
   200  	c.Assert(err, gc.IsNil)
   201  }
   202  
   203  func (s *SSHCommonSuite) makeMachines(n int, c *gc.C, setAddresses bool) []*state.Machine {
   204  	var machines = make([]*state.Machine, n)
   205  	for i := 0; i < n; i++ {
   206  		m, err := s.State.AddMachine("quantal", state.JobHostUnits)
   207  		c.Assert(err, gc.IsNil)
   208  		if setAddresses {
   209  			s.setAddresses(m, c)
   210  		}
   211  		// must set an instance id as the ssh command uses that as a signal the
   212  		// machine has been provisioned
   213  		inst, md := testing.AssertStartInstance(c, s.Conn.Environ, m.Id())
   214  		c.Assert(m.SetProvisioned(inst.Id(), "fake_nonce", md), gc.IsNil)
   215  		machines[i] = m
   216  	}
   217  	return machines
   218  }
   219  
   220  func (s *SSHCommonSuite) addUnit(srv *state.Service, m *state.Machine, c *gc.C) {
   221  	u, err := srv.AddUnit()
   222  	c.Assert(err, gc.IsNil)
   223  	err = u.AssignToMachine(m)
   224  	c.Assert(err, gc.IsNil)
   225  }