github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/clone/logdb/database_test.go (about) 1 // Copyright 2020 Google LLC. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package logdb 16 17 import ( 18 "bytes" 19 "context" 20 "crypto/sha256" 21 "database/sql" 22 "fmt" 23 "reflect" 24 "testing" 25 26 "github.com/google/go-cmp/cmp" 27 _ "github.com/mattn/go-sqlite3" // Load drivers for sqlite3 28 ) 29 30 func TestHeadIncremented(t *testing.T) { 31 for _, test := range []struct { 32 desc string 33 leaves [][]byte 34 want int64 35 wantErr error 36 }{ 37 { 38 desc: "no data", 39 wantErr: ErrNoDataFound, 40 }, { 41 desc: "one leaf", 42 leaves: [][]byte{[]byte("first!")}, 43 want: 0, 44 }, { 45 desc: "many leaves", 46 leaves: [][]byte{[]byte("a"), []byte("b"), []byte("c")}, 47 want: 2, 48 }, 49 } { 50 t.Run(test.desc, func(t *testing.T) { 51 db, close, err := NewInMemoryDatabase() 52 if err != nil { 53 t.Fatal("failed to init DB", err) 54 } 55 defer close() 56 if err := db.WriteLeaves(context.Background(), 0, test.leaves); err != nil { 57 t.Fatal("failed to write leaves", err) 58 } 59 60 head, err := db.Head() 61 if test.wantErr != err { 62 t.Errorf("expected err %q but got %q", test.wantErr, err) 63 } 64 if test.wantErr != nil { 65 if head != test.want { 66 t.Errorf("expected %d but got %d", test.want, head) 67 } 68 } 69 }) 70 } 71 } 72 73 func TestRoundTrip(t *testing.T) { 74 leaves := [][]byte{ 75 []byte("aa"), 76 []byte("bb"), 77 []byte("cc"), 78 []byte("dd"), 79 } 80 for _, test := range []struct { 81 desc string 82 leaves [][]byte 83 start, end uint64 84 wantLeaves [][]byte 85 }{{ 86 desc: "one leaf", 87 leaves: leaves[:1], 88 start: 0, 89 end: 1, 90 wantLeaves: leaves[:1], 91 }, { 92 desc: "many leaves pick all", 93 leaves: leaves, 94 start: 0, 95 end: uint64(len(leaves)), 96 wantLeaves: leaves, 97 }, { 98 desc: "many leaves select middle", 99 leaves: leaves, 100 start: 1, 101 end: 3, 102 wantLeaves: leaves[1:3], 103 }, { 104 desc: "many leaves start past the end", 105 leaves: leaves, 106 start: 100, 107 wantLeaves: [][]byte{}, 108 }, 109 } { 110 t.Run(test.desc, func(t *testing.T) { 111 db, close, err := NewInMemoryDatabase() 112 if err != nil { 113 t.Fatal("failed to init DB", err) 114 } 115 defer close() 116 if err := db.WriteLeaves(context.Background(), 0, test.leaves); err != nil { 117 t.Fatal("failed to write leaves", err) 118 } 119 120 results := make(chan StreamResult, 1) 121 go db.StreamLeaves(context.Background(), test.start, test.end, results) 122 123 got := make([][]byte, 0) 124 Receive: 125 for result := range results { 126 if result.Err != nil { 127 break Receive 128 } 129 got = append(got, result.Leaf) 130 } 131 132 if err != nil { 133 t.Fatalf("unexpected error: %q", err) 134 } 135 136 if diff := cmp.Diff(got, test.wantLeaves); len(diff) > 0 { 137 t.Errorf("diff in leaves: %q", diff) 138 } 139 }) 140 } 141 } 142 143 func TestCheckpointRoundTrip(t *testing.T) { 144 hashes := make([][]byte, 5) 145 for i := range hashes { 146 h := sha256.Sum256([]byte(fmt.Sprintf("hash %d", i))) 147 hashes[i] = h[:] 148 } 149 for _, test := range []struct { 150 desc string 151 checkpoint []byte 152 compactRange [][]byte 153 size uint64 154 err error 155 }{{ 156 desc: "no previous", 157 err: ErrNoDataFound, 158 }, { 159 desc: "single compact range", 160 size: 256, 161 checkpoint: hashes[0], 162 compactRange: hashes[1:2], 163 }, { 164 desc: "longer compact range", 165 size: 277, 166 checkpoint: hashes[0], 167 compactRange: hashes[1:], 168 }, 169 } { 170 t.Run(test.desc, func(t *testing.T) { 171 db, close, err := NewInMemoryDatabase() 172 if err != nil { 173 t.Fatal("failed to init DB", err) 174 } 175 defer close() 176 if test.checkpoint != nil { 177 if err := db.WriteCheckpoint(context.Background(), test.size, test.checkpoint, test.compactRange); err != nil { 178 t.Fatal("failed to write checkpoint", err) 179 } 180 } 181 182 gotSize, gotCP, gotCR, gotErr := db.GetLatestCheckpoint(context.Background()) 183 184 if gotErr != test.err { 185 t.Fatalf("mismatched error: got=%v, want %v", gotErr, test.err) 186 } 187 188 if gotSize != test.size { 189 t.Errorf("size: got %d, want %d", gotSize, test.size) 190 } 191 if !bytes.Equal(gotCP, test.checkpoint) { 192 t.Errorf("checkpoint: got %x, want %x", gotCP, test.checkpoint) 193 } 194 if !reflect.DeepEqual(gotCR, test.compactRange) { 195 t.Errorf("compact range: got != want: \n%v\n%v", gotCR, test.compactRange) 196 } 197 }) 198 } 199 } 200 201 func TestCheckpoints(t *testing.T) { 202 for _, test := range []struct { 203 desc string 204 size1 uint64 205 checkpoint1 []byte 206 compactRange1 [][]byte 207 size2 uint64 208 checkpoint2 []byte 209 compactRange2 [][]byte 210 wantSize uint64 211 wantCP []byte 212 }{{ 213 desc: "small, big", 214 size1: 16, 215 checkpoint1: mustHash("root 1"), 216 compactRange1: [][]byte{mustHash("root 1")}, 217 size2: 32, 218 checkpoint2: mustHash("root 2"), 219 compactRange2: [][]byte{mustHash("root 2")}, 220 wantSize: 32, 221 wantCP: mustHash("root 2"), 222 }, { 223 desc: "big, small", 224 size1: 32, 225 checkpoint1: mustHash("root 1"), 226 compactRange1: [][]byte{mustHash("root 1")}, 227 size2: 16, 228 checkpoint2: mustHash("root 2"), 229 compactRange2: [][]byte{mustHash("root 2")}, 230 wantSize: 32, 231 wantCP: mustHash("root 1"), 232 }, { 233 desc: "same checkpoint twice", 234 size1: 16, 235 checkpoint1: mustHash("root 1"), 236 compactRange1: [][]byte{mustHash("root 1")}, 237 size2: 16, 238 checkpoint2: mustHash("root 1"), 239 compactRange2: [][]byte{mustHash("root 1")}, 240 wantSize: 16, 241 wantCP: mustHash("root 1"), 242 }, { 243 desc: "unequal checkpoints for same size", 244 size1: 16, 245 checkpoint1: mustHash("root 1"), 246 compactRange1: [][]byte{mustHash("root 1")}, 247 size2: 16, 248 checkpoint2: mustHash("root 2"), 249 compactRange2: [][]byte{mustHash("root 2")}, 250 wantSize: 16, 251 wantCP: mustHash("root 1"), 252 }, 253 } { 254 t.Run(test.desc, func(t *testing.T) { 255 db, close, err := NewInMemoryDatabase() 256 if err != nil { 257 t.Fatal("failed to init DB", err) 258 } 259 defer close() 260 261 if err := db.WriteCheckpoint(context.Background(), test.size1, test.checkpoint1, test.compactRange1); err != nil { 262 t.Fatal("failed to write checkpoint 1", err) 263 } 264 if err := db.WriteCheckpoint(context.Background(), test.size2, test.checkpoint2, test.compactRange2); err != nil { 265 t.Fatal("failed to write checkpoint 2", err) 266 } 267 268 gotSize, cp, _, err := db.GetLatestCheckpoint(context.Background()) 269 if err != nil { 270 t.Fatalf("GetLatestCheckpoint(): %v", err) 271 } 272 273 if gotSize != test.wantSize { 274 t.Errorf("size: got %d, want %d", gotSize, test.wantSize) 275 } 276 if !bytes.Equal(cp, test.wantCP) { 277 t.Errorf("checkpoint: got %x, want %x", cp, test.wantCP) 278 } 279 }) 280 } 281 } 282 283 func NewInMemoryDatabase() (*Database, func(), error) { 284 sqlitedb, err := sql.Open("sqlite3", ":memory:") 285 if err != nil { 286 return nil, nil, fmt.Errorf("failed to open temporary in-memory DB: %v", err) 287 } 288 db, err := NewDatabaseDirect(sqlitedb) 289 return db, func() { _ = sqlitedb.Close }, err 290 } 291 292 func mustHash(s string) []byte { 293 h := sha256.Sum256([]byte(s)) 294 return h[:] 295 }