github.com/transparency-dev/armored-witness-applet@v0.1.1/trusted_applet/internal/storage/slots/journal_test.go (about) 1 // Copyright 2022 The Armored Witness Applet authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package slots 16 17 import ( 18 "bytes" 19 "crypto/sha256" 20 "fmt" 21 "strings" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 "github.com/transparency-dev/armored-witness-applet/trusted_applet/internal/storage/testonly" 26 ) 27 28 func magic0Hdr() [4]byte { 29 return [4]byte{magic0[0], magic0[1], magic0[2], magic0[3]} 30 } 31 32 func TestEntryMarshal(t *testing.T) { 33 for _, test := range []struct { 34 name string 35 e entry 36 want []byte 37 wantErr bool 38 }{ 39 { 40 name: "golden", 41 e: entry{ 42 Magic: magic0Hdr(), 43 Revision: 42, 44 DataLen: uint64(len("hello")), 45 DataSHA256: sha256.Sum256([]byte("hello")), 46 Data: []byte("hello"), 47 }, 48 want: []byte{ 49 'T', 'F', 'J', '0', 50 0x00, 0x00, 0x00, 0x2a, 51 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 52 0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24, 53 'h', 'e', 'l', 'l', 'o', 54 }, 55 }, { 56 name: "bad magic", 57 e: entry{ 58 Magic: [4]byte{'n', 'O', 'p', 'E'}, 59 Revision: 42, 60 DataLen: uint64(len("hello")), 61 DataSHA256: sha256.Sum256([]byte("hello")), 62 Data: []byte("hello"), 63 }, 64 wantErr: true, 65 }, { 66 name: "bad hash", 67 e: entry{ 68 Magic: magic0Hdr(), 69 Revision: 42, 70 DataLen: uint64(len("hello")), 71 DataSHA256: sha256.Sum256([]byte("nOpE")), 72 Data: []byte("hello"), 73 }, 74 wantErr: true, 75 }, 76 } { 77 t.Run(test.name, func(t *testing.T) { 78 b := &bytes.Buffer{} 79 err := marshalEntry(test.e, b) 80 if gotErr := err != nil; gotErr != test.wantErr { 81 t.Fatalf("marshalEntry: %v, wantErr %t", err, test.wantErr) 82 } 83 if test.wantErr { 84 return 85 } 86 87 if diff := cmp.Diff(b.Bytes(), test.want); diff != "" { 88 t.Fatalf("Got diff: %v", diff) 89 } 90 }) 91 } 92 } 93 94 func TestEntryUnmarshal(t *testing.T) { 95 for _, test := range []struct { 96 name string 97 want *entry 98 b []byte 99 wantErr bool 100 }{ 101 { 102 name: "golden", 103 want: &entry{ 104 Magic: magic0Hdr(), 105 Revision: 42, 106 DataLen: uint64(len("hello")), 107 DataSHA256: sha256.Sum256([]byte("hello")), 108 Data: []byte("hello"), 109 }, 110 b: []byte{ 111 'T', 'F', 'J', '0', 112 0x00, 0x00, 0x00, 0x2a, 113 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 114 0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24, 115 'h', 'e', 'l', 'l', 'o', 116 }, 117 }, { 118 name: "bad magic", 119 b: []byte{ 120 'N', 'o', 'P', 'e', 121 0x00, 0x00, 0x00, 0x2a, 122 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 123 0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24, 124 'h', 'e', 'l', 'l', 'o', 125 }, 126 wantErr: true, 127 }, { 128 name: "bad hash", 129 b: []byte{ 130 'T', 'F', 'J', '0', 131 0x00, 0x00, 0x00, 0x2a, 132 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 133 0x2c, 0xf2, 0x4d, 0xba, 0x5f, 0xb0, 0xa3, 0x0e, 0x26, 0xe8, 0x3b, 0x2a, 0xc5, 0xb9, 0xe2, 0x9e, 0x1b, 0x16, 0x1e, 0x5c, 0x1f, 0xa7, 0x42, 0x5e, 0x73, 0x04, 0x33, 0x62, 0x93, 0x8b, 0x98, 0x24, 134 'W', 'R', 'O', 'N', 'G', 135 }, 136 wantErr: true, 137 }, 138 } { 139 t.Run(test.name, func(t *testing.T) { 140 got, err := unmarshalEntry(bytes.NewReader(test.b)) 141 if gotErr := err != nil; gotErr != test.wantErr { 142 t.Fatalf("unmarshalEntry: %v, wantErr %t", err, test.wantErr) 143 } 144 if test.wantErr { 145 return 146 } 147 148 if diff := cmp.Diff(got, test.want); diff != "" { 149 t.Fatalf("Got diff: %v", diff) 150 } 151 }) 152 } 153 } 154 155 func TestOpenJournal(t *testing.T) { 156 md := testonly.NewMemDev(t, 2) 157 start, length := uint(1), uint(len(md.Storage)-1) 158 j, err := OpenJournal(md, start, length) 159 if err != nil { 160 t.Fatalf("OpenJournal: %v", err) 161 } 162 if j.start != start { 163 t.Errorf("Journal.start=%d, want %d", j.start, start) 164 } 165 if j.length != length { 166 t.Errorf("Journal.length=%d, want %d", j.length, length) 167 } 168 } 169 170 func TestWriteSizeLimit(t *testing.T) { 171 storageBlocks := uint(20) 172 md := testonly.NewMemDev(t, storageBlocks) 173 start, length := uint(1), storageBlocks 174 175 j, err := OpenJournal(md, start, length) 176 if err != nil { 177 t.Fatalf("OpenJournal: %v", err) 178 } 179 180 limit := int((storageBlocks * md.BlockSize() / 3) - entryHeaderSize) 181 if err := j.Update(fill(limit, "ok...")); err != nil { 182 t.Fatalf("Update: %q, but expected write to succeed", err) 183 } 184 if err := j.Update(fill(limit+1, "BOOM")); err == nil { 185 t.Fatal("Update succeeded, but expected write fail") 186 } 187 } 188 189 func TestPerfectlyFullJournal(t *testing.T) { 190 storageBlocks := uint(10) 191 md := testonly.NewMemDev(t, storageBlocks) 192 start, length := uint(1), storageBlocks-1 193 194 var prevData []byte 195 for i, test := range []struct { 196 data []byte 197 expectedWriteBlock uint 198 }{ 199 {data: []byte{}, expectedWriteBlock: start}, // 1 block 200 {data: fill(256, "One"), expectedWriteBlock: start + 1}, // 1 block 201 {data: fill(1000, "Two"), expectedWriteBlock: start + 1 + 1}, // 3 blocks 202 {data: fill(1000, "Three"), expectedWriteBlock: start + 1 + 1 + 3}, // 3 blocks 203 {data: fill(433, "Four"), expectedWriteBlock: start + 1 + 1 + 3 + 3}, // 1 block - we're full! 204 } { 205 t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { 206 j, err := OpenJournal(md, start, length) 207 if err != nil { 208 t.Fatalf("OpenJournal: %v", err) 209 } 210 211 // Check the current record is as we expect: 212 // - Revision should match the iteration count 213 // - The next block to write to is correct 214 // - The stored data matches what we last wrote 215 curData, rev := j.Data() 216 if rev != uint32(i) { 217 t.Errorf("Got revision %d, want %d", rev, i) 218 } 219 if got, want := j.nextBlock, test.expectedWriteBlock; got != want { 220 t.Errorf("nextBlock = %d, want %d", got, want) 221 } 222 // Ensure we see the data written in the last iteration, if any 223 if prevData != nil && !bytes.Equal(curData, prevData) { 224 t.Errorf("Got data %q, want %q", string(curData), string(prevData)) 225 } 226 prevData = test.data 227 228 // Write some updated data 229 if err := j.Update(test.data); err != nil { 230 t.Fatalf("Update: %v", err) 231 } 232 }) 233 } 234 235 t.Run("final", func(t *testing.T) { 236 // Now just ensure we can successfully read from the last entry of a full journal, 237 // and that the next write position is at the start of the journal. 238 j, err := OpenJournal(md, start, length) 239 if err != nil { 240 t.Fatalf("OpenJournal: %v", err) 241 } 242 if got, want := j.nextBlock, start; got != want { 243 t.Fatalf("nextBlock didn't wrap to first block of journal, got %d, want %d", got, want) 244 } 245 }) 246 } 247 248 func TestRoundTrip(t *testing.T) { 249 storageBlocks := uint(20) 250 md := testonly.NewMemDev(t, storageBlocks) 251 start, length := uint(1), storageBlocks-1 252 253 var prevData []byte 254 for i, test := range []struct { 255 data []byte 256 expectedWriteBlock uint 257 }{ 258 {data: []byte{}, expectedWriteBlock: start}, // 1 block 259 {data: fill(256, "hello"), expectedWriteBlock: start + 1}, // 1 block 260 {data: fill(1000, "there"), expectedWriteBlock: start + 1 + 1}, // 3 blocks 261 {data: fill(30, "how"), expectedWriteBlock: start + 1 + 1 + 3}, // 1 block 262 {data: fill(3000, "are"), expectedWriteBlock: start + 1 + 1 + 3 + 1}, // 6 blocks 263 {data: fill(59, "you"), expectedWriteBlock: start + 1 + 1 + 3 + 1 + 6}, // 1 block 264 {data: fill(3000, "doing"), expectedWriteBlock: start + 1 + 1 + 3 + 1 + 6 + 1}, // 6 blocks, 265 {data: fill(1000, "doing"), expectedWriteBlock: start}, // 3 blocks, won't fit so will end up writing at `start` 266 {data: fill(3000, "today?"), expectedWriteBlock: start + 3}, // 6 blocks 267 {data: fill(20, "All done!"), expectedWriteBlock: start + 9}, // 1 blocks 268 } { 269 t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { 270 j, err := OpenJournal(md, start, length) 271 if err != nil { 272 t.Fatalf("OpenJournal: %v", err) 273 } 274 275 // Check the current record is as we expect: 276 // - Revision should match the iteration count 277 // - The next block to write to is correct 278 // - The stored data matches what we last wrote 279 curData, rev := j.Data() 280 if rev != uint32(i) { 281 t.Errorf("Got revision %d, want %d", rev, i) 282 } 283 if got, want := j.nextBlock, test.expectedWriteBlock; got != want { 284 t.Errorf("nextBlock = %d, want %d", got, want) 285 } 286 // Ensure we see the data written in the last iteration, if any 287 if prevData != nil && !bytes.Equal(curData, prevData) { 288 t.Errorf("Got data %q, want %q", string(curData), string(prevData)) 289 } 290 prevData = test.data 291 292 // Write some updated data 293 if err := j.Update(test.data); err != nil { 294 t.Fatalf("Update: %v", err) 295 } 296 }) 297 } 298 } 299 300 func TestUpdateVerifies(t *testing.T) { 301 storageBlocks := uint(20) 302 md := testonly.NewMemDev(t, storageBlocks) 303 start, length := uint(1), storageBlocks-1 304 305 j, err := OpenJournal(md, start, length) 306 if err != nil { 307 t.Fatalf("OpenJournal: %v", err) 308 } 309 310 didCorrupt := false 311 // Set a hook to corrupt writes 312 md.OnBlockWritten = func(lba uint) { 313 md.Storage[lba][0] ^= 0x23 314 didCorrupt = true 315 } 316 317 // Write some updated data, the write itself should succeed but 318 if err := j.Update(fill(1000, "some data")); err == nil { 319 t.Fatal("Update want error, got nil") 320 } 321 if !didCorrupt { 322 t.Fatal("Update failed as expected, but we didn't corrupt data?!") 323 } 324 } 325 326 // fill returns a slice of length n containing as many repeats of s as necessary 327 // to fill it (including partial at the end if needed). 328 func fill(n int, s string) []byte { 329 return []byte(strings.Repeat(s, n/len(s)+1))[:n] 330 }