github.com/wallyworld/juju@v0.0.0-20161013125918-6cf1bc9d917a/payload/context/base_test.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package context_test
     5  
     6  import (
     7  	"reflect"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/testing"
    11  	gc "gopkg.in/check.v1"
    12  	"gopkg.in/juju/charm.v6-unstable"
    13  
    14  	"github.com/juju/juju/payload"
    15  	"github.com/juju/juju/payload/context"
    16  	jujuctesting "github.com/juju/juju/worker/uniter/runner/jujuc/testing"
    17  )
    18  
    19  type baseSuite struct {
    20  	jujuctesting.ContextSuite
    21  	payload payload.Payload
    22  }
    23  
    24  func (s *baseSuite) SetUpTest(c *gc.C) {
    25  	s.ContextSuite.SetUpTest(c)
    26  
    27  	s.payload = s.newPayload("payload A", "docker", "", "")
    28  }
    29  
    30  func (s *baseSuite) newPayload(name, ptype, id, status string) payload.Payload {
    31  	pl := payload.Payload{
    32  		PayloadClass: charm.PayloadClass{
    33  			Name: name,
    34  			Type: ptype,
    35  		},
    36  		ID:     id,
    37  		Status: status,
    38  		Unit:   "a-application/0",
    39  	}
    40  	return pl
    41  }
    42  
    43  func (s *baseSuite) NewHookContext() (*stubHookContext, *jujuctesting.ContextInfo) {
    44  	ctx, info := s.ContextSuite.NewHookContext()
    45  	return &stubHookContext{ctx}, info
    46  }
    47  
    48  func checkPayloads(c *gc.C, payloads, expected []payload.Payload) {
    49  	if !c.Check(payloads, gc.HasLen, len(expected)) {
    50  		return
    51  	}
    52  	for _, wl := range payloads {
    53  		matched := false
    54  		for _, expPayload := range expected {
    55  			if reflect.DeepEqual(wl, expPayload) {
    56  				matched = true
    57  				break
    58  			}
    59  		}
    60  		if !matched {
    61  			c.Errorf("%#v != %#v", payloads, expected)
    62  			return
    63  		}
    64  	}
    65  }
    66  
    67  type stubHookContext struct {
    68  	*jujuctesting.Context
    69  }
    70  
    71  func (c stubHookContext) Component(name string) (context.Component, error) {
    72  	found, err := c.Context.Component(name)
    73  	if err != nil {
    74  		return nil, errors.Trace(err)
    75  	}
    76  	compCtx, ok := found.(context.Component)
    77  	if !ok && found != nil {
    78  		return nil, errors.Errorf("wrong component context type registered: %T", found)
    79  	}
    80  	return compCtx, nil
    81  }
    82  
    83  var _ context.Component = (*stubContextComponent)(nil)
    84  
    85  type stubContextComponent struct {
    86  	stub     *testing.Stub
    87  	payloads map[string]payload.Payload
    88  	untracks map[string]struct{}
    89  }
    90  
    91  func newStubContextComponent(stub *testing.Stub) *stubContextComponent {
    92  	return &stubContextComponent{
    93  		stub:     stub,
    94  		payloads: make(map[string]payload.Payload),
    95  		untracks: make(map[string]struct{}),
    96  	}
    97  }
    98  
    99  func (c *stubContextComponent) Get(class, id string) (*payload.Payload, error) {
   100  	c.stub.AddCall("Get", class, id)
   101  	if err := c.stub.NextErr(); err != nil {
   102  		return nil, errors.Trace(err)
   103  	}
   104  
   105  	fullID := payload.BuildID(class, id)
   106  	info, ok := c.payloads[fullID]
   107  	if !ok {
   108  		return nil, errors.NotFoundf(id)
   109  	}
   110  	return &info, nil
   111  }
   112  
   113  func (c *stubContextComponent) List() ([]string, error) {
   114  	c.stub.AddCall("List")
   115  	if err := c.stub.NextErr(); err != nil {
   116  		return nil, errors.Trace(err)
   117  	}
   118  
   119  	var fullIDs []string
   120  	for k := range c.payloads {
   121  		fullIDs = append(fullIDs, k)
   122  	}
   123  	return fullIDs, nil
   124  }
   125  
   126  func (c *stubContextComponent) Track(pl payload.Payload) error {
   127  	c.stub.AddCall("Track", pl)
   128  	if err := c.stub.NextErr(); err != nil {
   129  		return errors.Trace(err)
   130  	}
   131  
   132  	c.payloads[pl.FullID()] = pl
   133  	return nil
   134  }
   135  
   136  func (c *stubContextComponent) Untrack(class, id string) error {
   137  	c.stub.AddCall("Untrack", class, id)
   138  
   139  	if err := c.stub.NextErr(); err != nil {
   140  		return errors.Trace(err)
   141  	}
   142  
   143  	fullID := payload.BuildID(class, id)
   144  	c.untracks[fullID] = struct{}{}
   145  	return nil
   146  }
   147  
   148  func (c *stubContextComponent) SetStatus(class, id, status string) error {
   149  	c.stub.AddCall("SetStatus", class, id, status)
   150  	if err := c.stub.NextErr(); err != nil {
   151  		return errors.Trace(err)
   152  	}
   153  
   154  	fullID := payload.BuildID(class, id)
   155  	pl := c.payloads[fullID]
   156  	pl.Status = status
   157  	return nil
   158  }
   159  
   160  func (c *stubContextComponent) Flush() error {
   161  	c.stub.AddCall("Flush")
   162  	if err := c.stub.NextErr(); err != nil {
   163  		return errors.Trace(err)
   164  	}
   165  
   166  	return nil
   167  }
   168  
   169  type stubAPIClient struct {
   170  	stub *testing.Stub
   171  	// TODO(ericsnow) Use id for the key rather than Info.ID().
   172  	payloads map[string]payload.Payload
   173  }
   174  
   175  func newStubAPIClient(stub *testing.Stub) *stubAPIClient {
   176  	return &stubAPIClient{
   177  		stub:     stub,
   178  		payloads: make(map[string]payload.Payload),
   179  	}
   180  }
   181  
   182  func (c *stubAPIClient) setNew(fullIDs ...string) []payload.Payload {
   183  	var payloads []payload.Payload
   184  	for _, id := range fullIDs {
   185  		name, pluginID := payload.ParseID(id)
   186  		if name == "" {
   187  			panic("missing name")
   188  		}
   189  		if pluginID == "" {
   190  			panic("missing id")
   191  		}
   192  		wl := payload.Payload{
   193  			PayloadClass: charm.PayloadClass{
   194  				Name: name,
   195  				Type: "myplugin",
   196  			},
   197  			ID:     pluginID,
   198  			Status: payload.StateRunning,
   199  		}
   200  		c.payloads[id] = wl
   201  		payloads = append(payloads, wl)
   202  	}
   203  	return payloads
   204  }
   205  
   206  func (c *stubAPIClient) List(fullIDs ...string) ([]payload.Result, error) {
   207  	c.stub.AddCall("List", fullIDs)
   208  	if err := c.stub.NextErr(); err != nil {
   209  		return nil, errors.Trace(err)
   210  	}
   211  
   212  	var results []payload.Result
   213  	if fullIDs == nil {
   214  		for id, pl := range c.payloads {
   215  			results = append(results, payload.Result{
   216  				ID:      id,
   217  				Payload: &payload.FullPayloadInfo{Payload: pl},
   218  			})
   219  		}
   220  	} else {
   221  		for _, id := range fullIDs {
   222  			pl, ok := c.payloads[id]
   223  			if !ok {
   224  				return nil, errors.NotFoundf("pl %q", id)
   225  			}
   226  			results = append(results, payload.Result{
   227  				ID:      id,
   228  				Payload: &payload.FullPayloadInfo{Payload: pl},
   229  			})
   230  		}
   231  	}
   232  	return results, nil
   233  }
   234  
   235  func (c *stubAPIClient) Track(payloads ...payload.Payload) ([]payload.Result, error) {
   236  	c.stub.AddCall("Track", payloads)
   237  	if err := c.stub.NextErr(); err != nil {
   238  		return nil, errors.Trace(err)
   239  	}
   240  
   241  	var results []payload.Result
   242  	for _, pl := range payloads {
   243  		id := pl.FullID()
   244  		c.payloads[id] = pl
   245  		results = append(results, payload.Result{
   246  			ID:      id,
   247  			Payload: &payload.FullPayloadInfo{Payload: pl},
   248  		})
   249  	}
   250  	return results, nil
   251  }
   252  
   253  func (c *stubAPIClient) Untrack(fullIDs ...string) ([]payload.Result, error) {
   254  	c.stub.AddCall("Untrack", fullIDs)
   255  	if err := c.stub.NextErr(); err != nil {
   256  		return nil, errors.Trace(err)
   257  	}
   258  
   259  	errs := []payload.Result{}
   260  	for _, id := range fullIDs {
   261  		delete(c.payloads, id)
   262  		errs = append(errs, payload.Result{ID: id})
   263  	}
   264  	return errs, nil
   265  }
   266  
   267  func (c *stubAPIClient) SetStatus(status string, fullIDs ...string) ([]payload.Result, error) {
   268  	c.stub.AddCall("SetStatus", status, fullIDs)
   269  	if err := c.stub.NextErr(); err != nil {
   270  		return nil, errors.Trace(err)
   271  	}
   272  
   273  	errs := []payload.Result{}
   274  	for _, id := range fullIDs {
   275  		pl := c.payloads[id]
   276  		pl.Status = status
   277  		errs = append(errs, payload.Result{ID: id})
   278  	}
   279  
   280  	return errs, nil
   281  }