github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/vaultclient/vaultclient_testing.go (about) 1 package vaultclient 2 3 import ( 4 "sync" 5 6 "github.com/hashicorp/nomad/helper/uuid" 7 "github.com/hashicorp/nomad/nomad/structs" 8 vaultapi "github.com/hashicorp/vault/api" 9 ) 10 11 // MockVaultClient is used for testing the vaultclient integration and is safe 12 // for concurrent access. 13 type MockVaultClient struct { 14 // stoppedTokens tracks the tokens that have stopped renewing 15 stoppedTokens []string 16 17 // renewTokens are the tokens that have been renewed and their error 18 // channels 19 renewTokens map[string]chan error 20 21 // renewTokenErrors is used to return an error when the RenewToken is called 22 // with the given token 23 renewTokenErrors map[string]error 24 25 // deriveTokenErrors maps an allocation ID and tasks to an error when the 26 // token is derived 27 deriveTokenErrors map[string]map[string]error 28 29 // DeriveTokenFn allows the caller to control the DeriveToken function. If 30 // not set an error is returned if found in DeriveTokenErrors and otherwise 31 // a token is generated and returned 32 DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) 33 34 mu sync.Mutex 35 } 36 37 // NewMockVaultClient returns a MockVaultClient for testing 38 func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } 39 40 func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { 41 vc.mu.Lock() 42 defer vc.mu.Unlock() 43 44 if vc.DeriveTokenFn != nil { 45 return vc.DeriveTokenFn(a, tasks) 46 } 47 48 tokens := make(map[string]string, len(tasks)) 49 for _, task := range tasks { 50 if tasks, ok := vc.deriveTokenErrors[a.ID]; ok { 51 if err, ok := tasks[task]; ok { 52 return nil, err 53 } 54 } 55 56 tokens[task] = uuid.Generate() 57 } 58 59 return tokens, nil 60 } 61 62 func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { 63 vc.mu.Lock() 64 defer vc.mu.Unlock() 65 66 if vc.deriveTokenErrors == nil { 67 vc.deriveTokenErrors = make(map[string]map[string]error, 10) 68 } 69 70 if _, ok := vc.deriveTokenErrors[allocID]; !ok { 71 vc.deriveTokenErrors[allocID] = make(map[string]error, 10) 72 } 73 74 for _, task := range tasks { 75 vc.deriveTokenErrors[allocID][task] = err 76 } 77 } 78 79 func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { 80 vc.mu.Lock() 81 defer vc.mu.Unlock() 82 83 if err, ok := vc.renewTokenErrors[token]; ok { 84 return nil, err 85 } 86 87 renewCh := make(chan error) 88 if vc.renewTokens == nil { 89 vc.renewTokens = make(map[string]chan error, 10) 90 } 91 vc.renewTokens[token] = renewCh 92 return renewCh, nil 93 } 94 95 func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { 96 vc.mu.Lock() 97 defer vc.mu.Unlock() 98 99 if vc.renewTokenErrors == nil { 100 vc.renewTokenErrors = make(map[string]error, 10) 101 } 102 103 vc.renewTokenErrors[token] = err 104 } 105 106 func (vc *MockVaultClient) StopRenewToken(token string) error { 107 vc.mu.Lock() 108 defer vc.mu.Unlock() 109 110 vc.stoppedTokens = append(vc.stoppedTokens, token) 111 return nil 112 } 113 114 func (vc *MockVaultClient) Start() {} 115 116 func (vc *MockVaultClient) Stop() {} 117 118 func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil } 119 120 // StoppedTokens tracks the tokens that have stopped renewing 121 func (vc *MockVaultClient) StoppedTokens() []string { 122 vc.mu.Lock() 123 defer vc.mu.Unlock() 124 return vc.stoppedTokens 125 } 126 127 // RenewTokens are the tokens that have been renewed and their error 128 // channels 129 func (vc *MockVaultClient) RenewTokens() map[string]chan error { 130 vc.mu.Lock() 131 defer vc.mu.Unlock() 132 return vc.renewTokens 133 } 134 135 // RenewTokenErrors is used to return an error when the RenewToken is called 136 // with the given token 137 func (vc *MockVaultClient) RenewTokenErrors() map[string]error { 138 vc.mu.Lock() 139 defer vc.mu.Unlock() 140 return vc.renewTokenErrors 141 } 142 143 // DeriveTokenErrors maps an allocation ID and tasks to an error when the 144 // token is derived 145 func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error { 146 vc.mu.Lock() 147 defer vc.mu.Unlock() 148 return vc.deriveTokenErrors 149 }