github.com/hhrutter/nomad@v0.6.0-rc2.0.20170723054333-80c4b03f0705/client/vaultclient/vaultclient_testing.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"github.com/hashicorp/nomad/nomad/structs"
     5  	vaultapi "github.com/hashicorp/vault/api"
     6  )
     7  
     8  // MockVaultClient is used for testing the vaultclient integration
     9  type MockVaultClient struct {
    10  	// StoppedTokens tracks the tokens that have stopped renewing
    11  	StoppedTokens []string
    12  
    13  	// RenewTokens are the tokens that have been renewed and their error
    14  	// channels
    15  	RenewTokens map[string]chan error
    16  
    17  	// RenewTokenErrors is used to return an error when the RenewToken is called
    18  	// with the given token
    19  	RenewTokenErrors map[string]error
    20  
    21  	// DeriveTokenErrors maps an allocation ID and tasks to an error when the
    22  	// token is derived
    23  	DeriveTokenErrors map[string]map[string]error
    24  
    25  	// DeriveTokenFn allows the caller to control the DeriveToken function. If
    26  	// not set an error is returned if found in DeriveTokenErrors and otherwise
    27  	// a token is generated and returned
    28  	DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
    29  }
    30  
    31  // NewMockVaultClient returns a MockVaultClient for testing
    32  func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
    33  
    34  func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
    35  	if vc.DeriveTokenFn != nil {
    36  		return vc.DeriveTokenFn(a, tasks)
    37  	}
    38  
    39  	tokens := make(map[string]string, len(tasks))
    40  	for _, task := range tasks {
    41  		if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {
    42  			if err, ok := tasks[task]; ok {
    43  				return nil, err
    44  			}
    45  		}
    46  
    47  		tokens[task] = structs.GenerateUUID()
    48  	}
    49  
    50  	return tokens, nil
    51  }
    52  
    53  func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
    54  	if vc.DeriveTokenErrors == nil {
    55  		vc.DeriveTokenErrors = make(map[string]map[string]error, 10)
    56  	}
    57  
    58  	if _, ok := vc.RenewTokenErrors[allocID]; !ok {
    59  		vc.DeriveTokenErrors[allocID] = make(map[string]error, 10)
    60  	}
    61  
    62  	for _, task := range tasks {
    63  		vc.DeriveTokenErrors[allocID][task] = err
    64  	}
    65  }
    66  
    67  func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
    68  	if err, ok := vc.RenewTokenErrors[token]; ok {
    69  		return nil, err
    70  	}
    71  
    72  	renewCh := make(chan error)
    73  	if vc.RenewTokens == nil {
    74  		vc.RenewTokens = make(map[string]chan error, 10)
    75  	}
    76  	vc.RenewTokens[token] = renewCh
    77  	return renewCh, nil
    78  }
    79  
    80  func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
    81  	if vc.RenewTokenErrors == nil {
    82  		vc.RenewTokenErrors = make(map[string]error, 10)
    83  	}
    84  
    85  	vc.RenewTokenErrors[token] = err
    86  }
    87  
    88  func (vc *MockVaultClient) StopRenewToken(token string) error {
    89  	vc.StoppedTokens = append(vc.StoppedTokens, token)
    90  	return nil
    91  }
    92  
    93  func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) {
    94  	return nil, nil
    95  }
    96  func (vc *MockVaultClient) StopRenewLease(leaseId string) error                   { return nil }
    97  func (vc *MockVaultClient) Start()                                                {}
    98  func (vc *MockVaultClient) Stop()                                                 {}
    99  func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }