github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/consul/identities_testing.go (about)

     1  package consul
     2  
     3  import (
     4  	"sync"
     5  
     6  	"github.com/hashicorp/nomad/helper/uuid"
     7  	"github.com/hashicorp/nomad/nomad/structs"
     8  )
     9  
    10  // MockServiceIdentitiesClient is used for testing the client for managing consul service
    11  // identity tokens.
    12  type MockServiceIdentitiesClient struct {
    13  	// deriveTokenErrors maps an allocation ID and tasks to an error when the
    14  	// token is derived
    15  	deriveTokenErrors map[string]map[string]error
    16  
    17  	// DeriveTokenFn allows the caller to control the DeriveToken function. If
    18  	// not set an error is returned if found in DeriveTokenErrors and otherwise
    19  	// a token is generated and returned
    20  	DeriveTokenFn TokenDeriverFunc
    21  
    22  	// lock around everything
    23  	lock sync.Mutex
    24  }
    25  
    26  var _ ServiceIdentityAPI = (*MockServiceIdentitiesClient)(nil)
    27  
    28  // NewMockServiceIdentitiesClient returns a MockServiceIdentitiesClient for testing.
    29  func NewMockServiceIdentitiesClient() *MockServiceIdentitiesClient {
    30  	return &MockServiceIdentitiesClient{
    31  		deriveTokenErrors: make(map[string]map[string]error),
    32  	}
    33  }
    34  
    35  func (mtc *MockServiceIdentitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) {
    36  	mtc.lock.Lock()
    37  	defer mtc.lock.Unlock()
    38  
    39  	// if the DeriveTokenFn is explicitly set, use that
    40  	if mtc.DeriveTokenFn != nil {
    41  		return mtc.DeriveTokenFn(alloc, tasks)
    42  	}
    43  
    44  	// generate a token for each task, unless the mock has an error ready for
    45  	// one or more of the tasks in which case return that
    46  	tokens := make(map[string]string, len(tasks))
    47  	for _, task := range tasks {
    48  		if m, ok := mtc.deriveTokenErrors[alloc.ID]; ok {
    49  			if err, ok := m[task]; ok {
    50  				return nil, err
    51  			}
    52  		}
    53  		tokens[task] = uuid.Generate()
    54  	}
    55  	return tokens, nil
    56  }
    57  
    58  func (mtc *MockServiceIdentitiesClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
    59  	mtc.lock.Lock()
    60  	defer mtc.lock.Unlock()
    61  
    62  	if _, ok := mtc.deriveTokenErrors[allocID]; !ok {
    63  		mtc.deriveTokenErrors[allocID] = make(map[string]error, 10)
    64  	}
    65  
    66  	for _, task := range tasks {
    67  		mtc.deriveTokenErrors[allocID][task] = err
    68  	}
    69  }
    70  
    71  func (mtc *MockServiceIdentitiesClient) DeriveTokenErrors() map[string]map[string]error {
    72  	mtc.lock.Lock()
    73  	defer mtc.lock.Unlock()
    74  
    75  	m := make(map[string]map[string]error)
    76  	for aID, tasks := range mtc.deriveTokenErrors {
    77  		for task, err := range tasks {
    78  			m[aID][task] = err
    79  		}
    80  	}
    81  	return m
    82  }