github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/worker/instancepoller/aggregate_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package instancepoller
     5  
     6  import (
     7  	"fmt"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/juju/errors"
    13  	jc "github.com/juju/testing/checkers"
    14  	gc "launchpad.net/gocheck"
    15  
    16  	"github.com/juju/juju/environs"
    17  	"github.com/juju/juju/instance"
    18  	"github.com/juju/juju/testing"
    19  )
    20  
    21  type aggregateSuite struct {
    22  	testing.BaseSuite
    23  }
    24  
    25  var _ = gc.Suite(&aggregateSuite{})
    26  
    27  type testInstance struct {
    28  	instance.Instance
    29  	addresses []instance.Address
    30  	status    string
    31  	err       error
    32  }
    33  
    34  var _ instance.Instance = (*testInstance)(nil)
    35  
    36  func (t *testInstance) Addresses() ([]instance.Address, error) {
    37  	if t.err != nil {
    38  		return nil, t.err
    39  	}
    40  	return t.addresses, nil
    41  }
    42  
    43  func (t *testInstance) Status() string {
    44  	return t.status
    45  }
    46  
    47  type testInstanceGetter struct {
    48  	// ids is set when the Instances method is called.
    49  	ids     []instance.Id
    50  	results []instance.Instance
    51  	err     error
    52  	counter int32
    53  }
    54  
    55  func (i *testInstanceGetter) Instances(ids []instance.Id) (result []instance.Instance, err error) {
    56  	i.ids = ids
    57  	atomic.AddInt32(&i.counter, 1)
    58  	return i.results, i.err
    59  }
    60  
    61  func newTestInstance(status string, addresses []string) *testInstance {
    62  	thisInstance := testInstance{status: status}
    63  	thisInstance.addresses = instance.NewAddresses(addresses...)
    64  	return &thisInstance
    65  }
    66  
    67  func (s *aggregateSuite) TestSingleRequest(c *gc.C) {
    68  	testGetter := new(testInstanceGetter)
    69  	instance1 := newTestInstance("foobar", []string{"127.0.0.1", "192.168.1.1"})
    70  	testGetter.results = []instance.Instance{instance1}
    71  	aggregator := newAggregator(testGetter)
    72  
    73  	info, err := aggregator.instanceInfo("foo")
    74  	c.Assert(err, gc.IsNil)
    75  	c.Assert(info, gc.DeepEquals, instanceInfo{
    76  		status:    "foobar",
    77  		addresses: instance1.addresses,
    78  	})
    79  	c.Assert(testGetter.ids, gc.DeepEquals, []instance.Id{"foo"})
    80  }
    81  
    82  func (s *aggregateSuite) TestMultipleResponseHandling(c *gc.C) {
    83  	s.PatchValue(&gatherTime, 30*time.Millisecond)
    84  	testGetter := new(testInstanceGetter)
    85  
    86  	instance1 := newTestInstance("foobar", []string{"127.0.0.1", "192.168.1.1"})
    87  	testGetter.results = []instance.Instance{instance1}
    88  	aggregator := newAggregator(testGetter)
    89  
    90  	replyChan := make(chan instanceInfoReply)
    91  	req := instanceInfoReq{
    92  		reply:  replyChan,
    93  		instId: instance.Id("foo"),
    94  	}
    95  	aggregator.reqc <- req
    96  	reply := <-replyChan
    97  	c.Assert(reply.err, gc.IsNil)
    98  
    99  	instance2 := newTestInstance("not foobar", []string{"192.168.1.2"})
   100  	instance3 := newTestInstance("ok-ish", []string{"192.168.1.3"})
   101  	testGetter.results = []instance.Instance{instance2, instance3}
   102  
   103  	var wg sync.WaitGroup
   104  	checkInfo := func(id instance.Id, expectStatus string) {
   105  		info, err := aggregator.instanceInfo(id)
   106  		c.Check(err, gc.IsNil)
   107  		c.Check(info.status, gc.Equals, expectStatus)
   108  		wg.Done()
   109  	}
   110  
   111  	wg.Add(2)
   112  	go checkInfo("foo2", "not foobar")
   113  	go checkInfo("foo3", "ok-ish")
   114  	wg.Wait()
   115  
   116  	c.Assert(len(testGetter.ids), gc.DeepEquals, 2)
   117  }
   118  
   119  type batchingInstanceGetter struct {
   120  	testInstanceGetter
   121  	wg         sync.WaitGroup
   122  	aggregator *aggregator
   123  	batchSize  int
   124  	started    int
   125  }
   126  
   127  func (g *batchingInstanceGetter) Instances(ids []instance.Id) ([]instance.Instance, error) {
   128  	insts, err := g.testInstanceGetter.Instances(ids)
   129  	g.startRequests()
   130  	return insts, err
   131  }
   132  
   133  func (g *batchingInstanceGetter) startRequests() {
   134  	n := len(g.results) - g.started
   135  	if n > g.batchSize {
   136  		n = g.batchSize
   137  	}
   138  	for i := 0; i < n; i++ {
   139  		g.startRequest()
   140  	}
   141  }
   142  
   143  func (g *batchingInstanceGetter) startRequest() {
   144  	g.started++
   145  	go func() {
   146  		_, err := g.aggregator.instanceInfo("foo")
   147  		if err != nil {
   148  			panic(err)
   149  		}
   150  		g.wg.Done()
   151  	}()
   152  }
   153  
   154  func (s *aggregateSuite) TestBatching(c *gc.C) {
   155  	s.PatchValue(&gatherTime, 10*time.Millisecond)
   156  	var testGetter batchingInstanceGetter
   157  	testGetter.aggregator = newAggregator(&testGetter)
   158  	testGetter.results = make([]instance.Instance, 100)
   159  	for i := range testGetter.results {
   160  		testGetter.results[i] = newTestInstance("foobar", []string{"127.0.0.1", "192.168.1.1"})
   161  	}
   162  	testGetter.batchSize = 10
   163  	testGetter.wg.Add(len(testGetter.results))
   164  	testGetter.startRequest()
   165  	testGetter.wg.Wait()
   166  	c.Assert(testGetter.counter, gc.Equals, int32(len(testGetter.results)/testGetter.batchSize)+1)
   167  }
   168  
   169  func (s *aggregateSuite) TestError(c *gc.C) {
   170  	testGetter := new(testInstanceGetter)
   171  	ourError := fmt.Errorf("Some error")
   172  	testGetter.err = ourError
   173  
   174  	aggregator := newAggregator(testGetter)
   175  
   176  	_, err := aggregator.instanceInfo("foo")
   177  	c.Assert(err, gc.Equals, ourError)
   178  }
   179  
   180  func (s *aggregateSuite) TestPartialErrResponse(c *gc.C) {
   181  	testGetter := new(testInstanceGetter)
   182  	testGetter.err = environs.ErrPartialInstances
   183  	testGetter.results = []instance.Instance{nil}
   184  
   185  	aggregator := newAggregator(testGetter)
   186  	_, err := aggregator.instanceInfo("foo")
   187  
   188  	c.Assert(err, gc.ErrorMatches, "instance foo not found")
   189  	c.Assert(err, jc.Satisfies, errors.IsNotFound)
   190  }
   191  
   192  func (s *aggregateSuite) TestAddressesError(c *gc.C) {
   193  	testGetter := new(testInstanceGetter)
   194  	instance1 := newTestInstance("foobar", []string{"127.0.0.1", "192.168.1.1"})
   195  	ourError := fmt.Errorf("gotcha")
   196  	instance1.err = ourError
   197  	testGetter.results = []instance.Instance{instance1}
   198  
   199  	aggregator := newAggregator(testGetter)
   200  	_, err := aggregator.instanceInfo("foo")
   201  	c.Assert(err, gc.Equals, ourError)
   202  }
   203  
   204  func (s *aggregateSuite) TestKillAndWait(c *gc.C) {
   205  	testGetter := new(testInstanceGetter)
   206  	aggregator := newAggregator(testGetter)
   207  	aggregator.Kill()
   208  	err := aggregator.Wait()
   209  	c.Assert(err, gc.IsNil)
   210  }