github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/payload/persistence/env_test.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package persistence
     5  
     6  import (
     7  	"reflect"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/testing"
    11  	jc "github.com/juju/testing/checkers"
    12  	gc "gopkg.in/check.v1"
    13  	"gopkg.in/juju/charm.v6-unstable"
    14  
    15  	"github.com/juju/juju/payload"
    16  )
    17  
    18  var _ = gc.Suite(&envPersistenceSuite{})
    19  
    20  type envPersistenceSuite struct {
    21  	BaseSuite
    22  
    23  	base *stubEnvPersistenceBase
    24  }
    25  
    26  func (s *envPersistenceSuite) SetUpTest(c *gc.C) {
    27  	s.BaseSuite.SetUpTest(c)
    28  
    29  	s.base = &stubEnvPersistenceBase{
    30  		PersistenceBase: s.State,
    31  		stub:            s.Stub,
    32  	}
    33  }
    34  
    35  func (s *envPersistenceSuite) newPayload(name string) payload.FullPayloadInfo {
    36  	return payload.FullPayloadInfo{
    37  		Payload: payload.Payload{
    38  			PayloadClass: charm.PayloadClass{
    39  				Name: name,
    40  				Type: "docker",
    41  			},
    42  			ID:     "id" + name,
    43  			Status: payload.StateRunning,
    44  			Labels: []string{"a-tag"},
    45  			Unit:   "a-service/0",
    46  		},
    47  		Machine: "1",
    48  	}
    49  }
    50  
    51  func (s *envPersistenceSuite) TestListAllOkay(c *gc.C) {
    52  	s.base.setUnits("0")
    53  	s.base.setUnits("1", "a-service/0", "a-service/1")
    54  	s.base.setUnits("2", "a-service/2")
    55  	p1 := s.newPayload("spam")
    56  	p2 := s.newPayload("eggs")
    57  	s.base.setPayloads(p1, p2)
    58  
    59  	persist := NewEnvPersistence(s.base)
    60  	persist.newUnitPersist = s.base.newUnitPersistence
    61  
    62  	payloads, err := persist.ListAll()
    63  	c.Assert(err, jc.ErrorIsNil)
    64  
    65  	checkPayloads(c, payloads, p1, p2)
    66  	s.Stub.CheckCallNames(c,
    67  		"Machines",
    68  
    69  		"MachineUnits",
    70  
    71  		"MachineUnits",
    72  		"newUnitPersistence",
    73  		"ListAll",
    74  		"newUnitPersistence",
    75  		"ListAll",
    76  
    77  		"MachineUnits",
    78  		"newUnitPersistence",
    79  		"ListAll",
    80  	)
    81  }
    82  
    83  func (s *envPersistenceSuite) TestListAllEmpty(c *gc.C) {
    84  	s.base.setUnits("0")
    85  	s.base.setUnits("1", "a-service/0", "a-service/1")
    86  	persist := NewEnvPersistence(s.base)
    87  	persist.newUnitPersist = s.base.newUnitPersistence
    88  
    89  	payloads, err := persist.ListAll()
    90  	c.Assert(err, jc.ErrorIsNil)
    91  
    92  	c.Check(payloads, gc.HasLen, 0)
    93  	s.Stub.CheckCallNames(c,
    94  		"Machines",
    95  
    96  		"MachineUnits",
    97  
    98  		"MachineUnits",
    99  		"newUnitPersistence",
   100  		"ListAll",
   101  		"newUnitPersistence",
   102  		"ListAll",
   103  	)
   104  }
   105  
   106  func (s *envPersistenceSuite) TestListAllFailed(c *gc.C) {
   107  	failure := errors.Errorf("<failed!>")
   108  	s.Stub.SetErrors(failure)
   109  
   110  	persist := NewEnvPersistence(s.base)
   111  	persist.newUnitPersist = s.base.newUnitPersistence
   112  
   113  	_, err := persist.ListAll()
   114  
   115  	c.Check(errors.Cause(err), gc.Equals, failure)
   116  }
   117  
   118  // TODO(ericsnow) Factor this out to a testing package.
   119  
   120  func checkPayloads(c *gc.C, payloads []payload.FullPayloadInfo, expectedList ...payload.FullPayloadInfo) {
   121  	remainder := make([]payload.FullPayloadInfo, len(payloads))
   122  	copy(remainder, payloads)
   123  	var noMatch []payload.FullPayloadInfo
   124  	for _, expected := range expectedList {
   125  		found := false
   126  		for i, payload := range remainder {
   127  			if reflect.DeepEqual(payload, expected) {
   128  				remainder = append(remainder[:i], remainder[i+1:]...)
   129  				found = true
   130  				break
   131  			}
   132  		}
   133  		if !found {
   134  			noMatch = append(noMatch, expected)
   135  		}
   136  	}
   137  
   138  	ok1 := c.Check(noMatch, gc.HasLen, 0)
   139  	ok2 := c.Check(remainder, gc.HasLen, 0)
   140  	if !ok1 || !ok2 {
   141  		c.Logf("<<<<<<<<\nexpected:")
   142  		for _, payload := range expectedList {
   143  			c.Logf("%#v", payload)
   144  		}
   145  		c.Logf("--------\ngot:")
   146  		for _, payload := range payloads {
   147  			c.Logf("%#v", payload)
   148  		}
   149  		c.Logf(">>>>>>>>")
   150  	}
   151  }
   152  
   153  type stubEnvPersistenceBase struct {
   154  	PersistenceBase
   155  	stub         *testing.Stub
   156  	machines     []string
   157  	units        map[string]map[string]bool
   158  	unitPersists map[string]*stubUnitPersistence
   159  }
   160  
   161  func (s *stubEnvPersistenceBase) setPayloads(payloads ...payload.FullPayloadInfo) {
   162  	if s.unitPersists == nil && len(payloads) > 0 {
   163  		s.unitPersists = make(map[string]*stubUnitPersistence)
   164  	}
   165  
   166  	for _, pl := range payloads {
   167  		s.setUnits(pl.Machine, pl.Unit)
   168  
   169  		unitPayloads := s.unitPersists[pl.Unit]
   170  		if unitPayloads == nil {
   171  			unitPayloads = &stubUnitPersistence{stub: s.stub}
   172  			s.unitPersists[pl.Unit] = unitPayloads
   173  		}
   174  
   175  		unitPayloads.setPayloads(pl.Payload)
   176  	}
   177  }
   178  
   179  func (s *stubEnvPersistenceBase) setUnits(machine string, units ...string) {
   180  	if s.units == nil {
   181  		s.units = make(map[string]map[string]bool)
   182  	}
   183  	if _, ok := s.units[machine]; !ok {
   184  		s.machines = append(s.machines, machine)
   185  		s.units[machine] = make(map[string]bool)
   186  	}
   187  
   188  	for _, unit := range units {
   189  		s.units[machine][unit] = true
   190  	}
   191  }
   192  
   193  func (s *stubEnvPersistenceBase) newUnitPersistence(base PersistenceBase, unit string) unitPersistence {
   194  	s.stub.AddCall("newUnitPersistence", base, unit)
   195  	s.stub.NextErr() // pop one off
   196  
   197  	persist, ok := s.unitPersists[unit]
   198  	if !ok {
   199  		if s.unitPersists == nil {
   200  			s.unitPersists = make(map[string]*stubUnitPersistence)
   201  		}
   202  		persist = &stubUnitPersistence{stub: s.stub}
   203  		s.unitPersists[unit] = persist
   204  	}
   205  	return persist
   206  }
   207  
   208  func (s *stubEnvPersistenceBase) Machines() ([]string, error) {
   209  	s.stub.AddCall("Machines")
   210  	if err := s.stub.NextErr(); err != nil {
   211  		return nil, errors.Trace(err)
   212  	}
   213  
   214  	var names []string
   215  	for _, name := range s.machines {
   216  		names = append(names, name)
   217  	}
   218  	return names, nil
   219  }
   220  
   221  func (s *stubEnvPersistenceBase) MachineUnits(machine string) ([]string, error) {
   222  	s.stub.AddCall("MachineUnits", machine)
   223  	if err := s.stub.NextErr(); err != nil {
   224  		return nil, errors.Trace(err)
   225  	}
   226  
   227  	var units []string
   228  	for unit := range s.units[machine] {
   229  		units = append(units, unit)
   230  	}
   231  	return units, nil
   232  }
   233  
   234  type stubUnitPersistence struct {
   235  	stub *testing.Stub
   236  
   237  	payloads []payload.Payload
   238  }
   239  
   240  func (s *stubUnitPersistence) setPayloads(payloads ...payload.Payload) {
   241  	s.payloads = append(s.payloads, payloads...)
   242  }
   243  
   244  func (s *stubUnitPersistence) ListAll() ([]payload.Payload, error) {
   245  	s.stub.AddCall("ListAll")
   246  	if err := s.stub.NextErr(); err != nil {
   247  		return nil, errors.Trace(err)
   248  	}
   249  
   250  	return s.payloads, nil
   251  }