github.com/mattyr/nomad@v0.3.3-0.20160919021406-3485a065154a/client/vaultclient/vaultclient_test.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"log"
     5  	"os"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/hashicorp/nomad/client/config"
    10  	"github.com/hashicorp/nomad/testutil"
    11  	vaultapi "github.com/hashicorp/vault/api"
    12  )
    13  
    14  func TestVaultClient_EstablishConnection(t *testing.T) {
    15  	v := testutil.NewTestVault(t)
    16  
    17  	logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
    18  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
    19  	v.Config.TaskTokenTTL = "10s"
    20  	c, err := NewVaultClient(v.Config, logger, nil)
    21  	if err != nil {
    22  		t.Fatalf("failed to build vault client: %v", err)
    23  	}
    24  
    25  	c.Start()
    26  	defer c.Stop()
    27  
    28  	// Sleep a little while and check that no connection has been established.
    29  	time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond)
    30  
    31  	if c.ConnectionEstablished() {
    32  		t.Fatalf("ConnectionEstablished() returned true before Vault server started")
    33  	}
    34  
    35  	// Start Vault
    36  	v.Start()
    37  	defer v.Stop()
    38  
    39  	testutil.WaitForResult(func() (bool, error) {
    40  		return c.ConnectionEstablished(), nil
    41  	}, func(err error) {
    42  		t.Fatalf("Connection not established")
    43  	})
    44  }
    45  
    46  func TestVaultClient_TokenRenewals(t *testing.T) {
    47  	v := testutil.NewTestVault(t).Start()
    48  	defer v.Stop()
    49  
    50  	logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
    51  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
    52  	v.Config.TaskTokenTTL = "10s"
    53  	c, err := NewVaultClient(v.Config, logger, nil)
    54  	if err != nil {
    55  		t.Fatalf("failed to build vault client: %v", err)
    56  	}
    57  
    58  	c.Start()
    59  	defer c.Stop()
    60  
    61  	// Sleep a little while to ensure that the renewal loop is active
    62  	time.Sleep(3 * time.Second)
    63  
    64  	tcr := &vaultapi.TokenCreateRequest{
    65  		Policies:    []string{"foo", "bar"},
    66  		TTL:         "2s",
    67  		DisplayName: "derived-for-task",
    68  		Renewable:   new(bool),
    69  	}
    70  	*tcr.Renewable = true
    71  
    72  	num := 5
    73  	tokens := make([]string, num)
    74  	for i := 0; i < num; i++ {
    75  		c.client.SetToken(v.Config.Token)
    76  
    77  		if err := c.client.SetAddress(v.Config.Addr); err != nil {
    78  			t.Fatal(err)
    79  		}
    80  
    81  		secret, err := c.client.Auth().Token().Create(tcr)
    82  		if err != nil {
    83  			t.Fatalf("failed to create vault token: %v", err)
    84  		}
    85  
    86  		if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
    87  			t.Fatal("failed to derive a wrapped vault token")
    88  		}
    89  
    90  		tokens[i] = secret.Auth.ClientToken
    91  
    92  		errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
    93  		go func(errCh <-chan error) {
    94  			var err error
    95  			for {
    96  				select {
    97  				case err = <-errCh:
    98  					t.Fatalf("error while renewing the token: %v", err)
    99  				}
   100  			}
   101  		}(errCh)
   102  	}
   103  
   104  	if c.heap.Length() != num {
   105  		t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length())
   106  	}
   107  
   108  	time.Sleep(5 * time.Second)
   109  
   110  	for i := 0; i < num; i++ {
   111  		if err := c.StopRenewToken(tokens[i]); err != nil {
   112  			t.Fatal(err)
   113  		}
   114  	}
   115  
   116  	if c.heap.Length() != 0 {
   117  		t.Fatal("bad: heap length: expected: 0, actual: %d", c.heap.Length())
   118  	}
   119  }
   120  
   121  func TestVaultClient_Heap(t *testing.T) {
   122  	conf := config.DefaultConfig()
   123  	conf.VaultConfig.Enabled = true
   124  	conf.VaultConfig.Token = "testvaulttoken"
   125  	conf.VaultConfig.TaskTokenTTL = "10s"
   126  
   127  	logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
   128  	c, err := NewVaultClient(conf.VaultConfig, logger, nil)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  	if c == nil {
   133  		t.Fatal("failed to create vault client")
   134  	}
   135  
   136  	now := time.Now()
   137  
   138  	renewalReq1 := &vaultClientRenewalRequest{
   139  		errCh:     make(chan error, 1),
   140  		id:        "id1",
   141  		increment: 10,
   142  	}
   143  	if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
   144  		t.Fatal(err)
   145  	}
   146  	if !c.isTracked("id1") {
   147  		t.Fatalf("id1 should have been tracked")
   148  	}
   149  
   150  	renewalReq2 := &vaultClientRenewalRequest{
   151  		errCh:     make(chan error, 1),
   152  		id:        "id2",
   153  		increment: 10,
   154  	}
   155  	if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	if !c.isTracked("id2") {
   159  		t.Fatalf("id2 should have been tracked")
   160  	}
   161  
   162  	renewalReq3 := &vaultClientRenewalRequest{
   163  		errCh:     make(chan error, 1),
   164  		id:        "id3",
   165  		increment: 10,
   166  	}
   167  	if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
   168  		t.Fatal(err)
   169  	}
   170  	if !c.isTracked("id3") {
   171  		t.Fatalf("id3 should have been tracked")
   172  	}
   173  
   174  	// Reading elements should yield id2, id1 and id3 in order
   175  	req, _ := c.nextRenewal()
   176  	if req != renewalReq2 {
   177  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
   178  	}
   179  	if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
   180  		t.Fatal(err)
   181  	}
   182  
   183  	req, _ = c.nextRenewal()
   184  	if req != renewalReq1 {
   185  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
   186  	}
   187  	if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
   188  		t.Fatal(err)
   189  	}
   190  
   191  	req, _ = c.nextRenewal()
   192  	if req != renewalReq3 {
   193  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
   194  	}
   195  	if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
   196  		t.Fatal(err)
   197  	}
   198  
   199  	if err := c.StopRenewToken("id1"); err != nil {
   200  		t.Fatal(err)
   201  	}
   202  
   203  	if err := c.StopRenewToken("id2"); err != nil {
   204  		t.Fatal(err)
   205  	}
   206  
   207  	if err := c.StopRenewToken("id3"); err != nil {
   208  		t.Fatal(err)
   209  	}
   210  
   211  	if c.isTracked("id1") {
   212  		t.Fatalf("id1 should not have been tracked")
   213  	}
   214  
   215  	if c.isTracked("id1") {
   216  		t.Fatalf("id1 should not have been tracked")
   217  	}
   218  
   219  	if c.isTracked("id1") {
   220  		t.Fatalf("id1 should not have been tracked")
   221  	}
   222  
   223  }