github.com/0chain/gosdk@v1.17.11/core/util/validation_tree_test.go (about) 1 package util 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "errors" 7 "fmt" 8 "math" 9 "testing" 10 11 "github.com/minio/sha256-simd" 12 "github.com/stretchr/testify/require" 13 ) 14 15 const ( 16 HashSize = 32 17 ) 18 19 func TestValidationTreeWrite(t *testing.T) { 20 dataSizes := []int64{ 21 MaxMerkleLeavesSize, 22 MaxMerkleLeavesSize - 24*KB, 23 MaxMerkleLeavesSize * 2, 24 MaxMerkleLeavesSize * 3, 25 MaxMerkleLeavesSize*10 - 1, 26 } 27 28 for _, s := range dataSizes { 29 data := make([]byte, s) 30 n, err := rand.Read(data) 31 require.NoError(t, err) 32 require.EqualValues(t, s, n) 33 34 root := calculateValidationMerkleRoot(data) 35 36 vt := NewValidationTree(s) 37 diff := 1 38 i := len(data) - diff 39 40 _, err = vt.Write(data[0:i]) 41 require.NoError(t, err) 42 vt.calculateRoot() 43 44 require.False(t, bytes.Equal(root, vt.validationRoot)) 45 46 _, err = vt.Write(data[i:]) 47 require.NoError(t, err) 48 49 err = vt.Finalize() 50 require.NoError(t, err) 51 52 vt.calculateRoot() 53 require.True(t, bytes.Equal(root, vt.validationRoot)) 54 55 require.Error(t, vt.Finalize()) 56 } 57 } 58 59 func TestValidationTreeCalculateDepth(t *testing.T) { 60 in := map[int]int{ 61 1: 1, 62 2: 2, 63 3: 3, 64 4: 3, 65 10: 5, 66 100: 8, 67 } 68 69 for k, d := range in { 70 v := ValidationTree{leaves: make([][]byte, k)} 71 require.Equal(t, v.CalculateDepth(), d) 72 } 73 } 74 75 func TestMerklePathVerificationForValidationTree(t *testing.T) { 76 77 type input struct { 78 dataSize int64 79 startInd int 80 endInd int 81 } 82 83 tests := []*input{ 84 { 85 dataSize: 24 * KB, 86 startInd: 0, 87 endInd: 0, 88 }, 89 { 90 dataSize: 340 * KB, 91 startInd: 1, 92 endInd: 3, 93 }, 94 { 95 dataSize: 640 * KB, 96 startInd: 1, 97 endInd: 4, 98 }, 99 { 100 dataSize: 640*KB + 1, 101 startInd: 1, 102 endInd: 5, 103 }, 104 } 105 106 for _, test := range tests { 107 t.Run(fmt.Sprintf("Data size: %d KB, startInd: %d, endInd:%d", 108 test.dataSize/KB, 109 test.startInd, 110 test.endInd, 111 ), func(t *testing.T) { 112 113 b := make([]byte, test.dataSize) 114 n, err := rand.Read(b) 115 116 require.NoError(t, err) 117 require.EqualValues(t, test.dataSize, n) 118 119 root, nodes, indexes, data, err := calculateValidationRootAndNodes(b, test.startInd, test.endInd) 120 require.NoError(t, err) 121 122 t.Logf("nodes len: %d; index len: %d, indexes: %v", len(nodes), len(indexes), indexes) 123 vp := MerklePathForMultiLeafVerification{ 124 RootHash: root, 125 Nodes: nodes, 126 Index: indexes, 127 DataSize: test.dataSize, 128 } 129 130 err = vp.VerifyMultipleBlocks(data) 131 require.NoError(t, err) 132 133 err = vp.VerifyMultipleBlocks(data[1:]) 134 require.Error(t, err) 135 }) 136 137 } 138 } 139 140 func calculateValidationMerkleRoot(data []byte) []byte { 141 hashes := make([][]byte, 0) 142 for i := 0; i < len(data); i += MaxMerkleLeavesSize { 143 j := i + MaxMerkleLeavesSize 144 if j > len(data) { 145 j = len(data) 146 } 147 h := sha256.New() 148 _, _ = h.Write(data[i:j]) 149 hashes = append(hashes, h.Sum(nil)) 150 } 151 152 if len(hashes) == 1 { 153 return hashes[0] 154 } 155 for len(hashes) != 1 { 156 newHashes := make([][]byte, 0) 157 if len(hashes)%2 == 0 { 158 for i := 0; i < len(hashes); i += 2 { 159 h := sha256.New() 160 _, _ = h.Write(hashes[i]) 161 _, _ = h.Write(hashes[i+1]) 162 newHashes = append(newHashes, h.Sum(nil)) 163 } 164 } else { 165 for i := 0; i < len(hashes)-1; i += 2 { 166 h := sha256.New() 167 _, _ = h.Write(hashes[i]) 168 _, _ = h.Write(hashes[i+1]) 169 newHashes = append(newHashes, h.Sum(nil)) 170 } 171 h := sha256.New() 172 _, _ = h.Write(hashes[len(hashes)-1]) 173 newHashes = append(newHashes, h.Sum(nil)) 174 } 175 176 hashes = newHashes 177 } 178 return hashes[0] 179 } 180 181 func calculateValidationRootAndNodes(b []byte, startInd, endInd int) ( 182 root []byte, nodes [][][]byte, indexes [][]int, data []byte, err error, 183 ) { 184 185 totalLeaves := int(math.Ceil(float64(len(b)) / float64(MaxMerkleLeavesSize))) 186 depth := int(math.Ceil(math.Log2(float64(totalLeaves)))) + 1 187 188 if endInd >= totalLeaves { 189 endInd = totalLeaves - 1 190 } 191 192 hashes := make([][]byte, 0) 193 nodesData := make([]byte, 0) 194 h := sha256.New() 195 for i := 0; i < len(b); i += MaxMerkleLeavesSize { 196 j := i + MaxMerkleLeavesSize 197 if j > len(b) { 198 j = len(b) 199 } 200 201 _, _ = h.Write(b[i:j]) 202 leafHash := h.Sum(nil) 203 hashes = append(hashes, leafHash) 204 h.Reset() 205 } 206 207 if len(hashes) == 1 { 208 return hashes[0], nil, nil, b, nil 209 } 210 211 for len(hashes) != 1 { 212 newHashes := make([][]byte, 0) 213 if len(hashes)%2 == 0 { 214 for i := 0; i < len(hashes); i += 2 { 215 h := sha256.New() 216 _, _ = h.Write(hashes[i]) 217 _, _ = h.Write(hashes[i+1]) 218 nodesData = append(nodesData, hashes[i]...) 219 nodesData = append(nodesData, hashes[i+1]...) 220 newHashes = append(newHashes, h.Sum(nil)) 221 } 222 } else { 223 for i := 0; i < len(hashes)-1; i += 2 { 224 h := sha256.New() 225 _, _ = h.Write(hashes[i]) 226 _, _ = h.Write(hashes[i+1]) 227 nodesData = append(nodesData, hashes[i]...) 228 nodesData = append(nodesData, hashes[i+1]...) 229 newHashes = append(newHashes, h.Sum(nil)) 230 } 231 h := sha256.New() 232 _, _ = h.Write(hashes[len(hashes)-1]) 233 nodesData = append(nodesData, hashes[len(hashes)-1]...) 234 newHashes = append(newHashes, h.Sum(nil)) 235 } 236 237 hashes = newHashes 238 } 239 240 nodes, indexes, err = getMerkleProofOfMultipleIndexes(nodesData, totalLeaves, depth, startInd, endInd) 241 if err != nil { 242 return nil, nil, nil, nil, err 243 } 244 245 startOffset := startInd * 64 * KB 246 endOffset := startOffset + (endInd-startInd+1)*64*KB 247 if endOffset > len(b) { 248 endOffset = len(b) 249 } 250 251 return hashes[0], nodes, indexes, b[startOffset:endOffset], nil 252 } 253 254 func getMerkleProofOfMultipleIndexes(nodesData []byte, totalLeaves, depth, startInd, endInd int) ( 255 [][][]byte, [][]int, error) { 256 257 if endInd >= totalLeaves { 258 endInd = totalLeaves - 1 259 } 260 261 if endInd < startInd { 262 return nil, nil, errors.New("end index cannot be lesser than start index") 263 } 264 265 offsets, leftRightIndexes := getFileOffsetsAndNodeIndexes(totalLeaves, depth, startInd, endInd) 266 267 offsetInd := 0 268 nodeHashes := make([][][]byte, len(leftRightIndexes)) 269 for i, indexes := range leftRightIndexes { 270 for range indexes { 271 b := make([]byte, HashSize) 272 off := offsets[offsetInd] 273 n := copy(b, nodesData[off:off+HashSize]) 274 if n != HashSize { 275 return nil, nil, errors.New("invalid hash length") 276 } 277 nodeHashes[i] = append(nodeHashes[i], b) 278 offsetInd++ 279 } 280 } 281 return nodeHashes, leftRightIndexes, nil 282 } 283 284 func getFileOffsetsAndNodeIndexes(totalLeaves, depth, startInd, endInd int) ([]int, [][]int) { 285 286 nodeIndexes, leftRightIndexes := getNodeIndexes(totalLeaves, depth, startInd, endInd) 287 offsets := make([]int, 0) 288 totalNodes := 0 289 curNodesTot := totalLeaves 290 for i := 0; i < len(nodeIndexes); i++ { 291 for _, ind := range nodeIndexes[i] { 292 offsetInd := ind + totalNodes 293 offsets = append(offsets, offsetInd*HashSize) 294 } 295 totalNodes += curNodesTot 296 curNodesTot = (curNodesTot + 1) / 2 297 } 298 299 return offsets, leftRightIndexes 300 } 301 302 func getNodeIndexes(totalLeaves, depth, startInd, endInd int) ([][]int, [][]int) { 303 304 indexes := make([][]int, 0) 305 leftRightIndexes := make([][]int, 0) 306 totalNodes := totalLeaves 307 for i := depth - 1; i >= 0; i-- { 308 if startInd == 0 && endInd == totalNodes-1 { 309 break 310 } 311 312 nodeOffsets := make([]int, 0) 313 lftRtInd := make([]int, 0) 314 if startInd&1 == 1 { 315 nodeOffsets = append(nodeOffsets, startInd-1) 316 lftRtInd = append(lftRtInd, Left) 317 } 318 319 if endInd != totalNodes-1 && endInd&1 == 0 { 320 nodeOffsets = append(nodeOffsets, endInd+1) 321 lftRtInd = append(lftRtInd, Right) 322 } 323 324 indexes = append(indexes, nodeOffsets) 325 leftRightIndexes = append(leftRightIndexes, lftRtInd) 326 startInd = startInd / 2 327 endInd = endInd / 2 328 totalNodes = (totalNodes + 1) / 2 329 } 330 return indexes, leftRightIndexes 331 }