github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/consul/identities_testing.go (about) 1 package consul 2 3 import ( 4 "sync" 5 6 "github.com/hashicorp/nomad/helper/uuid" 7 "github.com/hashicorp/nomad/nomad/structs" 8 ) 9 10 // MockServiceIdentitiesClient is used for testing the client for managing consul service 11 // identity tokens. 12 type MockServiceIdentitiesClient struct { 13 // deriveTokenErrors maps an allocation ID and tasks to an error when the 14 // token is derived 15 deriveTokenErrors map[string]map[string]error 16 17 // DeriveTokenFn allows the caller to control the DeriveToken function. If 18 // not set an error is returned if found in DeriveTokenErrors and otherwise 19 // a token is generated and returned 20 DeriveTokenFn TokenDeriverFunc 21 22 // lock around everything 23 lock sync.Mutex 24 } 25 26 var _ ServiceIdentityAPI = (*MockServiceIdentitiesClient)(nil) 27 28 // NewMockServiceIdentitiesClient returns a MockServiceIdentitiesClient for testing. 29 func NewMockServiceIdentitiesClient() *MockServiceIdentitiesClient { 30 return &MockServiceIdentitiesClient{ 31 deriveTokenErrors: make(map[string]map[string]error), 32 } 33 } 34 35 func (mtc *MockServiceIdentitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) { 36 mtc.lock.Lock() 37 defer mtc.lock.Unlock() 38 39 // if the DeriveTokenFn is explicitly set, use that 40 if mtc.DeriveTokenFn != nil { 41 return mtc.DeriveTokenFn(alloc, tasks) 42 } 43 44 // generate a token for each task, unless the mock has an error ready for 45 // one or more of the tasks in which case return that 46 tokens := make(map[string]string, len(tasks)) 47 for _, task := range tasks { 48 if m, ok := mtc.deriveTokenErrors[alloc.ID]; ok { 49 if err, ok := m[task]; ok { 50 return nil, err 51 } 52 } 53 tokens[task] = uuid.Generate() 54 } 55 return tokens, nil 56 } 57 58 func (mtc *MockServiceIdentitiesClient) SetDeriveTokenError(allocID string, tasks []string, err error) { 59 mtc.lock.Lock() 60 defer mtc.lock.Unlock() 61 62 if _, ok := mtc.deriveTokenErrors[allocID]; !ok { 63 mtc.deriveTokenErrors[allocID] = make(map[string]error, 10) 64 } 65 66 for _, task := range tasks { 67 mtc.deriveTokenErrors[allocID][task] = err 68 } 69 } 70 71 func (mtc *MockServiceIdentitiesClient) DeriveTokenErrors() map[string]map[string]error { 72 mtc.lock.Lock() 73 defer mtc.lock.Unlock() 74 75 m := make(map[string]map[string]error) 76 for aID, tasks := range mtc.deriveTokenErrors { 77 for task, err := range tasks { 78 m[aID][task] = err 79 } 80 } 81 return m 82 }