github.com/transparency-dev/armored-witness-applet@v0.1.1/trusted_applet/internal/storage/slots/journal_test.go (about)

     1  // Copyright 2022 The Armored Witness Applet authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package slots
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/sha256"
    20  	"fmt"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/transparency-dev/armored-witness-applet/trusted_applet/internal/storage/testonly"
    26  )
    27  
    28  func magic0Hdr() [4]byte {
    29  	return [4]byte{magic0[0], magic0[1], magic0[2], magic0[3]}
    30  }
    31  
    32  func TestEntryMarshal(t *testing.T) {
    33  	for _, test := range []struct {
    34  		name    string
    35  		e       entry
    36  		want    []byte
    37  		wantErr bool
    38  	}{
    39  		{
    40  			name: "golden",
    41  			e: entry{
    42  				Magic:      magic0Hdr(),
    43  				Revision:   42,
    44  				DataLen:    uint64(len("hello")),
    45  				DataSHA256: sha256.Sum256([]byte("hello")),
    46  				Data:       []byte("hello"),
    47  			},
    48  			want: []byte{
    49  				'T', 'F', 'J', '0',
    50  				0x00, 0x00, 0x00, 0x2a,
    51  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
    52  				0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24,
    53  				'h', 'e', 'l', 'l', 'o',
    54  			},
    55  		}, {
    56  			name: "bad magic",
    57  			e: entry{
    58  				Magic:      [4]byte{'n', 'O', 'p', 'E'},
    59  				Revision:   42,
    60  				DataLen:    uint64(len("hello")),
    61  				DataSHA256: sha256.Sum256([]byte("hello")),
    62  				Data:       []byte("hello"),
    63  			},
    64  			wantErr: true,
    65  		}, {
    66  			name: "bad hash",
    67  			e: entry{
    68  				Magic:      magic0Hdr(),
    69  				Revision:   42,
    70  				DataLen:    uint64(len("hello")),
    71  				DataSHA256: sha256.Sum256([]byte("nOpE")),
    72  				Data:       []byte("hello"),
    73  			},
    74  			wantErr: true,
    75  		},
    76  	} {
    77  		t.Run(test.name, func(t *testing.T) {
    78  			b := &bytes.Buffer{}
    79  			err := marshalEntry(test.e, b)
    80  			if gotErr := err != nil; gotErr != test.wantErr {
    81  				t.Fatalf("marshalEntry: %v, wantErr %t", err, test.wantErr)
    82  			}
    83  			if test.wantErr {
    84  				return
    85  			}
    86  
    87  			if diff := cmp.Diff(b.Bytes(), test.want); diff != "" {
    88  				t.Fatalf("Got diff: %v", diff)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func TestEntryUnmarshal(t *testing.T) {
    95  	for _, test := range []struct {
    96  		name    string
    97  		want    *entry
    98  		b       []byte
    99  		wantErr bool
   100  	}{
   101  		{
   102  			name: "golden",
   103  			want: &entry{
   104  				Magic:      magic0Hdr(),
   105  				Revision:   42,
   106  				DataLen:    uint64(len("hello")),
   107  				DataSHA256: sha256.Sum256([]byte("hello")),
   108  				Data:       []byte("hello"),
   109  			},
   110  			b: []byte{
   111  				'T', 'F', 'J', '0',
   112  				0x00, 0x00, 0x00, 0x2a,
   113  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
   114  				0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24,
   115  				'h', 'e', 'l', 'l', 'o',
   116  			},
   117  		}, {
   118  			name: "bad magic",
   119  			b: []byte{
   120  				'N', 'o', 'P', 'e',
   121  				0x00, 0x00, 0x00, 0x2a,
   122  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
   123  				0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24,
   124  				'h', 'e', 'l', 'l', 'o',
   125  			},
   126  			wantErr: true,
   127  		}, {
   128  			name: "bad hash",
   129  			b: []byte{
   130  				'T', 'F', 'J', '0',
   131  				0x00, 0x00, 0x00, 0x2a,
   132  				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
   133  				0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24,
   134  				'W', 'R', 'O', 'N', 'G',
   135  			},
   136  			wantErr: true,
   137  		},
   138  	} {
   139  		t.Run(test.name, func(t *testing.T) {
   140  			got, err := unmarshalEntry(bytes.NewReader(test.b))
   141  			if gotErr := err != nil; gotErr != test.wantErr {
   142  				t.Fatalf("unmarshalEntry: %v, wantErr %t", err, test.wantErr)
   143  			}
   144  			if test.wantErr {
   145  				return
   146  			}
   147  
   148  			if diff := cmp.Diff(got, test.want); diff != "" {
   149  				t.Fatalf("Got diff: %v", diff)
   150  			}
   151  		})
   152  	}
   153  }
   154  
   155  func TestOpenJournal(t *testing.T) {
   156  	md := testonly.NewMemDev(t, 2)
   157  	start, length := uint(1), uint(len(md.Storage)-1)
   158  	j, err := OpenJournal(md, start, length)
   159  	if err != nil {
   160  		t.Fatalf("OpenJournal: %v", err)
   161  	}
   162  	if j.start != start {
   163  		t.Errorf("Journal.start=%d, want %d", j.start, start)
   164  	}
   165  	if j.length != length {
   166  		t.Errorf("Journal.length=%d, want %d", j.length, length)
   167  	}
   168  }
   169  
   170  func TestWriteSizeLimit(t *testing.T) {
   171  	storageBlocks := uint(20)
   172  	md := testonly.NewMemDev(t, storageBlocks)
   173  	start, length := uint(1), storageBlocks
   174  
   175  	j, err := OpenJournal(md, start, length)
   176  	if err != nil {
   177  		t.Fatalf("OpenJournal: %v", err)
   178  	}
   179  
   180  	limit := int((storageBlocks * md.BlockSize() / 3) - entryHeaderSize)
   181  	if err := j.Update(fill(limit, "ok...")); err != nil {
   182  		t.Fatalf("Update: %q, but expected write to succeed", err)
   183  	}
   184  	if err := j.Update(fill(limit+1, "BOOM")); err == nil {
   185  		t.Fatal("Update succeeded, but expected write fail")
   186  	}
   187  }
   188  
   189  func TestPerfectlyFullJournal(t *testing.T) {
   190  	storageBlocks := uint(10)
   191  	md := testonly.NewMemDev(t, storageBlocks)
   192  	start, length := uint(1), storageBlocks-1
   193  
   194  	var prevData []byte
   195  	for i, test := range []struct {
   196  		data               []byte
   197  		expectedWriteBlock uint
   198  	}{
   199  		{data: []byte{}, expectedWriteBlock: start},                          // 1 block
   200  		{data: fill(256, "One"), expectedWriteBlock: start + 1},              // 1 block
   201  		{data: fill(1000, "Two"), expectedWriteBlock: start + 1 + 1},         // 3 blocks
   202  		{data: fill(1000, "Three"), expectedWriteBlock: start + 1 + 1 + 3},   // 3 blocks
   203  		{data: fill(433, "Four"), expectedWriteBlock: start + 1 + 1 + 3 + 3}, // 1 block - we're full!
   204  	} {
   205  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   206  			j, err := OpenJournal(md, start, length)
   207  			if err != nil {
   208  				t.Fatalf("OpenJournal: %v", err)
   209  			}
   210  
   211  			// Check the current record is as we expect:
   212  			// - Revision should match the iteration count
   213  			// - The next block to write to is correct
   214  			// - The stored data matches what we last wrote
   215  			curData, rev := j.Data()
   216  			if rev != uint32(i) {
   217  				t.Errorf("Got revision %d, want %d", rev, i)
   218  			}
   219  			if got, want := j.nextBlock, test.expectedWriteBlock; got != want {
   220  				t.Errorf("nextBlock = %d, want %d", got, want)
   221  			}
   222  			// Ensure we see the data written in the last iteration, if any
   223  			if prevData != nil && !bytes.Equal(curData, prevData) {
   224  				t.Errorf("Got data %q, want %q", string(curData), string(prevData))
   225  			}
   226  			prevData = test.data
   227  
   228  			// Write some updated data
   229  			if err := j.Update(test.data); err != nil {
   230  				t.Fatalf("Update: %v", err)
   231  			}
   232  		})
   233  	}
   234  
   235  	t.Run("final", func(t *testing.T) {
   236  		// Now just ensure we can successfully read from the last entry of a full journal,
   237  		// and that the next write position is at the start of the journal.
   238  		j, err := OpenJournal(md, start, length)
   239  		if err != nil {
   240  			t.Fatalf("OpenJournal: %v", err)
   241  		}
   242  		if got, want := j.nextBlock, start; got != want {
   243  			t.Fatalf("nextBlock didn't wrap to first block of journal, got %d, want %d", got, want)
   244  		}
   245  	})
   246  }
   247  
   248  func TestRoundTrip(t *testing.T) {
   249  	storageBlocks := uint(20)
   250  	md := testonly.NewMemDev(t, storageBlocks)
   251  	start, length := uint(1), storageBlocks-1
   252  
   253  	var prevData []byte
   254  	for i, test := range []struct {
   255  		data               []byte
   256  		expectedWriteBlock uint
   257  	}{
   258  		{data: []byte{}, expectedWriteBlock: start},                                    // 1 block
   259  		{data: fill(256, "hello"), expectedWriteBlock: start + 1},                      // 1 block
   260  		{data: fill(1000, "there"), expectedWriteBlock: start + 1 + 1},                 // 3 blocks
   261  		{data: fill(30, "how"), expectedWriteBlock: start + 1 + 1 + 3},                 // 1 block
   262  		{data: fill(3000, "are"), expectedWriteBlock: start + 1 + 1 + 3 + 1},           // 6 blocks
   263  		{data: fill(59, "you"), expectedWriteBlock: start + 1 + 1 + 3 + 1 + 6},         // 1 block
   264  		{data: fill(3000, "doing"), expectedWriteBlock: start + 1 + 1 + 3 + 1 + 6 + 1}, // 6 blocks,
   265  		{data: fill(1000, "doing"), expectedWriteBlock: start},                         // 3 blocks, won't fit so will end up writing at `start`
   266  		{data: fill(3000, "today?"), expectedWriteBlock: start + 3},                    // 6 blocks
   267  		{data: fill(20, "All done!"), expectedWriteBlock: start + 9},                   // 1 blocks
   268  	} {
   269  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   270  			j, err := OpenJournal(md, start, length)
   271  			if err != nil {
   272  				t.Fatalf("OpenJournal: %v", err)
   273  			}
   274  
   275  			// Check the current record is as we expect:
   276  			// - Revision should match the iteration count
   277  			// - The next block to write to is correct
   278  			// - The stored data matches what we last wrote
   279  			curData, rev := j.Data()
   280  			if rev != uint32(i) {
   281  				t.Errorf("Got revision %d, want %d", rev, i)
   282  			}
   283  			if got, want := j.nextBlock, test.expectedWriteBlock; got != want {
   284  				t.Errorf("nextBlock = %d, want %d", got, want)
   285  			}
   286  			// Ensure we see the data written in the last iteration, if any
   287  			if prevData != nil && !bytes.Equal(curData, prevData) {
   288  				t.Errorf("Got data %q, want %q", string(curData), string(prevData))
   289  			}
   290  			prevData = test.data
   291  
   292  			// Write some updated data
   293  			if err := j.Update(test.data); err != nil {
   294  				t.Fatalf("Update: %v", err)
   295  			}
   296  		})
   297  	}
   298  }
   299  
   300  func TestUpdateVerifies(t *testing.T) {
   301  	storageBlocks := uint(20)
   302  	md := testonly.NewMemDev(t, storageBlocks)
   303  	start, length := uint(1), storageBlocks-1
   304  
   305  	j, err := OpenJournal(md, start, length)
   306  	if err != nil {
   307  		t.Fatalf("OpenJournal: %v", err)
   308  	}
   309  
   310  	didCorrupt := false
   311  	// Set a hook to corrupt writes
   312  	md.OnBlockWritten = func(lba uint) {
   313  		md.Storage[lba][0] ^= 0x23
   314  		didCorrupt = true
   315  	}
   316  
   317  	// Write some updated data, the write itself should succeed but
   318  	if err := j.Update(fill(1000, "some data")); err == nil {
   319  		t.Fatal("Update want error, got nil")
   320  	}
   321  	if !didCorrupt {
   322  		t.Fatal("Update failed as expected, but we didn't corrupt data?!")
   323  	}
   324  }
   325  
   326  // fill returns a slice of length n containing as many repeats of s as necessary
   327  // to fill it (including partial at the end if needed).
   328  func fill(n int, s string) []byte {
   329  	return []byte(strings.Repeat(s, n/len(s)+1))[:n]
   330  }