github.com/0chain/gosdk@v1.17.11/core/util/validation_tree_test.go (about)

     1  package util
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/minio/sha256-simd"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  const (
    16  	HashSize = 32
    17  )
    18  
    19  func TestValidationTreeWrite(t *testing.T) {
    20  	dataSizes := []int64{
    21  		MaxMerkleLeavesSize,
    22  		MaxMerkleLeavesSize - 24*KB,
    23  		MaxMerkleLeavesSize * 2,
    24  		MaxMerkleLeavesSize * 3,
    25  		MaxMerkleLeavesSize*10 - 1,
    26  	}
    27  
    28  	for _, s := range dataSizes {
    29  		data := make([]byte, s)
    30  		n, err := rand.Read(data)
    31  		require.NoError(t, err)
    32  		require.EqualValues(t, s, n)
    33  
    34  		root := calculateValidationMerkleRoot(data)
    35  
    36  		vt := NewValidationTree(s)
    37  		diff := 1
    38  		i := len(data) - diff
    39  
    40  		_, err = vt.Write(data[0:i])
    41  		require.NoError(t, err)
    42  		vt.calculateRoot()
    43  
    44  		require.False(t, bytes.Equal(root, vt.validationRoot))
    45  
    46  		_, err = vt.Write(data[i:])
    47  		require.NoError(t, err)
    48  
    49  		err = vt.Finalize()
    50  		require.NoError(t, err)
    51  
    52  		vt.calculateRoot()
    53  		require.True(t, bytes.Equal(root, vt.validationRoot))
    54  
    55  		require.Error(t, vt.Finalize())
    56  	}
    57  }
    58  
    59  func TestValidationTreeCalculateDepth(t *testing.T) {
    60  	in := map[int]int{
    61  		1:   1,
    62  		2:   2,
    63  		3:   3,
    64  		4:   3,
    65  		10:  5,
    66  		100: 8,
    67  	}
    68  
    69  	for k, d := range in {
    70  		v := ValidationTree{leaves: make([][]byte, k)}
    71  		require.Equal(t, v.CalculateDepth(), d)
    72  	}
    73  }
    74  
    75  func TestMerklePathVerificationForValidationTree(t *testing.T) {
    76  
    77  	type input struct {
    78  		dataSize int64
    79  		startInd int
    80  		endInd   int
    81  	}
    82  
    83  	tests := []*input{
    84  		{
    85  			dataSize: 24 * KB,
    86  			startInd: 0,
    87  			endInd:   0,
    88  		},
    89  		{
    90  			dataSize: 340 * KB,
    91  			startInd: 1,
    92  			endInd:   3,
    93  		},
    94  		{
    95  			dataSize: 640 * KB,
    96  			startInd: 1,
    97  			endInd:   4,
    98  		},
    99  		{
   100  			dataSize: 640*KB + 1,
   101  			startInd: 1,
   102  			endInd:   5,
   103  		},
   104  	}
   105  
   106  	for _, test := range tests {
   107  		t.Run(fmt.Sprintf("Data size: %d KB, startInd: %d, endInd:%d",
   108  			test.dataSize/KB,
   109  			test.startInd,
   110  			test.endInd,
   111  		), func(t *testing.T) {
   112  
   113  			b := make([]byte, test.dataSize)
   114  			n, err := rand.Read(b)
   115  
   116  			require.NoError(t, err)
   117  			require.EqualValues(t, test.dataSize, n)
   118  
   119  			root, nodes, indexes, data, err := calculateValidationRootAndNodes(b, test.startInd, test.endInd)
   120  			require.NoError(t, err)
   121  
   122  			t.Logf("nodes len: %d; index len: %d, indexes: %v", len(nodes), len(indexes), indexes)
   123  			vp := MerklePathForMultiLeafVerification{
   124  				RootHash: root,
   125  				Nodes:    nodes,
   126  				Index:    indexes,
   127  				DataSize: test.dataSize,
   128  			}
   129  
   130  			err = vp.VerifyMultipleBlocks(data)
   131  			require.NoError(t, err)
   132  
   133  			err = vp.VerifyMultipleBlocks(data[1:])
   134  			require.Error(t, err)
   135  		})
   136  
   137  	}
   138  }
   139  
   140  func calculateValidationMerkleRoot(data []byte) []byte {
   141  	hashes := make([][]byte, 0)
   142  	for i := 0; i < len(data); i += MaxMerkleLeavesSize {
   143  		j := i + MaxMerkleLeavesSize
   144  		if j > len(data) {
   145  			j = len(data)
   146  		}
   147  		h := sha256.New()
   148  		_, _ = h.Write(data[i:j])
   149  		hashes = append(hashes, h.Sum(nil))
   150  	}
   151  
   152  	if len(hashes) == 1 {
   153  		return hashes[0]
   154  	}
   155  	for len(hashes) != 1 {
   156  		newHashes := make([][]byte, 0)
   157  		if len(hashes)%2 == 0 {
   158  			for i := 0; i < len(hashes); i += 2 {
   159  				h := sha256.New()
   160  				_, _ = h.Write(hashes[i])
   161  				_, _ = h.Write(hashes[i+1])
   162  				newHashes = append(newHashes, h.Sum(nil))
   163  			}
   164  		} else {
   165  			for i := 0; i < len(hashes)-1; i += 2 {
   166  				h := sha256.New()
   167  				_, _ = h.Write(hashes[i])
   168  				_, _ = h.Write(hashes[i+1])
   169  				newHashes = append(newHashes, h.Sum(nil))
   170  			}
   171  			h := sha256.New()
   172  			_, _ = h.Write(hashes[len(hashes)-1])
   173  			newHashes = append(newHashes, h.Sum(nil))
   174  		}
   175  
   176  		hashes = newHashes
   177  	}
   178  	return hashes[0]
   179  }
   180  
   181  func calculateValidationRootAndNodes(b []byte, startInd, endInd int) (
   182  	root []byte, nodes [][][]byte, indexes [][]int, data []byte, err error,
   183  ) {
   184  
   185  	totalLeaves := int(math.Ceil(float64(len(b)) / float64(MaxMerkleLeavesSize)))
   186  	depth := int(math.Ceil(math.Log2(float64(totalLeaves)))) + 1
   187  
   188  	if endInd >= totalLeaves {
   189  		endInd = totalLeaves - 1
   190  	}
   191  
   192  	hashes := make([][]byte, 0)
   193  	nodesData := make([]byte, 0)
   194  	h := sha256.New()
   195  	for i := 0; i < len(b); i += MaxMerkleLeavesSize {
   196  		j := i + MaxMerkleLeavesSize
   197  		if j > len(b) {
   198  			j = len(b)
   199  		}
   200  
   201  		_, _ = h.Write(b[i:j])
   202  		leafHash := h.Sum(nil)
   203  		hashes = append(hashes, leafHash)
   204  		h.Reset()
   205  	}
   206  
   207  	if len(hashes) == 1 {
   208  		return hashes[0], nil, nil, b, nil
   209  	}
   210  
   211  	for len(hashes) != 1 {
   212  		newHashes := make([][]byte, 0)
   213  		if len(hashes)%2 == 0 {
   214  			for i := 0; i < len(hashes); i += 2 {
   215  				h := sha256.New()
   216  				_, _ = h.Write(hashes[i])
   217  				_, _ = h.Write(hashes[i+1])
   218  				nodesData = append(nodesData, hashes[i]...)
   219  				nodesData = append(nodesData, hashes[i+1]...)
   220  				newHashes = append(newHashes, h.Sum(nil))
   221  			}
   222  		} else {
   223  			for i := 0; i < len(hashes)-1; i += 2 {
   224  				h := sha256.New()
   225  				_, _ = h.Write(hashes[i])
   226  				_, _ = h.Write(hashes[i+1])
   227  				nodesData = append(nodesData, hashes[i]...)
   228  				nodesData = append(nodesData, hashes[i+1]...)
   229  				newHashes = append(newHashes, h.Sum(nil))
   230  			}
   231  			h := sha256.New()
   232  			_, _ = h.Write(hashes[len(hashes)-1])
   233  			nodesData = append(nodesData, hashes[len(hashes)-1]...)
   234  			newHashes = append(newHashes, h.Sum(nil))
   235  		}
   236  
   237  		hashes = newHashes
   238  	}
   239  
   240  	nodes, indexes, err = getMerkleProofOfMultipleIndexes(nodesData, totalLeaves, depth, startInd, endInd)
   241  	if err != nil {
   242  		return nil, nil, nil, nil, err
   243  	}
   244  
   245  	startOffset := startInd * 64 * KB
   246  	endOffset := startOffset + (endInd-startInd+1)*64*KB
   247  	if endOffset > len(b) {
   248  		endOffset = len(b)
   249  	}
   250  
   251  	return hashes[0], nodes, indexes, b[startOffset:endOffset], nil
   252  }
   253  
   254  func getMerkleProofOfMultipleIndexes(nodesData []byte, totalLeaves, depth, startInd, endInd int) (
   255  	[][][]byte, [][]int, error) {
   256  
   257  	if endInd >= totalLeaves {
   258  		endInd = totalLeaves - 1
   259  	}
   260  
   261  	if endInd < startInd {
   262  		return nil, nil, errors.New("end index cannot be lesser than start index")
   263  	}
   264  
   265  	offsets, leftRightIndexes := getFileOffsetsAndNodeIndexes(totalLeaves, depth, startInd, endInd)
   266  
   267  	offsetInd := 0
   268  	nodeHashes := make([][][]byte, len(leftRightIndexes))
   269  	for i, indexes := range leftRightIndexes {
   270  		for range indexes {
   271  			b := make([]byte, HashSize)
   272  			off := offsets[offsetInd]
   273  			n := copy(b, nodesData[off:off+HashSize])
   274  			if n != HashSize {
   275  				return nil, nil, errors.New("invalid hash length")
   276  			}
   277  			nodeHashes[i] = append(nodeHashes[i], b)
   278  			offsetInd++
   279  		}
   280  	}
   281  	return nodeHashes, leftRightIndexes, nil
   282  }
   283  
   284  func getFileOffsetsAndNodeIndexes(totalLeaves, depth, startInd, endInd int) ([]int, [][]int) {
   285  
   286  	nodeIndexes, leftRightIndexes := getNodeIndexes(totalLeaves, depth, startInd, endInd)
   287  	offsets := make([]int, 0)
   288  	totalNodes := 0
   289  	curNodesTot := totalLeaves
   290  	for i := 0; i < len(nodeIndexes); i++ {
   291  		for _, ind := range nodeIndexes[i] {
   292  			offsetInd := ind + totalNodes
   293  			offsets = append(offsets, offsetInd*HashSize)
   294  		}
   295  		totalNodes += curNodesTot
   296  		curNodesTot = (curNodesTot + 1) / 2
   297  	}
   298  
   299  	return offsets, leftRightIndexes
   300  }
   301  
   302  func getNodeIndexes(totalLeaves, depth, startInd, endInd int) ([][]int, [][]int) {
   303  
   304  	indexes := make([][]int, 0)
   305  	leftRightIndexes := make([][]int, 0)
   306  	totalNodes := totalLeaves
   307  	for i := depth - 1; i >= 0; i-- {
   308  		if startInd == 0 && endInd == totalNodes-1 {
   309  			break
   310  		}
   311  
   312  		nodeOffsets := make([]int, 0)
   313  		lftRtInd := make([]int, 0)
   314  		if startInd&1 == 1 {
   315  			nodeOffsets = append(nodeOffsets, startInd-1)
   316  			lftRtInd = append(lftRtInd, Left)
   317  		}
   318  
   319  		if endInd != totalNodes-1 && endInd&1 == 0 {
   320  			nodeOffsets = append(nodeOffsets, endInd+1)
   321  			lftRtInd = append(lftRtInd, Right)
   322  		}
   323  
   324  		indexes = append(indexes, nodeOffsets)
   325  		leftRightIndexes = append(leftRightIndexes, lftRtInd)
   326  		startInd = startInd / 2
   327  		endInd = endInd / 2
   328  		totalNodes = (totalNodes + 1) / 2
   329  	}
   330  	return indexes, leftRightIndexes
   331  }