github.com/hhrutter/nomad@v0.6.0-rc2.0.20170723054333-80c4b03f0705/client/vaultclient/vaultclient_test.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"log"
     5  	"os"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/hashicorp/nomad/client/config"
    10  	"github.com/hashicorp/nomad/testutil"
    11  	vaultapi "github.com/hashicorp/vault/api"
    12  )
    13  
    14  func TestVaultClient_TokenRenewals(t *testing.T) {
    15  	t.Parallel()
    16  	v := testutil.NewTestVault(t).Start()
    17  	defer v.Stop()
    18  
    19  	logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
    20  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
    21  	v.Config.TaskTokenTTL = "4s"
    22  	c, err := NewVaultClient(v.Config, logger, nil)
    23  	if err != nil {
    24  		t.Fatalf("failed to build vault client: %v", err)
    25  	}
    26  
    27  	c.Start()
    28  	defer c.Stop()
    29  
    30  	// Sleep a little while to ensure that the renewal loop is active
    31  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
    32  
    33  	tcr := &vaultapi.TokenCreateRequest{
    34  		Policies:    []string{"foo", "bar"},
    35  		TTL:         "2s",
    36  		DisplayName: "derived-for-task",
    37  		Renewable:   new(bool),
    38  	}
    39  	*tcr.Renewable = true
    40  
    41  	num := 5
    42  	tokens := make([]string, num)
    43  	for i := 0; i < num; i++ {
    44  		c.client.SetToken(v.Config.Token)
    45  
    46  		if err := c.client.SetAddress(v.Config.Addr); err != nil {
    47  			t.Fatal(err)
    48  		}
    49  
    50  		secret, err := c.client.Auth().Token().Create(tcr)
    51  		if err != nil {
    52  			t.Fatalf("failed to create vault token: %v", err)
    53  		}
    54  
    55  		if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
    56  			t.Fatal("failed to derive a wrapped vault token")
    57  		}
    58  
    59  		tokens[i] = secret.Auth.ClientToken
    60  
    61  		errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
    62  		if err != nil {
    63  			t.Fatalf("Unexpected error: %v", err)
    64  		}
    65  
    66  		go func(errCh <-chan error) {
    67  			for {
    68  				select {
    69  				case err := <-errCh:
    70  					if err != nil {
    71  						t.Fatalf("error while renewing the token: %v", err)
    72  					}
    73  				}
    74  			}
    75  		}(errCh)
    76  	}
    77  
    78  	if c.heap.Length() != num {
    79  		t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length())
    80  	}
    81  
    82  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
    83  
    84  	for i := 0; i < num; i++ {
    85  		if err := c.StopRenewToken(tokens[i]); err != nil {
    86  			t.Fatal(err)
    87  		}
    88  	}
    89  
    90  	if c.heap.Length() != 0 {
    91  		t.Fatalf("bad: heap length: expected: 0, actual: %d", c.heap.Length())
    92  	}
    93  }
    94  
    95  func TestVaultClient_Heap(t *testing.T) {
    96  	t.Parallel()
    97  	tr := true
    98  	conf := config.DefaultConfig()
    99  	conf.VaultConfig.Enabled = &tr
   100  	conf.VaultConfig.Token = "testvaulttoken"
   101  	conf.VaultConfig.TaskTokenTTL = "10s"
   102  
   103  	logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
   104  	c, err := NewVaultClient(conf.VaultConfig, logger, nil)
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	if c == nil {
   109  		t.Fatal("failed to create vault client")
   110  	}
   111  
   112  	now := time.Now()
   113  
   114  	renewalReq1 := &vaultClientRenewalRequest{
   115  		errCh:     make(chan error, 1),
   116  		id:        "id1",
   117  		increment: 10,
   118  	}
   119  	if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  	if !c.isTracked("id1") {
   123  		t.Fatalf("id1 should have been tracked")
   124  	}
   125  
   126  	renewalReq2 := &vaultClientRenewalRequest{
   127  		errCh:     make(chan error, 1),
   128  		id:        "id2",
   129  		increment: 10,
   130  	}
   131  	if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
   132  		t.Fatal(err)
   133  	}
   134  	if !c.isTracked("id2") {
   135  		t.Fatalf("id2 should have been tracked")
   136  	}
   137  
   138  	renewalReq3 := &vaultClientRenewalRequest{
   139  		errCh:     make(chan error, 1),
   140  		id:        "id3",
   141  		increment: 10,
   142  	}
   143  	if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
   144  		t.Fatal(err)
   145  	}
   146  	if !c.isTracked("id3") {
   147  		t.Fatalf("id3 should have been tracked")
   148  	}
   149  
   150  	// Reading elements should yield id2, id1 and id3 in order
   151  	req, _ := c.nextRenewal()
   152  	if req != renewalReq2 {
   153  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
   154  	}
   155  	if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
   156  		t.Fatal(err)
   157  	}
   158  
   159  	req, _ = c.nextRenewal()
   160  	if req != renewalReq1 {
   161  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
   162  	}
   163  	if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
   164  		t.Fatal(err)
   165  	}
   166  
   167  	req, _ = c.nextRenewal()
   168  	if req != renewalReq3 {
   169  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
   170  	}
   171  	if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
   172  		t.Fatal(err)
   173  	}
   174  
   175  	if err := c.StopRenewToken("id1"); err != nil {
   176  		t.Fatal(err)
   177  	}
   178  
   179  	if err := c.StopRenewToken("id2"); err != nil {
   180  		t.Fatal(err)
   181  	}
   182  
   183  	if err := c.StopRenewToken("id3"); err != nil {
   184  		t.Fatal(err)
   185  	}
   186  
   187  	if c.isTracked("id1") {
   188  		t.Fatalf("id1 should not have been tracked")
   189  	}
   190  
   191  	if c.isTracked("id1") {
   192  		t.Fatalf("id1 should not have been tracked")
   193  	}
   194  
   195  	if c.isTracked("id1") {
   196  		t.Fatalf("id1 should not have been tracked")
   197  	}
   198  
   199  }