github.com/hhrutter/nomad@v0.6.0-rc2.0.20170723054333-80c4b03f0705/client/vaultclient/vaultclient_testing.go (about) 1 package vaultclient 2 3 import ( 4 "github.com/hashicorp/nomad/nomad/structs" 5 vaultapi "github.com/hashicorp/vault/api" 6 ) 7 8 // MockVaultClient is used for testing the vaultclient integration 9 type MockVaultClient struct { 10 // StoppedTokens tracks the tokens that have stopped renewing 11 StoppedTokens []string 12 13 // RenewTokens are the tokens that have been renewed and their error 14 // channels 15 RenewTokens map[string]chan error 16 17 // RenewTokenErrors is used to return an error when the RenewToken is called 18 // with the given token 19 RenewTokenErrors map[string]error 20 21 // DeriveTokenErrors maps an allocation ID and tasks to an error when the 22 // token is derived 23 DeriveTokenErrors map[string]map[string]error 24 25 // DeriveTokenFn allows the caller to control the DeriveToken function. If 26 // not set an error is returned if found in DeriveTokenErrors and otherwise 27 // a token is generated and returned 28 DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) 29 } 30 31 // NewMockVaultClient returns a MockVaultClient for testing 32 func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } 33 34 func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { 35 if vc.DeriveTokenFn != nil { 36 return vc.DeriveTokenFn(a, tasks) 37 } 38 39 tokens := make(map[string]string, len(tasks)) 40 for _, task := range tasks { 41 if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok { 42 if err, ok := tasks[task]; ok { 43 return nil, err 44 } 45 } 46 47 tokens[task] = structs.GenerateUUID() 48 } 49 50 return tokens, nil 51 } 52 53 func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { 54 if vc.DeriveTokenErrors == nil { 55 vc.DeriveTokenErrors = make(map[string]map[string]error, 10) 56 } 57 58 if _, ok := vc.RenewTokenErrors[allocID]; !ok { 59 vc.DeriveTokenErrors[allocID] = make(map[string]error, 10) 60 } 61 62 for _, task := range tasks { 63 vc.DeriveTokenErrors[allocID][task] = err 64 } 65 } 66 67 func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { 68 if err, ok := vc.RenewTokenErrors[token]; ok { 69 return nil, err 70 } 71 72 renewCh := make(chan error) 73 if vc.RenewTokens == nil { 74 vc.RenewTokens = make(map[string]chan error, 10) 75 } 76 vc.RenewTokens[token] = renewCh 77 return renewCh, nil 78 } 79 80 func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { 81 if vc.RenewTokenErrors == nil { 82 vc.RenewTokenErrors = make(map[string]error, 10) 83 } 84 85 vc.RenewTokenErrors[token] = err 86 } 87 88 func (vc *MockVaultClient) StopRenewToken(token string) error { 89 vc.StoppedTokens = append(vc.StoppedTokens, token) 90 return nil 91 } 92 93 func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) { 94 return nil, nil 95 } 96 func (vc *MockVaultClient) StopRenewLease(leaseId string) error { return nil } 97 func (vc *MockVaultClient) Start() {} 98 func (vc *MockVaultClient) Stop() {} 99 func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }