github.com/DerekStrickland/consul@v1.4.5/agent/consul/session_ttl_test.go (about)

     1  package consul
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/hashicorp/consul/agent/structs"
    11  	"github.com/hashicorp/consul/testrpc"
    12  	"github.com/hashicorp/consul/testutil/retry"
    13  	"github.com/hashicorp/go-uuid"
    14  	"github.com/hashicorp/net-rpc-msgpackrpc"
    15  )
    16  
    17  func generateUUID() (ret string) {
    18  	var err error
    19  	if ret, err = uuid.GenerateUUID(); err != nil {
    20  		panic(fmt.Sprintf("Unable to generate a UUID, %v", err))
    21  	}
    22  	return ret
    23  }
    24  
    25  func TestInitializeSessionTimers(t *testing.T) {
    26  	t.Parallel()
    27  	dir1, s1 := testServer(t)
    28  	defer os.RemoveAll(dir1)
    29  	defer s1.Shutdown()
    30  
    31  	testrpc.WaitForLeader(t, s1.RPC, "dc1")
    32  
    33  	state := s1.fsm.State()
    34  	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
    35  		t.Fatalf("err: %s", err)
    36  	}
    37  	session := &structs.Session{
    38  		ID:   generateUUID(),
    39  		Node: "foo",
    40  		TTL:  "10s",
    41  	}
    42  	if err := state.SessionCreate(100, session); err != nil {
    43  		t.Fatalf("err: %v", err)
    44  	}
    45  
    46  	// Reset the session timers
    47  	err := s1.initializeSessionTimers()
    48  	if err != nil {
    49  		t.Fatalf("err: %v", err)
    50  	}
    51  
    52  	// Check that we have a timer
    53  	if s1.sessionTimers.Get(session.ID) == nil {
    54  		t.Fatalf("missing session timer")
    55  	}
    56  }
    57  
    58  func TestResetSessionTimer_Fault(t *testing.T) {
    59  	t.Parallel()
    60  	dir1, s1 := testServer(t)
    61  	defer os.RemoveAll(dir1)
    62  	defer s1.Shutdown()
    63  
    64  	testrpc.WaitForLeader(t, s1.RPC, "dc1")
    65  
    66  	// Should not exist
    67  	err := s1.resetSessionTimer(generateUUID(), nil)
    68  	if err == nil || !strings.Contains(err.Error(), "not found") {
    69  		t.Fatalf("err: %v", err)
    70  	}
    71  
    72  	// Create a session
    73  	state := s1.fsm.State()
    74  	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
    75  		t.Fatalf("err: %s", err)
    76  	}
    77  	session := &structs.Session{
    78  		ID:   generateUUID(),
    79  		Node: "foo",
    80  		TTL:  "10s",
    81  	}
    82  	if err := state.SessionCreate(100, session); err != nil {
    83  		t.Fatalf("err: %v", err)
    84  	}
    85  
    86  	// Reset the session timer
    87  	err = s1.resetSessionTimer(session.ID, nil)
    88  	if err != nil {
    89  		t.Fatalf("err: %v", err)
    90  	}
    91  
    92  	// Check that we have a timer
    93  	if s1.sessionTimers.Get(session.ID) == nil {
    94  		t.Fatalf("missing session timer")
    95  	}
    96  }
    97  
    98  func TestResetSessionTimer_NoTTL(t *testing.T) {
    99  	t.Parallel()
   100  	dir1, s1 := testServer(t)
   101  	defer os.RemoveAll(dir1)
   102  	defer s1.Shutdown()
   103  
   104  	testrpc.WaitForLeader(t, s1.RPC, "dc1")
   105  
   106  	// Create a session
   107  	state := s1.fsm.State()
   108  	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
   109  		t.Fatalf("err: %s", err)
   110  	}
   111  	session := &structs.Session{
   112  		ID:   generateUUID(),
   113  		Node: "foo",
   114  		TTL:  "0000s",
   115  	}
   116  	if err := state.SessionCreate(100, session); err != nil {
   117  		t.Fatalf("err: %v", err)
   118  	}
   119  
   120  	// Reset the session timer
   121  	err := s1.resetSessionTimer(session.ID, session)
   122  	if err != nil {
   123  		t.Fatalf("err: %v", err)
   124  	}
   125  
   126  	// Check that we have a timer
   127  	if s1.sessionTimers.Get(session.ID) != nil {
   128  		t.Fatalf("should not have session timer")
   129  	}
   130  }
   131  
   132  func TestResetSessionTimer_InvalidTTL(t *testing.T) {
   133  	t.Parallel()
   134  	dir1, s1 := testServer(t)
   135  	defer os.RemoveAll(dir1)
   136  	defer s1.Shutdown()
   137  
   138  	// Create a session
   139  	session := &structs.Session{
   140  		ID:   generateUUID(),
   141  		Node: "foo",
   142  		TTL:  "foo",
   143  	}
   144  
   145  	// Reset the session timer
   146  	err := s1.resetSessionTimer(session.ID, session)
   147  	if err == nil || !strings.Contains(err.Error(), "Invalid Session TTL") {
   148  		t.Fatalf("err: %v", err)
   149  	}
   150  }
   151  
   152  func TestResetSessionTimerLocked(t *testing.T) {
   153  	t.Parallel()
   154  	dir1, s1 := testServer(t)
   155  	defer os.RemoveAll(dir1)
   156  	defer s1.Shutdown()
   157  
   158  	testrpc.WaitForLeader(t, s1.RPC, "dc1")
   159  
   160  	s1.createSessionTimer("foo", 5*time.Millisecond)
   161  	if s1.sessionTimers.Get("foo") == nil {
   162  		t.Fatalf("missing timer")
   163  	}
   164  
   165  	time.Sleep(10 * time.Millisecond * structs.SessionTTLMultiplier)
   166  	if s1.sessionTimers.Get("foo") != nil {
   167  		t.Fatalf("timer should be gone")
   168  	}
   169  }
   170  
   171  func TestResetSessionTimerLocked_Renew(t *testing.T) {
   172  	t.Parallel()
   173  	dir1, s1 := testServer(t)
   174  	defer os.RemoveAll(dir1)
   175  	defer s1.Shutdown()
   176  
   177  	ttl := 100 * time.Millisecond
   178  
   179  	// create the timer
   180  	s1.createSessionTimer("foo", ttl)
   181  	if s1.sessionTimers.Get("foo") == nil {
   182  		t.Fatalf("missing timer")
   183  	}
   184  
   185  	// wait until it is "expired" but at this point
   186  	// the session still exists.
   187  	time.Sleep(ttl)
   188  	if s1.sessionTimers.Get("foo") == nil {
   189  		t.Fatal("missing timer")
   190  	}
   191  
   192  	// renew the session which will reset the TTL to 2*ttl
   193  	// since that is the current SessionTTLMultiplier
   194  	s1.createSessionTimer("foo", ttl)
   195  
   196  	// Watch for invalidation
   197  	renew := time.Now()
   198  	deadline := renew.Add(2 * structs.SessionTTLMultiplier * ttl)
   199  	for {
   200  		now := time.Now()
   201  		if now.After(deadline) {
   202  			t.Fatal("should have expired by now")
   203  		}
   204  
   205  		// timer still exists
   206  		if s1.sessionTimers.Get("foo") != nil {
   207  			time.Sleep(time.Millisecond)
   208  			continue
   209  		}
   210  
   211  		// timer gone
   212  		if now.Sub(renew) < ttl {
   213  			t.Fatalf("early invalidate")
   214  		}
   215  		break
   216  	}
   217  }
   218  
   219  func TestInvalidateSession(t *testing.T) {
   220  	t.Parallel()
   221  	dir1, s1 := testServer(t)
   222  	defer os.RemoveAll(dir1)
   223  	defer s1.Shutdown()
   224  
   225  	testrpc.WaitForLeader(t, s1.RPC, "dc1")
   226  
   227  	// Create a session
   228  	state := s1.fsm.State()
   229  	if err := state.EnsureNode(1, &structs.Node{Node: "foo", Address: "127.0.0.1"}); err != nil {
   230  		t.Fatalf("err: %s", err)
   231  	}
   232  	session := &structs.Session{
   233  		ID:   generateUUID(),
   234  		Node: "foo",
   235  		TTL:  "10s",
   236  	}
   237  	if err := state.SessionCreate(100, session); err != nil {
   238  		t.Fatalf("err: %v", err)
   239  	}
   240  
   241  	// This should cause a destroy
   242  	s1.invalidateSession(session.ID)
   243  
   244  	// Check it is gone
   245  	_, sess, err := state.SessionGet(nil, session.ID)
   246  	if err != nil {
   247  		t.Fatalf("err: %v", err)
   248  	}
   249  	if sess != nil {
   250  		t.Fatalf("should destroy session")
   251  	}
   252  }
   253  
   254  func TestClearSessionTimer(t *testing.T) {
   255  	t.Parallel()
   256  	dir1, s1 := testServer(t)
   257  	defer os.RemoveAll(dir1)
   258  	defer s1.Shutdown()
   259  
   260  	s1.createSessionTimer("foo", 5*time.Millisecond)
   261  
   262  	err := s1.clearSessionTimer("foo")
   263  	if err != nil {
   264  		t.Fatalf("err: %v", err)
   265  	}
   266  
   267  	if s1.sessionTimers.Get("foo") != nil {
   268  		t.Fatalf("timer should be gone")
   269  	}
   270  }
   271  
   272  func TestClearAllSessionTimers(t *testing.T) {
   273  	t.Parallel()
   274  	dir1, s1 := testServer(t)
   275  	defer os.RemoveAll(dir1)
   276  	defer s1.Shutdown()
   277  
   278  	s1.createSessionTimer("foo", 10*time.Millisecond)
   279  	s1.createSessionTimer("bar", 10*time.Millisecond)
   280  	s1.createSessionTimer("baz", 10*time.Millisecond)
   281  
   282  	err := s1.clearAllSessionTimers()
   283  	if err != nil {
   284  		t.Fatalf("err: %v", err)
   285  	}
   286  
   287  	// sessionTimers is guarded by the lock
   288  	if s1.sessionTimers.Len() != 0 {
   289  		t.Fatalf("timers should be gone")
   290  	}
   291  }
   292  
   293  func TestServer_SessionTTL_Failover(t *testing.T) {
   294  	t.Parallel()
   295  	dir1, s1 := testServer(t)
   296  	defer os.RemoveAll(dir1)
   297  	defer s1.Shutdown()
   298  	testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
   299  
   300  	dir2, s2 := testServerDCBootstrap(t, "dc1", false)
   301  	defer os.RemoveAll(dir2)
   302  	defer s2.Shutdown()
   303  
   304  	dir3, s3 := testServerDCBootstrap(t, "dc1", false)
   305  	defer os.RemoveAll(dir3)
   306  	defer s3.Shutdown()
   307  	servers := []*Server{s1, s2, s3}
   308  
   309  	// Try to join
   310  	joinLAN(t, s2, s1)
   311  	joinLAN(t, s3, s1)
   312  	retry.Run(t, func(r *retry.R) { r.Check(wantPeers(s1, 3)) })
   313  
   314  	// Find the leader
   315  	var leader *Server
   316  	for _, s := range servers {
   317  		// Check that s.sessionTimers is empty
   318  		if s.sessionTimers.Len() != 0 {
   319  			t.Fatalf("should have no sessionTimers")
   320  		}
   321  		// Find the leader too
   322  		if s.IsLeader() {
   323  			leader = s
   324  		}
   325  	}
   326  	if leader == nil {
   327  		t.Fatalf("Should have a leader")
   328  	}
   329  
   330  	codec := rpcClient(t, leader)
   331  	defer codec.Close()
   332  
   333  	// Register a node
   334  	node := structs.RegisterRequest{
   335  		Datacenter: s1.config.Datacenter,
   336  		Node:       "foo",
   337  		Address:    "127.0.0.1",
   338  	}
   339  	var out struct{}
   340  	if err := s1.RPC("Catalog.Register", &node, &out); err != nil {
   341  		t.Fatalf("err: %v", err)
   342  	}
   343  
   344  	// Create a TTL session
   345  	arg := structs.SessionRequest{
   346  		Datacenter: "dc1",
   347  		Op:         structs.SessionCreate,
   348  		Session: structs.Session{
   349  			Node: "foo",
   350  			TTL:  "10s",
   351  		},
   352  	}
   353  	var id1 string
   354  	if err := msgpackrpc.CallWithCodec(codec, "Session.Apply", &arg, &id1); err != nil {
   355  		t.Fatalf("err: %v", err)
   356  	}
   357  
   358  	// Check that sessionTimers has the session ID
   359  	if leader.sessionTimers.Get(id1) == nil {
   360  		t.Fatalf("missing session timer")
   361  	}
   362  
   363  	// Shutdown the leader!
   364  	leader.Shutdown()
   365  
   366  	// sessionTimers should be cleared on leader shutdown
   367  	if leader.sessionTimers.Len() != 0 {
   368  		t.Fatalf("session timers should be empty on the shutdown leader")
   369  	}
   370  	// Find the new leader
   371  	retry.Run(t, func(r *retry.R) {
   372  		leader = nil
   373  		for _, s := range servers {
   374  			if s.IsLeader() {
   375  				leader = s
   376  			}
   377  		}
   378  		if leader == nil {
   379  			r.Fatal("Should have a new leader")
   380  		}
   381  
   382  		// Ensure session timer is restored
   383  		if leader.sessionTimers.Get(id1) == nil {
   384  			r.Fatal("missing session timer")
   385  		}
   386  	})
   387  }