github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/worker/instancepoller/aggregate_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package instancepoller
     5  
     6  import (
     7  	"fmt"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/juju/errors"
    13  	jc "github.com/juju/testing/checkers"
    14  	gc "gopkg.in/check.v1"
    15  
    16  	"github.com/juju/juju/environs"
    17  	"github.com/juju/juju/instance"
    18  	"github.com/juju/juju/network"
    19  	"github.com/juju/juju/status"
    20  	"github.com/juju/juju/testing"
    21  )
    22  
    23  type aggregateSuite struct {
    24  	testing.BaseSuite
    25  }
    26  
    27  var _ = gc.Suite(&aggregateSuite{})
    28  
    29  type testInstance struct {
    30  	instance.Instance
    31  	id        instance.Id
    32  	addresses []network.Address
    33  	status    string
    34  	err       error
    35  }
    36  
    37  var _ instance.Instance = (*testInstance)(nil)
    38  
    39  func (t *testInstance) Id() instance.Id {
    40  	return t.id
    41  }
    42  
    43  func (t *testInstance) Addresses() ([]network.Address, error) {
    44  	if t.err != nil {
    45  		return nil, t.err
    46  	}
    47  	return t.addresses, nil
    48  }
    49  
    50  func (t *testInstance) Status() instance.InstanceStatus {
    51  	return instance.InstanceStatus{Status: status.StatusUnknown, Message: t.status}
    52  }
    53  
    54  type testInstanceGetter struct {
    55  	// ids is set when the Instances method is called.
    56  	ids     []instance.Id
    57  	results map[instance.Id]instance.Instance
    58  	err     error
    59  	counter int32
    60  }
    61  
    62  func (tig *testInstanceGetter) Instances(ids []instance.Id) (result []instance.Instance, err error) {
    63  	tig.ids = ids
    64  	atomic.AddInt32(&tig.counter, 1)
    65  	results := make([]instance.Instance, len(ids))
    66  	for i, id := range ids {
    67  		// We don't check 'ok' here, because we want the Instance{nil}
    68  		// response for those
    69  		results[i] = tig.results[id]
    70  	}
    71  	return results, tig.err
    72  }
    73  
    74  func (tig *testInstanceGetter) newTestInstance(id instance.Id, status string, addresses []string) *testInstance {
    75  	if tig.results == nil {
    76  		tig.results = make(map[instance.Id]instance.Instance)
    77  	}
    78  	thisInstance := &testInstance{
    79  		id:        id,
    80  		status:    status,
    81  		addresses: network.NewAddresses(addresses...),
    82  	}
    83  	tig.results[thisInstance.Id()] = thisInstance
    84  	return thisInstance
    85  }
    86  
    87  func (s *aggregateSuite) TestSingleRequest(c *gc.C) {
    88  	testGetter := new(testInstanceGetter)
    89  	instance1 := testGetter.newTestInstance("foo", "foobar", []string{"127.0.0.1", "192.168.1.1"})
    90  	aggregator := newAggregator(testGetter)
    91  
    92  	info, err := aggregator.instanceInfo("foo")
    93  	c.Assert(err, jc.ErrorIsNil)
    94  	c.Assert(info, gc.DeepEquals, instanceInfo{
    95  		status:    instance.InstanceStatus{Status: status.StatusUnknown, Message: "foobar"},
    96  		addresses: instance1.addresses,
    97  	})
    98  	c.Assert(testGetter.ids, gc.DeepEquals, []instance.Id{"foo"})
    99  }
   100  
   101  func (s *aggregateSuite) TestMultipleResponseHandling(c *gc.C) {
   102  	s.PatchValue(&gatherTime, 30*time.Millisecond)
   103  	testGetter := new(testInstanceGetter)
   104  
   105  	testGetter.newTestInstance("foo", "foobar", []string{"127.0.0.1", "192.168.1.1"})
   106  	aggregator := newAggregator(testGetter)
   107  
   108  	replyChan := make(chan instanceInfoReply)
   109  	req := instanceInfoReq{
   110  		reply:  replyChan,
   111  		instId: instance.Id("foo"),
   112  	}
   113  	aggregator.reqc <- req
   114  	reply := <-replyChan
   115  	c.Assert(reply.err, gc.IsNil)
   116  
   117  	testGetter.newTestInstance("foo2", "not foobar", []string{"192.168.1.2"})
   118  	testGetter.newTestInstance("foo3", "ok-ish", []string{"192.168.1.3"})
   119  
   120  	var wg sync.WaitGroup
   121  	checkInfo := func(id instance.Id, expectStatus string) {
   122  		info, err := aggregator.instanceInfo(id)
   123  		c.Check(err, jc.ErrorIsNil)
   124  		c.Check(info.status.Message, gc.Equals, expectStatus)
   125  		wg.Done()
   126  	}
   127  
   128  	wg.Add(2)
   129  	go checkInfo("foo2", "not foobar")
   130  	go checkInfo("foo3", "ok-ish")
   131  	wg.Wait()
   132  
   133  	c.Assert(len(testGetter.ids), gc.DeepEquals, 2)
   134  }
   135  
   136  // notifyingInstanceGetter wraps testInstanceGetter, notifying via
   137  // a channel when Instances() is called.
   138  type notifyingInstanceGetter struct {
   139  	testInstanceGetter
   140  	instancesc chan bool
   141  }
   142  
   143  func (g *notifyingInstanceGetter) Instances(ids []instance.Id) ([]instance.Instance, error) {
   144  	g.instancesc <- true
   145  	return g.testInstanceGetter.Instances(ids)
   146  }
   147  
   148  func (s *aggregateSuite) TestDyingWhileHandlingRequest(c *gc.C) {
   149  	// This tests a regression where the aggregator couldn't shut down
   150  	// if the the tomb was killed while a request was being handled,
   151  	// leaving the reply channel unread.
   152  
   153  	s.PatchValue(&gatherTime, 30*time.Millisecond)
   154  
   155  	// Set up the aggregator with the instance getter.
   156  	testGetter := &notifyingInstanceGetter{instancesc: make(chan bool)}
   157  	testGetter.newTestInstance("foo", "foobar", []string{"127.0.0.1", "192.168.1.1"})
   158  	aggregator := newAggregator(testGetter)
   159  
   160  	// Make a request with a reply channel that will never be read.
   161  	req := instanceInfoReq{
   162  		reply:  make(chan instanceInfoReply),
   163  		instId: instance.Id("foo"),
   164  	}
   165  	aggregator.reqc <- req
   166  
   167  	// Wait for Instances to be called.
   168  	select {
   169  	case <-testGetter.instancesc:
   170  	case <-time.After(testing.LongWait):
   171  		c.Fatal("Instances() not called")
   172  	}
   173  
   174  	// Now we know the request is being handled - kill the aggregator.
   175  	aggregator.Kill()
   176  	done := make(chan error)
   177  	go func() {
   178  		done <- aggregator.Wait()
   179  	}()
   180  
   181  	// The aggregator should stop.
   182  	select {
   183  	case err := <-done:
   184  		c.Assert(err, jc.ErrorIsNil)
   185  	case <-time.After(testing.LongWait):
   186  		c.Fatal("aggregator didn't stop")
   187  	}
   188  }
   189  
   190  type batchingInstanceGetter struct {
   191  	testInstanceGetter
   192  	wg         sync.WaitGroup
   193  	aggregator *aggregator
   194  	totalCount int
   195  	batchSize  int
   196  	started    int
   197  }
   198  
   199  func (g *batchingInstanceGetter) Instances(ids []instance.Id) ([]instance.Instance, error) {
   200  	insts, err := g.testInstanceGetter.Instances(ids)
   201  	g.startRequests()
   202  	return insts, err
   203  }
   204  
   205  func (g *batchingInstanceGetter) startRequests() {
   206  	n := g.totalCount - g.started
   207  	if n > g.batchSize {
   208  		n = g.batchSize
   209  	}
   210  	for i := 0; i < n; i++ {
   211  		g.startRequest()
   212  	}
   213  }
   214  
   215  func (g *batchingInstanceGetter) startRequest() {
   216  	g.started++
   217  	go func() {
   218  		_, err := g.aggregator.instanceInfo("foo")
   219  		if err != nil {
   220  			panic(err)
   221  		}
   222  		g.wg.Done()
   223  	}()
   224  }
   225  
   226  func (s *aggregateSuite) TestBatching(c *gc.C) {
   227  	s.PatchValue(&gatherTime, 10*time.Millisecond)
   228  	var testGetter batchingInstanceGetter
   229  	testGetter.aggregator = newAggregator(&testGetter)
   230  	// We only need to inform the system about 1 instance, because all the
   231  	// requests are for the same instance.
   232  	testGetter.newTestInstance("foo", "foobar", []string{"127.0.0.1", "192.168.1.1"})
   233  	testGetter.totalCount = 100
   234  	testGetter.batchSize = 10
   235  	testGetter.wg.Add(testGetter.totalCount)
   236  	// startRequest will trigger one request, which ends up calling
   237  	// Instances, which will turn around and trigger batchSize requests,
   238  	// which should get aggregated into a single call to Instances, which
   239  	// then should trigger another round of batchSize requests.
   240  	testGetter.startRequest()
   241  	testGetter.wg.Wait()
   242  	c.Assert(testGetter.counter, gc.Equals, int32(testGetter.totalCount/testGetter.batchSize)+1)
   243  }
   244  
   245  func (s *aggregateSuite) TestError(c *gc.C) {
   246  	testGetter := new(testInstanceGetter)
   247  	ourError := fmt.Errorf("Some error")
   248  	testGetter.err = ourError
   249  
   250  	aggregator := newAggregator(testGetter)
   251  
   252  	_, err := aggregator.instanceInfo("foo")
   253  	c.Assert(err, gc.Equals, ourError)
   254  }
   255  
   256  func (s *aggregateSuite) TestPartialErrResponse(c *gc.C) {
   257  	testGetter := new(testInstanceGetter)
   258  	testGetter.err = environs.ErrPartialInstances
   259  
   260  	aggregator := newAggregator(testGetter)
   261  	_, err := aggregator.instanceInfo("foo")
   262  
   263  	c.Assert(err, gc.ErrorMatches, "instance foo not found")
   264  	c.Assert(err, jc.Satisfies, errors.IsNotFound)
   265  }
   266  
   267  func (s *aggregateSuite) TestAddressesError(c *gc.C) {
   268  	testGetter := new(testInstanceGetter)
   269  	instance1 := testGetter.newTestInstance("foo", "foobar", []string{"127.0.0.1", "192.168.1.1"})
   270  	ourError := fmt.Errorf("gotcha")
   271  	instance1.err = ourError
   272  
   273  	aggregator := newAggregator(testGetter)
   274  	_, err := aggregator.instanceInfo("foo")
   275  	c.Assert(err, gc.Equals, ourError)
   276  }
   277  
   278  func (s *aggregateSuite) TestKillAndWait(c *gc.C) {
   279  	testGetter := new(testInstanceGetter)
   280  	aggregator := newAggregator(testGetter)
   281  	aggregator.Kill()
   282  	err := aggregator.Wait()
   283  	c.Assert(err, jc.ErrorIsNil)
   284  }