github.com/ethersphere/bee/v2@v2.2.0/pkg/bmt/proof_test.go (about)

     1  // Copyright 2022 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bmt_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"io"
    13  	"testing"
    14  
    15  	"github.com/ethersphere/bee/v2/pkg/bmt"
    16  	"github.com/ethersphere/bee/v2/pkg/swarm"
    17  )
    18  
    19  func TestProofCorrectness(t *testing.T) {
    20  	t.Parallel()
    21  
    22  	testData := []byte("hello world")
    23  	testDataPadded := make([]byte, swarm.ChunkSize)
    24  	copy(testDataPadded, testData)
    25  
    26  	verifySegments := func(t *testing.T, exp []string, found [][]byte) {
    27  		t.Helper()
    28  
    29  		var expSegments [][]byte
    30  		for _, v := range exp {
    31  			decoded, err := hex.DecodeString(v)
    32  			if err != nil {
    33  				t.Fatal(err)
    34  			}
    35  			expSegments = append(expSegments, decoded)
    36  		}
    37  
    38  		if len(expSegments) != len(found) {
    39  			t.Fatal("incorrect no of proof segments", len(found))
    40  		}
    41  
    42  		for idx, v := range expSegments {
    43  			if !bytes.Equal(v, found[idx]) {
    44  				t.Fatal("incorrect segment in proof")
    45  			}
    46  		}
    47  
    48  	}
    49  
    50  	pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, 128, 128))
    51  	hh := pool.Get()
    52  	t.Cleanup(func() {
    53  		pool.Put(hh)
    54  	})
    55  	hh.SetHeaderInt64(4096)
    56  
    57  	_, err := hh.Write(testData)
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	pr := bmt.Prover{hh}
    62  	rh, err := pr.Hash(nil)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	t.Run("proof for left most", func(t *testing.T) {
    68  		t.Parallel()
    69  
    70  		proof := pr.Proof(0)
    71  
    72  		expSegmentStrings := []string{
    73  			"0000000000000000000000000000000000000000000000000000000000000000",
    74  			"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
    75  			"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
    76  			"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
    77  			"e58769b32a1beaf1ea27375a44095a0d1fb664ce2dd358e7fcbfb78c26a19344",
    78  			"0eb01ebfc9ed27500cd4dfc979272d1f0913cc9f66540d7e8005811109e1cf2d",
    79  			"887c22bd8750d34016ac3c66b5ff102dacdd73f6b014e710b51e8022af9a1968",
    80  		}
    81  
    82  		verifySegments(t, expSegmentStrings, proof.ProofSegments)
    83  
    84  		if !bytes.Equal(proof.ProveSegment, testDataPadded[:hh.Size()]) {
    85  			t.Fatal("section incorrect")
    86  		}
    87  
    88  		if !bytes.Equal(proof.Span, bmt.LengthToSpan(4096)) {
    89  			t.Fatal("incorrect span")
    90  		}
    91  	})
    92  
    93  	t.Run("proof for right most", func(t *testing.T) {
    94  		t.Parallel()
    95  
    96  		proof := pr.Proof(127)
    97  
    98  		expSegmentStrings := []string{
    99  			"0000000000000000000000000000000000000000000000000000000000000000",
   100  			"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
   101  			"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
   102  			"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
   103  			"e58769b32a1beaf1ea27375a44095a0d1fb664ce2dd358e7fcbfb78c26a19344",
   104  			"0eb01ebfc9ed27500cd4dfc979272d1f0913cc9f66540d7e8005811109e1cf2d",
   105  			"745bae095b6ff5416b4a351a167f731db6d6f5924f30cd88d48e74261795d27b",
   106  		}
   107  
   108  		verifySegments(t, expSegmentStrings, proof.ProofSegments)
   109  
   110  		if !bytes.Equal(proof.ProveSegment, testDataPadded[127*hh.Size():]) {
   111  			t.Fatal("section incorrect")
   112  		}
   113  
   114  		if !bytes.Equal(proof.Span, bmt.LengthToSpan(4096)) {
   115  			t.Fatal("incorrect span")
   116  		}
   117  	})
   118  
   119  	t.Run("proof for middle", func(t *testing.T) {
   120  		t.Parallel()
   121  
   122  		proof := pr.Proof(64)
   123  
   124  		expSegmentStrings := []string{
   125  			"0000000000000000000000000000000000000000000000000000000000000000",
   126  			"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
   127  			"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
   128  			"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
   129  			"e58769b32a1beaf1ea27375a44095a0d1fb664ce2dd358e7fcbfb78c26a19344",
   130  			"0eb01ebfc9ed27500cd4dfc979272d1f0913cc9f66540d7e8005811109e1cf2d",
   131  			"745bae095b6ff5416b4a351a167f731db6d6f5924f30cd88d48e74261795d27b",
   132  		}
   133  
   134  		verifySegments(t, expSegmentStrings, proof.ProofSegments)
   135  
   136  		if !bytes.Equal(proof.ProveSegment, testDataPadded[64*hh.Size():65*hh.Size()]) {
   137  			t.Fatal("section incorrect")
   138  		}
   139  
   140  		if !bytes.Equal(proof.Span, bmt.LengthToSpan(4096)) {
   141  			t.Fatal("incorrect span")
   142  		}
   143  	})
   144  
   145  	t.Run("root hash calculation", func(t *testing.T) {
   146  		t.Parallel()
   147  
   148  		segmentStrings := []string{
   149  			"0000000000000000000000000000000000000000000000000000000000000000",
   150  			"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5",
   151  			"b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30",
   152  			"21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85",
   153  			"e58769b32a1beaf1ea27375a44095a0d1fb664ce2dd358e7fcbfb78c26a19344",
   154  			"0eb01ebfc9ed27500cd4dfc979272d1f0913cc9f66540d7e8005811109e1cf2d",
   155  			"745bae095b6ff5416b4a351a167f731db6d6f5924f30cd88d48e74261795d27b",
   156  		}
   157  
   158  		var segments [][]byte
   159  		for _, v := range segmentStrings {
   160  			decoded, err := hex.DecodeString(v)
   161  			if err != nil {
   162  				t.Fatal(err)
   163  			}
   164  			segments = append(segments, decoded)
   165  		}
   166  
   167  		segment := testDataPadded[64*hh.Size() : 65*hh.Size()]
   168  
   169  		rootHash, err := pr.Verify(64, bmt.Proof{
   170  			ProveSegment:  segment,
   171  			ProofSegments: segments,
   172  			Span:          bmt.LengthToSpan(4096),
   173  		})
   174  		if err != nil {
   175  			t.Fatal(err)
   176  		}
   177  
   178  		if !bytes.Equal(rootHash, rh) {
   179  			t.Fatal("incorrect root hash obtained")
   180  		}
   181  	})
   182  }
   183  
   184  func TestProof(t *testing.T) {
   185  	t.Parallel()
   186  
   187  	// BMT segment inclusion proofs
   188  	// Usage
   189  	buf := make([]byte, 4096)
   190  	_, err := io.ReadFull(rand.Reader, buf)
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  
   195  	pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, 128, 128))
   196  	hh := pool.Get()
   197  	t.Cleanup(func() {
   198  		pool.Put(hh)
   199  	})
   200  	hh.SetHeaderInt64(4096)
   201  
   202  	_, err = hh.Write(buf)
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  
   207  	rh, err := hh.Hash(nil)
   208  	pr := bmt.Prover{hh}
   209  	if err != nil {
   210  		t.Fatal(err)
   211  	}
   212  
   213  	for i := 0; i < 128; i++ {
   214  		i := i
   215  		t.Run(fmt.Sprintf("segmentIndex %d", i), func(t *testing.T) {
   216  			t.Parallel()
   217  
   218  			proof := pr.Proof(i)
   219  
   220  			h := pool.Get()
   221  			defer pool.Put(h)
   222  
   223  			root, err := bmt.Prover{h}.Verify(i, proof)
   224  			if err != nil {
   225  				t.Fatal(err)
   226  			}
   227  			if !bytes.Equal(rh, root) {
   228  				t.Fatalf("incorrect hash. wanted %x, got %x.", rh, root)
   229  			}
   230  		})
   231  	}
   232  }