github.com/hernad/nomad@v1.6.112/helper/snapshot/snapshot_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package snapshot
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/rand"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/hashicorp/consul/sdk/testutil"
    19  	"github.com/hashicorp/go-msgpack/codec"
    20  	"github.com/hernad/nomad/nomad/structs"
    21  	"github.com/hashicorp/raft"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  // MockFSM is a simple FSM for testing that simply stores its logs in a slice of
    26  // byte slices.
    27  type MockFSM struct {
    28  	sync.Mutex
    29  	logs [][]byte
    30  }
    31  
    32  // MockSnapshot is a snapshot sink for testing that encodes the contents of a
    33  // MockFSM using msgpack.
    34  type MockSnapshot struct {
    35  	logs     [][]byte
    36  	maxIndex int
    37  }
    38  
    39  // See raft.FSM.
    40  func (m *MockFSM) Apply(log *raft.Log) interface{} {
    41  	m.Lock()
    42  	defer m.Unlock()
    43  	m.logs = append(m.logs, log.Data)
    44  	return len(m.logs)
    45  }
    46  
    47  // See raft.FSM.
    48  func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
    49  	m.Lock()
    50  	defer m.Unlock()
    51  	return &MockSnapshot{m.logs, len(m.logs)}, nil
    52  }
    53  
    54  // See raft.FSM.
    55  func (m *MockFSM) Restore(in io.ReadCloser) error {
    56  	m.Lock()
    57  	defer m.Unlock()
    58  	defer in.Close()
    59  	dec := codec.NewDecoder(in, structs.MsgpackHandle)
    60  
    61  	m.logs = nil
    62  	return dec.Decode(&m.logs)
    63  }
    64  
    65  // See raft.SnapshotSink.
    66  func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
    67  	enc := codec.NewEncoder(sink, structs.MsgpackHandle)
    68  	if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
    69  		sink.Cancel()
    70  		return err
    71  	}
    72  	sink.Close()
    73  	return nil
    74  }
    75  
    76  // See raft.SnapshotSink.
    77  func (m *MockSnapshot) Release() {
    78  }
    79  
    80  // makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
    81  func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
    82  	snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
    83  	if err != nil {
    84  		t.Fatalf("err: %v", err)
    85  	}
    86  
    87  	fsm := &MockFSM{}
    88  	store := raft.NewInmemStore()
    89  	addr, trans := raft.NewInmemTransport("")
    90  
    91  	config := raft.DefaultConfig()
    92  	config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))
    93  
    94  	var members raft.Configuration
    95  	members.Servers = append(members.Servers, raft.Server{
    96  		Suffrage: raft.Voter,
    97  		ID:       config.LocalID,
    98  		Address:  addr,
    99  	})
   100  
   101  	err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
   102  	if err != nil {
   103  		t.Fatalf("err: %v", err)
   104  	}
   105  
   106  	raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
   107  	if err != nil {
   108  		t.Fatalf("err: %v", err)
   109  	}
   110  
   111  	timeout := time.After(10 * time.Second)
   112  	for {
   113  		if raft.Leader() != "" {
   114  			break
   115  		}
   116  
   117  		select {
   118  		case <-raft.LeaderCh():
   119  		case <-time.After(1 * time.Second):
   120  			// Need to poll because we might have missed the first
   121  			// go with the leader channel.
   122  		case <-timeout:
   123  			t.Fatalf("timed out waiting for leader")
   124  		}
   125  	}
   126  
   127  	return raft, fsm
   128  }
   129  
   130  func TestSnapshot(t *testing.T) {
   131  	dir := testutil.TempDir(t, "snapshot")
   132  	defer os.RemoveAll(dir)
   133  
   134  	// Make a Raft and populate it with some data. We tee everything we
   135  	// apply off to a buffer for checking post-snapshot.
   136  	var expected []bytes.Buffer
   137  	entries := 64 * 1024
   138  	before, _ := makeRaft(t, filepath.Join(dir, "before"))
   139  	defer before.Shutdown()
   140  	for i := 0; i < entries; i++ {
   141  		var log bytes.Buffer
   142  		var copy bytes.Buffer
   143  		both := io.MultiWriter(&log, &copy)
   144  		if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
   145  			t.Fatalf("err: %v", err)
   146  		}
   147  		future := before.Apply(log.Bytes(), time.Second)
   148  		if err := future.Error(); err != nil {
   149  			t.Fatalf("err: %v", err)
   150  		}
   151  		expected = append(expected, copy)
   152  	}
   153  
   154  	// Take a snapshot.
   155  	logger := testutil.Logger(t)
   156  	snap, err := New(logger, before)
   157  	if err != nil {
   158  		t.Fatalf("err: %v", err)
   159  	}
   160  	defer snap.Close()
   161  
   162  	// Verify the snapshot. We have to rewind it after for the restore.
   163  	metadata, err := Verify(snap)
   164  	if err != nil {
   165  		t.Fatalf("err: %v", err)
   166  	}
   167  	if _, err := snap.file.Seek(0, 0); err != nil {
   168  		t.Fatalf("err: %v", err)
   169  	}
   170  	if int(metadata.Index) != entries+2 {
   171  		t.Fatalf("bad: %d", metadata.Index)
   172  	}
   173  	if metadata.Term != 2 {
   174  		t.Fatalf("bad: %d", metadata.Index)
   175  	}
   176  	if metadata.Version != raft.SnapshotVersionMax {
   177  		t.Fatalf("bad: %d", metadata.Version)
   178  	}
   179  
   180  	// Make a new, independent Raft.
   181  	after, fsm := makeRaft(t, filepath.Join(dir, "after"))
   182  	defer after.Shutdown()
   183  
   184  	// Put some initial data in there that the snapshot should overwrite.
   185  	for i := 0; i < 16; i++ {
   186  		var log bytes.Buffer
   187  		if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
   188  			t.Fatalf("err: %v", err)
   189  		}
   190  		future := after.Apply(log.Bytes(), time.Second)
   191  		if err := future.Error(); err != nil {
   192  			t.Fatalf("err: %v", err)
   193  		}
   194  	}
   195  
   196  	// Restore the snapshot.
   197  	if err := Restore(logger, snap, after); err != nil {
   198  		t.Fatalf("err: %v", err)
   199  	}
   200  
   201  	// Compare the contents.
   202  	fsm.Lock()
   203  	defer fsm.Unlock()
   204  	if len(fsm.logs) != len(expected) {
   205  		t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
   206  	}
   207  	for i := range fsm.logs {
   208  		if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
   209  			t.Fatalf("bad: log %d doesn't match", i)
   210  		}
   211  	}
   212  }
   213  
   214  func TestSnapshot_Nil(t *testing.T) {
   215  	var snap *Snapshot
   216  
   217  	if idx := snap.Index(); idx != 0 {
   218  		t.Fatalf("bad: %d", idx)
   219  	}
   220  
   221  	n, err := snap.Read(make([]byte, 16))
   222  	if n != 0 || err != io.EOF {
   223  		t.Fatalf("bad: %d %v", n, err)
   224  	}
   225  
   226  	if err := snap.Close(); err != nil {
   227  		t.Fatalf("err: %v", err)
   228  	}
   229  }
   230  
   231  func TestSnapshot_BadVerify(t *testing.T) {
   232  	buf := bytes.NewBuffer([]byte("nope"))
   233  	_, err := Verify(buf)
   234  	if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
   235  		t.Fatalf("err: %v", err)
   236  	}
   237  }
   238  
   239  func TestSnapshot_TruncatedVerify(t *testing.T) {
   240  	dir := testutil.TempDir(t, "snapshot")
   241  	defer os.RemoveAll(dir)
   242  
   243  	// Make a Raft and populate it with some data. We tee everything we
   244  	// apply off to a buffer for checking post-snapshot.
   245  	var expected []bytes.Buffer
   246  	entries := 64 * 1024
   247  	before, _ := makeRaft(t, filepath.Join(dir, "before"))
   248  	defer before.Shutdown()
   249  	for i := 0; i < entries; i++ {
   250  		var log bytes.Buffer
   251  		var copy bytes.Buffer
   252  		both := io.MultiWriter(&log, &copy)
   253  
   254  		_, err := io.CopyN(both, rand.Reader, 256)
   255  		require.NoError(t, err)
   256  
   257  		future := before.Apply(log.Bytes(), time.Second)
   258  		require.NoError(t, future.Error())
   259  		expected = append(expected, copy)
   260  	}
   261  
   262  	// Take a snapshot.
   263  	logger := testutil.Logger(t)
   264  	snap, err := New(logger, before)
   265  	require.NoError(t, err)
   266  	defer snap.Close()
   267  
   268  	var data []byte
   269  	{
   270  		var buf bytes.Buffer
   271  		_, err = io.Copy(&buf, snap)
   272  		require.NoError(t, err)
   273  		data = buf.Bytes()
   274  	}
   275  
   276  	for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
   277  		t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
   278  			// Lop off part of the end.
   279  			buf := bytes.NewReader(data[0 : len(data)-removeBytes])
   280  
   281  			_, err = Verify(buf)
   282  			require.Error(t, err)
   283  		})
   284  	}
   285  }
   286  
   287  func TestSnapshot_BadRestore(t *testing.T) {
   288  	dir := testutil.TempDir(t, "snapshot")
   289  	defer os.RemoveAll(dir)
   290  
   291  	// Make a Raft and populate it with some data.
   292  	before, _ := makeRaft(t, filepath.Join(dir, "before"))
   293  	defer before.Shutdown()
   294  	for i := 0; i < 16*1024; i++ {
   295  		var log bytes.Buffer
   296  		if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
   297  			t.Fatalf("err: %v", err)
   298  		}
   299  		future := before.Apply(log.Bytes(), time.Second)
   300  		if err := future.Error(); err != nil {
   301  			t.Fatalf("err: %v", err)
   302  		}
   303  	}
   304  
   305  	// Take a snapshot.
   306  	logger := testutil.Logger(t)
   307  	snap, err := New(logger, before)
   308  	if err != nil {
   309  		t.Fatalf("err: %v", err)
   310  	}
   311  
   312  	// Make a new, independent Raft.
   313  	after, fsm := makeRaft(t, filepath.Join(dir, "after"))
   314  	defer after.Shutdown()
   315  
   316  	// Put some initial data in there that should not be harmed by the
   317  	// failed restore attempt.
   318  	var expected []bytes.Buffer
   319  	for i := 0; i < 16; i++ {
   320  		var log bytes.Buffer
   321  		var copy bytes.Buffer
   322  		both := io.MultiWriter(&log, &copy)
   323  		if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
   324  			t.Fatalf("err: %v", err)
   325  		}
   326  		future := after.Apply(log.Bytes(), time.Second)
   327  		if err := future.Error(); err != nil {
   328  			t.Fatalf("err: %v", err)
   329  		}
   330  		expected = append(expected, copy)
   331  	}
   332  
   333  	// Attempt to restore a truncated version of the snapshot. This is
   334  	// expected to fail.
   335  	err = Restore(logger, io.LimitReader(snap, 512), after)
   336  	if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
   337  		t.Fatalf("err: %v", err)
   338  	}
   339  
   340  	// Compare the contents to make sure the aborted restore didn't harm
   341  	// anything.
   342  	fsm.Lock()
   343  	defer fsm.Unlock()
   344  	if len(fsm.logs) != len(expected) {
   345  		t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
   346  	}
   347  	for i := range fsm.logs {
   348  		if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
   349  			t.Fatalf("bad: log %d doesn't match", i)
   350  		}
   351  	}
   352  }