github.com/bigcommerce/nomad@v0.9.3-bc/client/vaultclient/vaultclient_testing.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"sync"
     5  
     6  	"github.com/hashicorp/nomad/helper/uuid"
     7  	"github.com/hashicorp/nomad/nomad/structs"
     8  	vaultapi "github.com/hashicorp/vault/api"
     9  )
    10  
    11  // MockVaultClient is used for testing the vaultclient integration and is safe
    12  // for concurrent access.
    13  type MockVaultClient struct {
    14  	// stoppedTokens tracks the tokens that have stopped renewing
    15  	stoppedTokens []string
    16  
    17  	// renewTokens are the tokens that have been renewed and their error
    18  	// channels
    19  	renewTokens map[string]chan error
    20  
    21  	// renewTokenErrors is used to return an error when the RenewToken is called
    22  	// with the given token
    23  	renewTokenErrors map[string]error
    24  
    25  	// deriveTokenErrors maps an allocation ID and tasks to an error when the
    26  	// token is derived
    27  	deriveTokenErrors map[string]map[string]error
    28  
    29  	// DeriveTokenFn allows the caller to control the DeriveToken function. If
    30  	// not set an error is returned if found in DeriveTokenErrors and otherwise
    31  	// a token is generated and returned
    32  	DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
    33  
    34  	mu sync.Mutex
    35  }
    36  
    37  // NewMockVaultClient returns a MockVaultClient for testing
    38  func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
    39  
    40  func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
    41  	vc.mu.Lock()
    42  	defer vc.mu.Unlock()
    43  
    44  	if vc.DeriveTokenFn != nil {
    45  		return vc.DeriveTokenFn(a, tasks)
    46  	}
    47  
    48  	tokens := make(map[string]string, len(tasks))
    49  	for _, task := range tasks {
    50  		if tasks, ok := vc.deriveTokenErrors[a.ID]; ok {
    51  			if err, ok := tasks[task]; ok {
    52  				return nil, err
    53  			}
    54  		}
    55  
    56  		tokens[task] = uuid.Generate()
    57  	}
    58  
    59  	return tokens, nil
    60  }
    61  
    62  func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
    63  	vc.mu.Lock()
    64  	defer vc.mu.Unlock()
    65  
    66  	if vc.deriveTokenErrors == nil {
    67  		vc.deriveTokenErrors = make(map[string]map[string]error, 10)
    68  	}
    69  
    70  	if _, ok := vc.renewTokenErrors[allocID]; !ok {
    71  		vc.deriveTokenErrors[allocID] = make(map[string]error, 10)
    72  	}
    73  
    74  	for _, task := range tasks {
    75  		vc.deriveTokenErrors[allocID][task] = err
    76  	}
    77  }
    78  
    79  func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
    80  	vc.mu.Lock()
    81  	defer vc.mu.Unlock()
    82  
    83  	if err, ok := vc.renewTokenErrors[token]; ok {
    84  		return nil, err
    85  	}
    86  
    87  	renewCh := make(chan error)
    88  	if vc.renewTokens == nil {
    89  		vc.renewTokens = make(map[string]chan error, 10)
    90  	}
    91  	vc.renewTokens[token] = renewCh
    92  	return renewCh, nil
    93  }
    94  
    95  func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
    96  	vc.mu.Lock()
    97  	defer vc.mu.Unlock()
    98  
    99  	if vc.renewTokenErrors == nil {
   100  		vc.renewTokenErrors = make(map[string]error, 10)
   101  	}
   102  
   103  	vc.renewTokenErrors[token] = err
   104  }
   105  
   106  func (vc *MockVaultClient) StopRenewToken(token string) error {
   107  	vc.mu.Lock()
   108  	defer vc.mu.Unlock()
   109  
   110  	vc.stoppedTokens = append(vc.stoppedTokens, token)
   111  	return nil
   112  }
   113  
   114  func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) {
   115  	return nil, nil
   116  }
   117  func (vc *MockVaultClient) StopRenewLease(leaseId string) error                   { return nil }
   118  func (vc *MockVaultClient) Start()                                                {}
   119  func (vc *MockVaultClient) Stop()                                                 {}
   120  func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }
   121  
   122  // StoppedTokens tracks the tokens that have stopped renewing
   123  func (vc *MockVaultClient) StoppedTokens() []string {
   124  	vc.mu.Lock()
   125  	defer vc.mu.Unlock()
   126  	return vc.stoppedTokens
   127  }
   128  
   129  // RenewTokens are the tokens that have been renewed and their error
   130  // channels
   131  func (vc *MockVaultClient) RenewTokens() map[string]chan error {
   132  	vc.mu.Lock()
   133  	defer vc.mu.Unlock()
   134  	return vc.renewTokens
   135  }
   136  
   137  // RenewTokenErrors is used to return an error when the RenewToken is called
   138  // with the given token
   139  func (vc *MockVaultClient) RenewTokenErrors() map[string]error {
   140  	vc.mu.Lock()
   141  	defer vc.mu.Unlock()
   142  	return vc.renewTokenErrors
   143  }
   144  
   145  // DeriveTokenErrors maps an allocation ID and tasks to an error when the
   146  // token is derived
   147  func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error {
   148  	vc.mu.Lock()
   149  	defer vc.mu.Unlock()
   150  	return vc.deriveTokenErrors
   151  }