github.com/mwhudson/juju@v0.0.0-20160512215208-90ff01f3497f/payload/state/base_test.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package state_test
     5  
     6  import (
     7  	"fmt"
     8  
     9  	"github.com/juju/errors"
    10  	gitjujutesting "github.com/juju/testing"
    11  	jc "github.com/juju/testing/checkers"
    12  	"github.com/juju/utils"
    13  	gc "gopkg.in/check.v1"
    14  	"gopkg.in/juju/charm.v6-unstable"
    15  
    16  	"github.com/juju/juju/payload"
    17  	"github.com/juju/juju/testing"
    18  )
    19  
    20  type basePayloadsSuite struct {
    21  	testing.BaseSuite
    22  
    23  	stub    *gitjujutesting.Stub
    24  	persist *fakePayloadsPersistence
    25  }
    26  
    27  func (s *basePayloadsSuite) SetUpTest(c *gc.C) {
    28  	s.BaseSuite.SetUpTest(c)
    29  
    30  	s.stub = &gitjujutesting.Stub{}
    31  	s.persist = &fakePayloadsPersistence{Stub: s.stub}
    32  }
    33  
    34  func (s *basePayloadsSuite) newPayload(pType string, id string) payload.Payload {
    35  	name, rawID := payload.ParseID(id)
    36  	if rawID == "" {
    37  		rawID = fmt.Sprintf("%s-%s", name, utils.MustNewUUID())
    38  	}
    39  
    40  	return payload.Payload{
    41  		PayloadClass: charm.PayloadClass{
    42  			Name: name,
    43  			Type: pType,
    44  		},
    45  		Status: payload.StateRunning,
    46  		ID:     rawID,
    47  		Unit:   "a-service/0",
    48  	}
    49  }
    50  
    51  type fakePayloadsPersistence struct {
    52  	*gitjujutesting.Stub
    53  	payloads map[string]*payload.Payload
    54  }
    55  
    56  func (s *fakePayloadsPersistence) checkPayload(c *gc.C, id string, expected payload.Payload) {
    57  	pl, ok := s.payloads[id]
    58  	if !ok {
    59  		c.Errorf("payload %q not found", id)
    60  	} else {
    61  		c.Check(pl, jc.DeepEquals, &expected)
    62  	}
    63  }
    64  
    65  func (s *fakePayloadsPersistence) setPayload(id string, pl *payload.Payload) {
    66  	if s.payloads == nil {
    67  		s.payloads = make(map[string]*payload.Payload)
    68  	}
    69  	s.payloads[id] = pl
    70  }
    71  
    72  func (s *fakePayloadsPersistence) Track(id string, pl payload.Payload) (bool, error) {
    73  	s.AddCall("Track", id, pl)
    74  	if err := s.NextErr(); err != nil {
    75  		return false, errors.Trace(err)
    76  	}
    77  
    78  	if _, ok := s.payloads[id]; ok {
    79  		return false, nil
    80  	}
    81  	s.setPayload(id, &pl)
    82  	return true, nil
    83  }
    84  
    85  func (s *fakePayloadsPersistence) SetStatus(id, status string) (bool, error) {
    86  	s.AddCall("SetStatus", id, status)
    87  	if err := s.NextErr(); err != nil {
    88  		return false, errors.Trace(err)
    89  	}
    90  
    91  	pl, ok := s.payloads[id]
    92  	if !ok {
    93  		return false, nil
    94  	}
    95  	pl.Status = status
    96  	return true, nil
    97  }
    98  
    99  func (s *fakePayloadsPersistence) List(ids ...string) ([]payload.Payload, []string, error) {
   100  	s.AddCall("List", ids)
   101  	if err := s.NextErr(); err != nil {
   102  		return nil, nil, errors.Trace(err)
   103  	}
   104  
   105  	var payloads []payload.Payload
   106  	var missing []string
   107  	for _, id := range ids {
   108  		if pl, ok := s.payloads[id]; !ok {
   109  			missing = append(missing, id)
   110  		} else {
   111  			payloads = append(payloads, *pl)
   112  		}
   113  	}
   114  	return payloads, missing, nil
   115  }
   116  
   117  func (s *fakePayloadsPersistence) ListAll() ([]payload.Payload, error) {
   118  	s.AddCall("ListAll")
   119  	if err := s.NextErr(); err != nil {
   120  		return nil, errors.Trace(err)
   121  	}
   122  
   123  	var payloads []payload.Payload
   124  	for _, pl := range s.payloads {
   125  		payloads = append(payloads, *pl)
   126  	}
   127  	return payloads, nil
   128  }
   129  
   130  func (s *fakePayloadsPersistence) LookUp(name, rawID string) (string, error) {
   131  	s.AddCall("LookUp", name, rawID)
   132  	if err := s.NextErr(); err != nil {
   133  		return "", errors.Trace(err)
   134  	}
   135  
   136  	for id, pl := range s.payloads {
   137  		if pl.Name == name && pl.ID == rawID {
   138  			return id, nil
   139  		}
   140  	}
   141  	return "", errors.NotFoundf("doc ID")
   142  }
   143  
   144  func (s *fakePayloadsPersistence) Untrack(id string) (bool, error) {
   145  	s.AddCall("Untrack", id)
   146  	if err := s.NextErr(); err != nil {
   147  		return false, errors.Trace(err)
   148  	}
   149  
   150  	if _, ok := s.payloads[id]; !ok {
   151  		return false, nil
   152  	}
   153  	delete(s.payloads, id)
   154  	return true, nil
   155  }