github.com/anuvu/nomad@v0.8.7-atom1/client/vaultclient/vaultclient_testing.go (about) 1 package vaultclient 2 3 import ( 4 "github.com/hashicorp/nomad/helper/uuid" 5 "github.com/hashicorp/nomad/nomad/structs" 6 vaultapi "github.com/hashicorp/vault/api" 7 ) 8 9 // MockVaultClient is used for testing the vaultclient integration 10 type MockVaultClient struct { 11 // StoppedTokens tracks the tokens that have stopped renewing 12 StoppedTokens []string 13 14 // RenewTokens are the tokens that have been renewed and their error 15 // channels 16 RenewTokens map[string]chan error 17 18 // RenewTokenErrors is used to return an error when the RenewToken is called 19 // with the given token 20 RenewTokenErrors map[string]error 21 22 // DeriveTokenErrors maps an allocation ID and tasks to an error when the 23 // token is derived 24 DeriveTokenErrors map[string]map[string]error 25 26 // DeriveTokenFn allows the caller to control the DeriveToken function. If 27 // not set an error is returned if found in DeriveTokenErrors and otherwise 28 // a token is generated and returned 29 DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) 30 } 31 32 // NewMockVaultClient returns a MockVaultClient for testing 33 func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } 34 35 func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { 36 if vc.DeriveTokenFn != nil { 37 return vc.DeriveTokenFn(a, tasks) 38 } 39 40 tokens := make(map[string]string, len(tasks)) 41 for _, task := range tasks { 42 if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok { 43 if err, ok := tasks[task]; ok { 44 return nil, err 45 } 46 } 47 48 tokens[task] = uuid.Generate() 49 } 50 51 return tokens, nil 52 } 53 54 func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { 55 if vc.DeriveTokenErrors == nil { 56 vc.DeriveTokenErrors = make(map[string]map[string]error, 10) 57 } 58 59 if _, ok := vc.RenewTokenErrors[allocID]; !ok { 60 vc.DeriveTokenErrors[allocID] = make(map[string]error, 10) 61 } 62 63 for _, task := range tasks { 64 vc.DeriveTokenErrors[allocID][task] = err 65 } 66 } 67 68 func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { 69 if err, ok := vc.RenewTokenErrors[token]; ok { 70 return nil, err 71 } 72 73 renewCh := make(chan error) 74 if vc.RenewTokens == nil { 75 vc.RenewTokens = make(map[string]chan error, 10) 76 } 77 vc.RenewTokens[token] = renewCh 78 return renewCh, nil 79 } 80 81 func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { 82 if vc.RenewTokenErrors == nil { 83 vc.RenewTokenErrors = make(map[string]error, 10) 84 } 85 86 vc.RenewTokenErrors[token] = err 87 } 88 89 func (vc *MockVaultClient) StopRenewToken(token string) error { 90 vc.StoppedTokens = append(vc.StoppedTokens, token) 91 return nil 92 } 93 94 func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) { 95 return nil, nil 96 } 97 func (vc *MockVaultClient) StopRenewLease(leaseId string) error { return nil } 98 func (vc *MockVaultClient) Start() {} 99 func (vc *MockVaultClient) Stop() {} 100 func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }