github.com/anuvu/nomad@v0.8.7-atom1/client/vaultclient/vaultclient_testing.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"github.com/hashicorp/nomad/helper/uuid"
     5  	"github.com/hashicorp/nomad/nomad/structs"
     6  	vaultapi "github.com/hashicorp/vault/api"
     7  )
     8  
     9  // MockVaultClient is used for testing the vaultclient integration
    10  type MockVaultClient struct {
    11  	// StoppedTokens tracks the tokens that have stopped renewing
    12  	StoppedTokens []string
    13  
    14  	// RenewTokens are the tokens that have been renewed and their error
    15  	// channels
    16  	RenewTokens map[string]chan error
    17  
    18  	// RenewTokenErrors is used to return an error when the RenewToken is called
    19  	// with the given token
    20  	RenewTokenErrors map[string]error
    21  
    22  	// DeriveTokenErrors maps an allocation ID and tasks to an error when the
    23  	// token is derived
    24  	DeriveTokenErrors map[string]map[string]error
    25  
    26  	// DeriveTokenFn allows the caller to control the DeriveToken function. If
    27  	// not set an error is returned if found in DeriveTokenErrors and otherwise
    28  	// a token is generated and returned
    29  	DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
    30  }
    31  
    32  // NewMockVaultClient returns a MockVaultClient for testing
    33  func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
    34  
    35  func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
    36  	if vc.DeriveTokenFn != nil {
    37  		return vc.DeriveTokenFn(a, tasks)
    38  	}
    39  
    40  	tokens := make(map[string]string, len(tasks))
    41  	for _, task := range tasks {
    42  		if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {
    43  			if err, ok := tasks[task]; ok {
    44  				return nil, err
    45  			}
    46  		}
    47  
    48  		tokens[task] = uuid.Generate()
    49  	}
    50  
    51  	return tokens, nil
    52  }
    53  
    54  func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
    55  	if vc.DeriveTokenErrors == nil {
    56  		vc.DeriveTokenErrors = make(map[string]map[string]error, 10)
    57  	}
    58  
    59  	if _, ok := vc.RenewTokenErrors[allocID]; !ok {
    60  		vc.DeriveTokenErrors[allocID] = make(map[string]error, 10)
    61  	}
    62  
    63  	for _, task := range tasks {
    64  		vc.DeriveTokenErrors[allocID][task] = err
    65  	}
    66  }
    67  
    68  func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
    69  	if err, ok := vc.RenewTokenErrors[token]; ok {
    70  		return nil, err
    71  	}
    72  
    73  	renewCh := make(chan error)
    74  	if vc.RenewTokens == nil {
    75  		vc.RenewTokens = make(map[string]chan error, 10)
    76  	}
    77  	vc.RenewTokens[token] = renewCh
    78  	return renewCh, nil
    79  }
    80  
    81  func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
    82  	if vc.RenewTokenErrors == nil {
    83  		vc.RenewTokenErrors = make(map[string]error, 10)
    84  	}
    85  
    86  	vc.RenewTokenErrors[token] = err
    87  }
    88  
    89  func (vc *MockVaultClient) StopRenewToken(token string) error {
    90  	vc.StoppedTokens = append(vc.StoppedTokens, token)
    91  	return nil
    92  }
    93  
    94  func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) {
    95  	return nil, nil
    96  }
    97  func (vc *MockVaultClient) StopRenewLease(leaseId string) error                   { return nil }
    98  func (vc *MockVaultClient) Start()                                                {}
    99  func (vc *MockVaultClient) Stop()                                                 {}
   100  func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }