github.com/bigcommerce/nomad@v0.9.3-bc/client/vaultclient/vaultclient_testing.go (about) 1 package vaultclient 2 3 import ( 4 "sync" 5 6 "github.com/hashicorp/nomad/helper/uuid" 7 "github.com/hashicorp/nomad/nomad/structs" 8 vaultapi "github.com/hashicorp/vault/api" 9 ) 10 11 // MockVaultClient is used for testing the vaultclient integration and is safe 12 // for concurrent access. 13 type MockVaultClient struct { 14 // stoppedTokens tracks the tokens that have stopped renewing 15 stoppedTokens []string 16 17 // renewTokens are the tokens that have been renewed and their error 18 // channels 19 renewTokens map[string]chan error 20 21 // renewTokenErrors is used to return an error when the RenewToken is called 22 // with the given token 23 renewTokenErrors map[string]error 24 25 // deriveTokenErrors maps an allocation ID and tasks to an error when the 26 // token is derived 27 deriveTokenErrors map[string]map[string]error 28 29 // DeriveTokenFn allows the caller to control the DeriveToken function. If 30 // not set an error is returned if found in DeriveTokenErrors and otherwise 31 // a token is generated and returned 32 DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) 33 34 mu sync.Mutex 35 } 36 37 // NewMockVaultClient returns a MockVaultClient for testing 38 func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } 39 40 func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { 41 vc.mu.Lock() 42 defer vc.mu.Unlock() 43 44 if vc.DeriveTokenFn != nil { 45 return vc.DeriveTokenFn(a, tasks) 46 } 47 48 tokens := make(map[string]string, len(tasks)) 49 for _, task := range tasks { 50 if tasks, ok := vc.deriveTokenErrors[a.ID]; ok { 51 if err, ok := tasks[task]; ok { 52 return nil, err 53 } 54 } 55 56 tokens[task] = uuid.Generate() 57 } 58 59 return tokens, nil 60 } 61 62 func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { 63 vc.mu.Lock() 64 defer vc.mu.Unlock() 65 66 if vc.deriveTokenErrors == nil { 67 vc.deriveTokenErrors = make(map[string]map[string]error, 10) 68 } 69 70 if _, ok := vc.renewTokenErrors[allocID]; !ok { 71 vc.deriveTokenErrors[allocID] = make(map[string]error, 10) 72 } 73 74 for _, task := range tasks { 75 vc.deriveTokenErrors[allocID][task] = err 76 } 77 } 78 79 func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { 80 vc.mu.Lock() 81 defer vc.mu.Unlock() 82 83 if err, ok := vc.renewTokenErrors[token]; ok { 84 return nil, err 85 } 86 87 renewCh := make(chan error) 88 if vc.renewTokens == nil { 89 vc.renewTokens = make(map[string]chan error, 10) 90 } 91 vc.renewTokens[token] = renewCh 92 return renewCh, nil 93 } 94 95 func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { 96 vc.mu.Lock() 97 defer vc.mu.Unlock() 98 99 if vc.renewTokenErrors == nil { 100 vc.renewTokenErrors = make(map[string]error, 10) 101 } 102 103 vc.renewTokenErrors[token] = err 104 } 105 106 func (vc *MockVaultClient) StopRenewToken(token string) error { 107 vc.mu.Lock() 108 defer vc.mu.Unlock() 109 110 vc.stoppedTokens = append(vc.stoppedTokens, token) 111 return nil 112 } 113 114 func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) { 115 return nil, nil 116 } 117 func (vc *MockVaultClient) StopRenewLease(leaseId string) error { return nil } 118 func (vc *MockVaultClient) Start() {} 119 func (vc *MockVaultClient) Stop() {} 120 func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil } 121 122 // StoppedTokens tracks the tokens that have stopped renewing 123 func (vc *MockVaultClient) StoppedTokens() []string { 124 vc.mu.Lock() 125 defer vc.mu.Unlock() 126 return vc.stoppedTokens 127 } 128 129 // RenewTokens are the tokens that have been renewed and their error 130 // channels 131 func (vc *MockVaultClient) RenewTokens() map[string]chan error { 132 vc.mu.Lock() 133 defer vc.mu.Unlock() 134 return vc.renewTokens 135 } 136 137 // RenewTokenErrors is used to return an error when the RenewToken is called 138 // with the given token 139 func (vc *MockVaultClient) RenewTokenErrors() map[string]error { 140 vc.mu.Lock() 141 defer vc.mu.Unlock() 142 return vc.renewTokenErrors 143 } 144 145 // DeriveTokenErrors maps an allocation ID and tasks to an error when the 146 // token is derived 147 func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error { 148 vc.mu.Lock() 149 defer vc.mu.Unlock() 150 return vc.deriveTokenErrors 151 }