github.com/wallyworld/juju@v0.0.0-20161013125918-6cf1bc9d917a/payload/context/base_test.go (about) 1 // Copyright 2015 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package context_test 5 6 import ( 7 "reflect" 8 9 "github.com/juju/errors" 10 "github.com/juju/testing" 11 gc "gopkg.in/check.v1" 12 "gopkg.in/juju/charm.v6-unstable" 13 14 "github.com/juju/juju/payload" 15 "github.com/juju/juju/payload/context" 16 jujuctesting "github.com/juju/juju/worker/uniter/runner/jujuc/testing" 17 ) 18 19 type baseSuite struct { 20 jujuctesting.ContextSuite 21 payload payload.Payload 22 } 23 24 func (s *baseSuite) SetUpTest(c *gc.C) { 25 s.ContextSuite.SetUpTest(c) 26 27 s.payload = s.newPayload("payload A", "docker", "", "") 28 } 29 30 func (s *baseSuite) newPayload(name, ptype, id, status string) payload.Payload { 31 pl := payload.Payload{ 32 PayloadClass: charm.PayloadClass{ 33 Name: name, 34 Type: ptype, 35 }, 36 ID: id, 37 Status: status, 38 Unit: "a-application/0", 39 } 40 return pl 41 } 42 43 func (s *baseSuite) NewHookContext() (*stubHookContext, *jujuctesting.ContextInfo) { 44 ctx, info := s.ContextSuite.NewHookContext() 45 return &stubHookContext{ctx}, info 46 } 47 48 func checkPayloads(c *gc.C, payloads, expected []payload.Payload) { 49 if !c.Check(payloads, gc.HasLen, len(expected)) { 50 return 51 } 52 for _, wl := range payloads { 53 matched := false 54 for _, expPayload := range expected { 55 if reflect.DeepEqual(wl, expPayload) { 56 matched = true 57 break 58 } 59 } 60 if !matched { 61 c.Errorf("%#v != %#v", payloads, expected) 62 return 63 } 64 } 65 } 66 67 type stubHookContext struct { 68 *jujuctesting.Context 69 } 70 71 func (c stubHookContext) Component(name string) (context.Component, error) { 72 found, err := c.Context.Component(name) 73 if err != nil { 74 return nil, errors.Trace(err) 75 } 76 compCtx, ok := found.(context.Component) 77 if !ok && found != nil { 78 return nil, errors.Errorf("wrong component context type registered: %T", found) 79 } 80 return compCtx, nil 81 } 82 83 var _ context.Component = (*stubContextComponent)(nil) 84 85 type stubContextComponent struct { 86 stub *testing.Stub 87 payloads map[string]payload.Payload 88 untracks map[string]struct{} 89 } 90 91 func newStubContextComponent(stub *testing.Stub) *stubContextComponent { 92 return &stubContextComponent{ 93 stub: stub, 94 payloads: make(map[string]payload.Payload), 95 untracks: make(map[string]struct{}), 96 } 97 } 98 99 func (c *stubContextComponent) Get(class, id string) (*payload.Payload, error) { 100 c.stub.AddCall("Get", class, id) 101 if err := c.stub.NextErr(); err != nil { 102 return nil, errors.Trace(err) 103 } 104 105 fullID := payload.BuildID(class, id) 106 info, ok := c.payloads[fullID] 107 if !ok { 108 return nil, errors.NotFoundf(id) 109 } 110 return &info, nil 111 } 112 113 func (c *stubContextComponent) List() ([]string, error) { 114 c.stub.AddCall("List") 115 if err := c.stub.NextErr(); err != nil { 116 return nil, errors.Trace(err) 117 } 118 119 var fullIDs []string 120 for k := range c.payloads { 121 fullIDs = append(fullIDs, k) 122 } 123 return fullIDs, nil 124 } 125 126 func (c *stubContextComponent) Track(pl payload.Payload) error { 127 c.stub.AddCall("Track", pl) 128 if err := c.stub.NextErr(); err != nil { 129 return errors.Trace(err) 130 } 131 132 c.payloads[pl.FullID()] = pl 133 return nil 134 } 135 136 func (c *stubContextComponent) Untrack(class, id string) error { 137 c.stub.AddCall("Untrack", class, id) 138 139 if err := c.stub.NextErr(); err != nil { 140 return errors.Trace(err) 141 } 142 143 fullID := payload.BuildID(class, id) 144 c.untracks[fullID] = struct{}{} 145 return nil 146 } 147 148 func (c *stubContextComponent) SetStatus(class, id, status string) error { 149 c.stub.AddCall("SetStatus", class, id, status) 150 if err := c.stub.NextErr(); err != nil { 151 return errors.Trace(err) 152 } 153 154 fullID := payload.BuildID(class, id) 155 pl := c.payloads[fullID] 156 pl.Status = status 157 return nil 158 } 159 160 func (c *stubContextComponent) Flush() error { 161 c.stub.AddCall("Flush") 162 if err := c.stub.NextErr(); err != nil { 163 return errors.Trace(err) 164 } 165 166 return nil 167 } 168 169 type stubAPIClient struct { 170 stub *testing.Stub 171 // TODO(ericsnow) Use id for the key rather than Info.ID(). 172 payloads map[string]payload.Payload 173 } 174 175 func newStubAPIClient(stub *testing.Stub) *stubAPIClient { 176 return &stubAPIClient{ 177 stub: stub, 178 payloads: make(map[string]payload.Payload), 179 } 180 } 181 182 func (c *stubAPIClient) setNew(fullIDs ...string) []payload.Payload { 183 var payloads []payload.Payload 184 for _, id := range fullIDs { 185 name, pluginID := payload.ParseID(id) 186 if name == "" { 187 panic("missing name") 188 } 189 if pluginID == "" { 190 panic("missing id") 191 } 192 wl := payload.Payload{ 193 PayloadClass: charm.PayloadClass{ 194 Name: name, 195 Type: "myplugin", 196 }, 197 ID: pluginID, 198 Status: payload.StateRunning, 199 } 200 c.payloads[id] = wl 201 payloads = append(payloads, wl) 202 } 203 return payloads 204 } 205 206 func (c *stubAPIClient) List(fullIDs ...string) ([]payload.Result, error) { 207 c.stub.AddCall("List", fullIDs) 208 if err := c.stub.NextErr(); err != nil { 209 return nil, errors.Trace(err) 210 } 211 212 var results []payload.Result 213 if fullIDs == nil { 214 for id, pl := range c.payloads { 215 results = append(results, payload.Result{ 216 ID: id, 217 Payload: &payload.FullPayloadInfo{Payload: pl}, 218 }) 219 } 220 } else { 221 for _, id := range fullIDs { 222 pl, ok := c.payloads[id] 223 if !ok { 224 return nil, errors.NotFoundf("pl %q", id) 225 } 226 results = append(results, payload.Result{ 227 ID: id, 228 Payload: &payload.FullPayloadInfo{Payload: pl}, 229 }) 230 } 231 } 232 return results, nil 233 } 234 235 func (c *stubAPIClient) Track(payloads ...payload.Payload) ([]payload.Result, error) { 236 c.stub.AddCall("Track", payloads) 237 if err := c.stub.NextErr(); err != nil { 238 return nil, errors.Trace(err) 239 } 240 241 var results []payload.Result 242 for _, pl := range payloads { 243 id := pl.FullID() 244 c.payloads[id] = pl 245 results = append(results, payload.Result{ 246 ID: id, 247 Payload: &payload.FullPayloadInfo{Payload: pl}, 248 }) 249 } 250 return results, nil 251 } 252 253 func (c *stubAPIClient) Untrack(fullIDs ...string) ([]payload.Result, error) { 254 c.stub.AddCall("Untrack", fullIDs) 255 if err := c.stub.NextErr(); err != nil { 256 return nil, errors.Trace(err) 257 } 258 259 errs := []payload.Result{} 260 for _, id := range fullIDs { 261 delete(c.payloads, id) 262 errs = append(errs, payload.Result{ID: id}) 263 } 264 return errs, nil 265 } 266 267 func (c *stubAPIClient) SetStatus(status string, fullIDs ...string) ([]payload.Result, error) { 268 c.stub.AddCall("SetStatus", status, fullIDs) 269 if err := c.stub.NextErr(); err != nil { 270 return nil, errors.Trace(err) 271 } 272 273 errs := []payload.Result{} 274 for _, id := range fullIDs { 275 pl := c.payloads[id] 276 pl.Status = status 277 errs = append(errs, payload.Result{ID: id}) 278 } 279 280 return errs, nil 281 }