github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/clone/logdb/database_test.go (about)

     1  // Copyright 2020 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package logdb
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/sha256"
    21  	"database/sql"
    22  	"fmt"
    23  	"reflect"
    24  	"testing"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	_ "github.com/mattn/go-sqlite3" // Load drivers for sqlite3
    28  )
    29  
    30  func TestHeadIncremented(t *testing.T) {
    31  	for _, test := range []struct {
    32  		desc    string
    33  		leaves  [][]byte
    34  		want    int64
    35  		wantErr error
    36  	}{
    37  		{
    38  			desc:    "no data",
    39  			wantErr: ErrNoDataFound,
    40  		}, {
    41  			desc:   "one leaf",
    42  			leaves: [][]byte{[]byte("first!")},
    43  			want:   0,
    44  		}, {
    45  			desc:   "many leaves",
    46  			leaves: [][]byte{[]byte("a"), []byte("b"), []byte("c")},
    47  			want:   2,
    48  		},
    49  	} {
    50  		t.Run(test.desc, func(t *testing.T) {
    51  			db, close, err := NewInMemoryDatabase()
    52  			if err != nil {
    53  				t.Fatal("failed to init DB", err)
    54  			}
    55  			defer close()
    56  			if err := db.WriteLeaves(context.Background(), 0, test.leaves); err != nil {
    57  				t.Fatal("failed to write leaves", err)
    58  			}
    59  
    60  			head, err := db.Head()
    61  			if test.wantErr != err {
    62  				t.Errorf("expected err %q but got %q", test.wantErr, err)
    63  			}
    64  			if test.wantErr != nil {
    65  				if head != test.want {
    66  					t.Errorf("expected %d but got %d", test.want, head)
    67  				}
    68  			}
    69  		})
    70  	}
    71  }
    72  
    73  func TestRoundTrip(t *testing.T) {
    74  	leaves := [][]byte{
    75  		[]byte("aa"),
    76  		[]byte("bb"),
    77  		[]byte("cc"),
    78  		[]byte("dd"),
    79  	}
    80  	for _, test := range []struct {
    81  		desc       string
    82  		leaves     [][]byte
    83  		start, end uint64
    84  		wantLeaves [][]byte
    85  	}{{
    86  		desc:       "one leaf",
    87  		leaves:     leaves[:1],
    88  		start:      0,
    89  		end:        1,
    90  		wantLeaves: leaves[:1],
    91  	}, {
    92  		desc:       "many leaves pick all",
    93  		leaves:     leaves,
    94  		start:      0,
    95  		end:        uint64(len(leaves)),
    96  		wantLeaves: leaves,
    97  	}, {
    98  		desc:       "many leaves select middle",
    99  		leaves:     leaves,
   100  		start:      1,
   101  		end:        3,
   102  		wantLeaves: leaves[1:3],
   103  	}, {
   104  		desc:       "many leaves start past the end",
   105  		leaves:     leaves,
   106  		start:      100,
   107  		wantLeaves: [][]byte{},
   108  	},
   109  	} {
   110  		t.Run(test.desc, func(t *testing.T) {
   111  			db, close, err := NewInMemoryDatabase()
   112  			if err != nil {
   113  				t.Fatal("failed to init DB", err)
   114  			}
   115  			defer close()
   116  			if err := db.WriteLeaves(context.Background(), 0, test.leaves); err != nil {
   117  				t.Fatal("failed to write leaves", err)
   118  			}
   119  
   120  			results := make(chan StreamResult, 1)
   121  			go db.StreamLeaves(context.Background(), test.start, test.end, results)
   122  
   123  			got := make([][]byte, 0)
   124  		Receive:
   125  			for result := range results {
   126  				if result.Err != nil {
   127  					break Receive
   128  				}
   129  				got = append(got, result.Leaf)
   130  			}
   131  
   132  			if err != nil {
   133  				t.Fatalf("unexpected error: %q", err)
   134  			}
   135  
   136  			if diff := cmp.Diff(got, test.wantLeaves); len(diff) > 0 {
   137  				t.Errorf("diff in leaves: %q", diff)
   138  			}
   139  		})
   140  	}
   141  }
   142  
   143  func TestCheckpointRoundTrip(t *testing.T) {
   144  	hashes := make([][]byte, 5)
   145  	for i := range hashes {
   146  		h := sha256.Sum256([]byte(fmt.Sprintf("hash %d", i)))
   147  		hashes[i] = h[:]
   148  	}
   149  	for _, test := range []struct {
   150  		desc         string
   151  		checkpoint   []byte
   152  		compactRange [][]byte
   153  		size         uint64
   154  		err          error
   155  	}{{
   156  		desc: "no previous",
   157  		err:  ErrNoDataFound,
   158  	}, {
   159  		desc:         "single compact range",
   160  		size:         256,
   161  		checkpoint:   hashes[0],
   162  		compactRange: hashes[1:2],
   163  	}, {
   164  		desc:         "longer compact range",
   165  		size:         277,
   166  		checkpoint:   hashes[0],
   167  		compactRange: hashes[1:],
   168  	},
   169  	} {
   170  		t.Run(test.desc, func(t *testing.T) {
   171  			db, close, err := NewInMemoryDatabase()
   172  			if err != nil {
   173  				t.Fatal("failed to init DB", err)
   174  			}
   175  			defer close()
   176  			if test.checkpoint != nil {
   177  				if err := db.WriteCheckpoint(context.Background(), test.size, test.checkpoint, test.compactRange); err != nil {
   178  					t.Fatal("failed to write checkpoint", err)
   179  				}
   180  			}
   181  
   182  			gotSize, gotCP, gotCR, gotErr := db.GetLatestCheckpoint(context.Background())
   183  
   184  			if gotErr != test.err {
   185  				t.Fatalf("mismatched error: got=%v, want %v", gotErr, test.err)
   186  			}
   187  
   188  			if gotSize != test.size {
   189  				t.Errorf("size: got %d, want %d", gotSize, test.size)
   190  			}
   191  			if !bytes.Equal(gotCP, test.checkpoint) {
   192  				t.Errorf("checkpoint: got %x, want %x", gotCP, test.checkpoint)
   193  			}
   194  			if !reflect.DeepEqual(gotCR, test.compactRange) {
   195  				t.Errorf("compact range: got != want: \n%v\n%v", gotCR, test.compactRange)
   196  			}
   197  		})
   198  	}
   199  }
   200  
   201  func TestCheckpoints(t *testing.T) {
   202  	for _, test := range []struct {
   203  		desc          string
   204  		size1         uint64
   205  		checkpoint1   []byte
   206  		compactRange1 [][]byte
   207  		size2         uint64
   208  		checkpoint2   []byte
   209  		compactRange2 [][]byte
   210  		wantSize      uint64
   211  		wantCP        []byte
   212  	}{{
   213  		desc:          "small, big",
   214  		size1:         16,
   215  		checkpoint1:   mustHash("root 1"),
   216  		compactRange1: [][]byte{mustHash("root 1")},
   217  		size2:         32,
   218  		checkpoint2:   mustHash("root 2"),
   219  		compactRange2: [][]byte{mustHash("root 2")},
   220  		wantSize:      32,
   221  		wantCP:        mustHash("root 2"),
   222  	}, {
   223  		desc:          "big, small",
   224  		size1:         32,
   225  		checkpoint1:   mustHash("root 1"),
   226  		compactRange1: [][]byte{mustHash("root 1")},
   227  		size2:         16,
   228  		checkpoint2:   mustHash("root 2"),
   229  		compactRange2: [][]byte{mustHash("root 2")},
   230  		wantSize:      32,
   231  		wantCP:        mustHash("root 1"),
   232  	}, {
   233  		desc:          "same checkpoint twice",
   234  		size1:         16,
   235  		checkpoint1:   mustHash("root 1"),
   236  		compactRange1: [][]byte{mustHash("root 1")},
   237  		size2:         16,
   238  		checkpoint2:   mustHash("root 1"),
   239  		compactRange2: [][]byte{mustHash("root 1")},
   240  		wantSize:      16,
   241  		wantCP:        mustHash("root 1"),
   242  	}, {
   243  		desc:          "unequal checkpoints for same size",
   244  		size1:         16,
   245  		checkpoint1:   mustHash("root 1"),
   246  		compactRange1: [][]byte{mustHash("root 1")},
   247  		size2:         16,
   248  		checkpoint2:   mustHash("root 2"),
   249  		compactRange2: [][]byte{mustHash("root 2")},
   250  		wantSize:      16,
   251  		wantCP:        mustHash("root 1"),
   252  	},
   253  	} {
   254  		t.Run(test.desc, func(t *testing.T) {
   255  			db, close, err := NewInMemoryDatabase()
   256  			if err != nil {
   257  				t.Fatal("failed to init DB", err)
   258  			}
   259  			defer close()
   260  
   261  			if err := db.WriteCheckpoint(context.Background(), test.size1, test.checkpoint1, test.compactRange1); err != nil {
   262  				t.Fatal("failed to write checkpoint 1", err)
   263  			}
   264  			if err := db.WriteCheckpoint(context.Background(), test.size2, test.checkpoint2, test.compactRange2); err != nil {
   265  				t.Fatal("failed to write checkpoint 2", err)
   266  			}
   267  
   268  			gotSize, cp, _, err := db.GetLatestCheckpoint(context.Background())
   269  			if err != nil {
   270  				t.Fatalf("GetLatestCheckpoint(): %v", err)
   271  			}
   272  
   273  			if gotSize != test.wantSize {
   274  				t.Errorf("size: got %d, want %d", gotSize, test.wantSize)
   275  			}
   276  			if !bytes.Equal(cp, test.wantCP) {
   277  				t.Errorf("checkpoint: got %x, want %x", cp, test.wantCP)
   278  			}
   279  		})
   280  	}
   281  }
   282  
   283  func NewInMemoryDatabase() (*Database, func(), error) {
   284  	sqlitedb, err := sql.Open("sqlite3", ":memory:")
   285  	if err != nil {
   286  		return nil, nil, fmt.Errorf("failed to open temporary in-memory DB: %v", err)
   287  	}
   288  	db, err := NewDatabaseDirect(sqlitedb)
   289  	return db, func() { _ = sqlitedb.Close }, err
   290  }
   291  
   292  func mustHash(s string) []byte {
   293  	h := sha256.Sum256([]byte(s))
   294  	return h[:]
   295  }