github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/helper/snapshot/snapshot_test.go (about)

     1  package snapshot
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/hashicorp/consul/sdk/testutil"
    16  	"github.com/hashicorp/go-msgpack/codec"
    17  	"github.com/hashicorp/nomad/nomad/structs"
    18  	"github.com/hashicorp/raft"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  // MockFSM is a simple FSM for testing that simply stores its logs in a slice of
    23  // byte slices.
    24  type MockFSM struct {
    25  	sync.Mutex
    26  	logs [][]byte
    27  }
    28  
    29  // MockSnapshot is a snapshot sink for testing that encodes the contents of a
    30  // MockFSM using msgpack.
    31  type MockSnapshot struct {
    32  	logs     [][]byte
    33  	maxIndex int
    34  }
    35  
    36  // See raft.FSM.
    37  func (m *MockFSM) Apply(log *raft.Log) interface{} {
    38  	m.Lock()
    39  	defer m.Unlock()
    40  	m.logs = append(m.logs, log.Data)
    41  	return len(m.logs)
    42  }
    43  
    44  // See raft.FSM.
    45  func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
    46  	m.Lock()
    47  	defer m.Unlock()
    48  	return &MockSnapshot{m.logs, len(m.logs)}, nil
    49  }
    50  
    51  // See raft.FSM.
    52  func (m *MockFSM) Restore(in io.ReadCloser) error {
    53  	m.Lock()
    54  	defer m.Unlock()
    55  	defer in.Close()
    56  	dec := codec.NewDecoder(in, structs.MsgpackHandle)
    57  
    58  	m.logs = nil
    59  	return dec.Decode(&m.logs)
    60  }
    61  
    62  // See raft.SnapshotSink.
    63  func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
    64  	enc := codec.NewEncoder(sink, structs.MsgpackHandle)
    65  	if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
    66  		sink.Cancel()
    67  		return err
    68  	}
    69  	sink.Close()
    70  	return nil
    71  }
    72  
    73  // See raft.SnapshotSink.
    74  func (m *MockSnapshot) Release() {
    75  }
    76  
    77  // makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
    78  func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
    79  	snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
    80  	if err != nil {
    81  		t.Fatalf("err: %v", err)
    82  	}
    83  
    84  	fsm := &MockFSM{}
    85  	store := raft.NewInmemStore()
    86  	addr, trans := raft.NewInmemTransport("")
    87  
    88  	config := raft.DefaultConfig()
    89  	config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))
    90  
    91  	var members raft.Configuration
    92  	members.Servers = append(members.Servers, raft.Server{
    93  		Suffrage: raft.Voter,
    94  		ID:       config.LocalID,
    95  		Address:  addr,
    96  	})
    97  
    98  	err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
    99  	if err != nil {
   100  		t.Fatalf("err: %v", err)
   101  	}
   102  
   103  	raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
   104  	if err != nil {
   105  		t.Fatalf("err: %v", err)
   106  	}
   107  
   108  	timeout := time.After(10 * time.Second)
   109  	for {
   110  		if raft.Leader() != "" {
   111  			break
   112  		}
   113  
   114  		select {
   115  		case <-raft.LeaderCh():
   116  		case <-time.After(1 * time.Second):
   117  			// Need to poll because we might have missed the first
   118  			// go with the leader channel.
   119  		case <-timeout:
   120  			t.Fatalf("timed out waiting for leader")
   121  		}
   122  	}
   123  
   124  	return raft, fsm
   125  }
   126  
   127  func TestSnapshot(t *testing.T) {
   128  	dir := testutil.TempDir(t, "snapshot")
   129  	defer os.RemoveAll(dir)
   130  
   131  	// Make a Raft and populate it with some data. We tee everything we
   132  	// apply off to a buffer for checking post-snapshot.
   133  	var expected []bytes.Buffer
   134  	entries := 64 * 1024
   135  	before, _ := makeRaft(t, filepath.Join(dir, "before"))
   136  	defer before.Shutdown()
   137  	for i := 0; i < entries; i++ {
   138  		var log bytes.Buffer
   139  		var copy bytes.Buffer
   140  		both := io.MultiWriter(&log, &copy)
   141  		if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
   142  			t.Fatalf("err: %v", err)
   143  		}
   144  		future := before.Apply(log.Bytes(), time.Second)
   145  		if err := future.Error(); err != nil {
   146  			t.Fatalf("err: %v", err)
   147  		}
   148  		expected = append(expected, copy)
   149  	}
   150  
   151  	// Take a snapshot.
   152  	logger := testutil.Logger(t)
   153  	snap, err := New(logger, before)
   154  	if err != nil {
   155  		t.Fatalf("err: %v", err)
   156  	}
   157  	defer snap.Close()
   158  
   159  	// Verify the snapshot. We have to rewind it after for the restore.
   160  	metadata, err := Verify(snap)
   161  	if err != nil {
   162  		t.Fatalf("err: %v", err)
   163  	}
   164  	if _, err := snap.file.Seek(0, 0); err != nil {
   165  		t.Fatalf("err: %v", err)
   166  	}
   167  	if int(metadata.Index) != entries+2 {
   168  		t.Fatalf("bad: %d", metadata.Index)
   169  	}
   170  	if metadata.Term != 2 {
   171  		t.Fatalf("bad: %d", metadata.Index)
   172  	}
   173  	if metadata.Version != raft.SnapshotVersionMax {
   174  		t.Fatalf("bad: %d", metadata.Version)
   175  	}
   176  
   177  	// Make a new, independent Raft.
   178  	after, fsm := makeRaft(t, filepath.Join(dir, "after"))
   179  	defer after.Shutdown()
   180  
   181  	// Put some initial data in there that the snapshot should overwrite.
   182  	for i := 0; i < 16; i++ {
   183  		var log bytes.Buffer
   184  		if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
   185  			t.Fatalf("err: %v", err)
   186  		}
   187  		future := after.Apply(log.Bytes(), time.Second)
   188  		if err := future.Error(); err != nil {
   189  			t.Fatalf("err: %v", err)
   190  		}
   191  	}
   192  
   193  	// Restore the snapshot.
   194  	if err := Restore(logger, snap, after); err != nil {
   195  		t.Fatalf("err: %v", err)
   196  	}
   197  
   198  	// Compare the contents.
   199  	fsm.Lock()
   200  	defer fsm.Unlock()
   201  	if len(fsm.logs) != len(expected) {
   202  		t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
   203  	}
   204  	for i := range fsm.logs {
   205  		if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
   206  			t.Fatalf("bad: log %d doesn't match", i)
   207  		}
   208  	}
   209  }
   210  
   211  func TestSnapshot_Nil(t *testing.T) {
   212  	var snap *Snapshot
   213  
   214  	if idx := snap.Index(); idx != 0 {
   215  		t.Fatalf("bad: %d", idx)
   216  	}
   217  
   218  	n, err := snap.Read(make([]byte, 16))
   219  	if n != 0 || err != io.EOF {
   220  		t.Fatalf("bad: %d %v", n, err)
   221  	}
   222  
   223  	if err := snap.Close(); err != nil {
   224  		t.Fatalf("err: %v", err)
   225  	}
   226  }
   227  
   228  func TestSnapshot_BadVerify(t *testing.T) {
   229  	buf := bytes.NewBuffer([]byte("nope"))
   230  	_, err := Verify(buf)
   231  	if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
   232  		t.Fatalf("err: %v", err)
   233  	}
   234  }
   235  
   236  func TestSnapshot_TruncatedVerify(t *testing.T) {
   237  	dir := testutil.TempDir(t, "snapshot")
   238  	defer os.RemoveAll(dir)
   239  
   240  	// Make a Raft and populate it with some data. We tee everything we
   241  	// apply off to a buffer for checking post-snapshot.
   242  	var expected []bytes.Buffer
   243  	entries := 64 * 1024
   244  	before, _ := makeRaft(t, filepath.Join(dir, "before"))
   245  	defer before.Shutdown()
   246  	for i := 0; i < entries; i++ {
   247  		var log bytes.Buffer
   248  		var copy bytes.Buffer
   249  		both := io.MultiWriter(&log, &copy)
   250  
   251  		_, err := io.CopyN(both, rand.Reader, 256)
   252  		require.NoError(t, err)
   253  
   254  		future := before.Apply(log.Bytes(), time.Second)
   255  		require.NoError(t, future.Error())
   256  		expected = append(expected, copy)
   257  	}
   258  
   259  	// Take a snapshot.
   260  	logger := testutil.Logger(t)
   261  	snap, err := New(logger, before)
   262  	require.NoError(t, err)
   263  	defer snap.Close()
   264  
   265  	var data []byte
   266  	{
   267  		var buf bytes.Buffer
   268  		_, err = io.Copy(&buf, snap)
   269  		require.NoError(t, err)
   270  		data = buf.Bytes()
   271  	}
   272  
   273  	for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
   274  		t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
   275  			// Lop off part of the end.
   276  			buf := bytes.NewReader(data[0 : len(data)-removeBytes])
   277  
   278  			_, err = Verify(buf)
   279  			require.Error(t, err)
   280  		})
   281  	}
   282  }
   283  
   284  func TestSnapshot_BadRestore(t *testing.T) {
   285  	dir := testutil.TempDir(t, "snapshot")
   286  	defer os.RemoveAll(dir)
   287  
   288  	// Make a Raft and populate it with some data.
   289  	before, _ := makeRaft(t, filepath.Join(dir, "before"))
   290  	defer before.Shutdown()
   291  	for i := 0; i < 16*1024; i++ {
   292  		var log bytes.Buffer
   293  		if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
   294  			t.Fatalf("err: %v", err)
   295  		}
   296  		future := before.Apply(log.Bytes(), time.Second)
   297  		if err := future.Error(); err != nil {
   298  			t.Fatalf("err: %v", err)
   299  		}
   300  	}
   301  
   302  	// Take a snapshot.
   303  	logger := testutil.Logger(t)
   304  	snap, err := New(logger, before)
   305  	if err != nil {
   306  		t.Fatalf("err: %v", err)
   307  	}
   308  
   309  	// Make a new, independent Raft.
   310  	after, fsm := makeRaft(t, filepath.Join(dir, "after"))
   311  	defer after.Shutdown()
   312  
   313  	// Put some initial data in there that should not be harmed by the
   314  	// failed restore attempt.
   315  	var expected []bytes.Buffer
   316  	for i := 0; i < 16; i++ {
   317  		var log bytes.Buffer
   318  		var copy bytes.Buffer
   319  		both := io.MultiWriter(&log, &copy)
   320  		if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
   321  			t.Fatalf("err: %v", err)
   322  		}
   323  		future := after.Apply(log.Bytes(), time.Second)
   324  		if err := future.Error(); err != nil {
   325  			t.Fatalf("err: %v", err)
   326  		}
   327  		expected = append(expected, copy)
   328  	}
   329  
   330  	// Attempt to restore a truncated version of the snapshot. This is
   331  	// expected to fail.
   332  	err = Restore(logger, io.LimitReader(snap, 512), after)
   333  	if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
   334  		t.Fatalf("err: %v", err)
   335  	}
   336  
   337  	// Compare the contents to make sure the aborted restore didn't harm
   338  	// anything.
   339  	fsm.Lock()
   340  	defer fsm.Unlock()
   341  	if len(fsm.logs) != len(expected) {
   342  		t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
   343  	}
   344  	for i := range fsm.logs {
   345  		if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
   346  			t.Fatalf("bad: log %d doesn't match", i)
   347  		}
   348  	}
   349  }