github.com/lazyledger/lazyledger-core@v0.35.0-dev.0.20210613111200-4c651f053571/p2p/ipld/read_test.go (about) 1 package ipld 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/sha256" 7 "fmt" 8 "math" 9 "math/rand" 10 "sort" 11 "testing" 12 "time" 13 14 format "github.com/ipfs/go-ipld-format" 15 mdutils "github.com/ipfs/go-merkledag/test" 16 "github.com/lazyledger/nmt" 17 "github.com/lazyledger/nmt/namespace" 18 "github.com/lazyledger/rsmt2d" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 22 "github.com/lazyledger/lazyledger-core/ipfs" 23 "github.com/lazyledger/lazyledger-core/ipfs/plugin" 24 "github.com/lazyledger/lazyledger-core/libs/log" 25 "github.com/lazyledger/lazyledger-core/p2p/ipld/wrapper" 26 "github.com/lazyledger/lazyledger-core/types" 27 "github.com/lazyledger/lazyledger-core/types/consts" 28 ) 29 30 func TestGetLeafData(t *testing.T) { 31 const leaves = 16 32 33 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 34 defer cancel() 35 36 // generate random data for the nmt 37 data := generateRandNamespacedRawData(leaves, consts.NamespaceSize, consts.ShareSize) 38 39 // create a random tree 40 dag := mdutils.Mock() 41 root, err := getNmtRoot(ctx, dag, data) 42 require.NoError(t, err) 43 44 // compute the root and create a cid for the root hash 45 rootCid, err := plugin.CidFromNamespacedSha256(root.Bytes()) 46 require.NoError(t, err) 47 48 for i, leaf := range data { 49 data, err := GetLeafData(ctx, rootCid, uint32(i), uint32(len(data)), dag) 50 assert.NoError(t, err) 51 assert.Equal(t, leaf, data) 52 } 53 } 54 55 func TestBlockRecovery(t *testing.T) { 56 originalSquareWidth := 8 57 shareCount := originalSquareWidth * originalSquareWidth 58 extendedSquareWidth := 2 * originalSquareWidth 59 extendedShareCount := extendedSquareWidth * extendedSquareWidth 60 61 // generate test data 62 quarterShares := generateRandNamespacedRawData(shareCount, consts.NamespaceSize, consts.MsgShareSize) 63 allShares := generateRandNamespacedRawData(shareCount, consts.NamespaceSize, consts.MsgShareSize) 64 65 testCases := []struct { 66 name string 67 shares [][]byte 68 expectErr bool 69 errString string 70 d int // number of shares to delete 71 }{ 72 {"missing 1/2 shares", quarterShares, false, "", extendedShareCount / 2}, 73 {"missing 1/4 shares", quarterShares, false, "", extendedShareCount / 4}, 74 {"max missing data", quarterShares, false, "", (originalSquareWidth + 1) * (originalSquareWidth + 1)}, 75 {"missing all but one shares", allShares, true, "failed to solve data square", extendedShareCount - 1}, 76 } 77 for _, tc := range testCases { 78 tc := tc 79 80 t.Run(tc.name, func(t *testing.T) { 81 squareSize := uint64(math.Sqrt(float64(len(tc.shares)))) 82 83 // create trees for creating roots 84 tree := wrapper.NewErasuredNamespacedMerkleTree(squareSize) 85 recoverTree := wrapper.NewErasuredNamespacedMerkleTree(squareSize) 86 87 eds, err := rsmt2d.ComputeExtendedDataSquare(tc.shares, rsmt2d.NewRSGF8Codec(), tree.Constructor) 88 require.NoError(t, err) 89 90 // calculate roots using the first complete square 91 rowRoots := eds.RowRoots() 92 colRoots := eds.ColumnRoots() 93 94 flat := flatten(eds) 95 96 // recover a partially complete square 97 reds, err := rsmt2d.RepairExtendedDataSquare( 98 rowRoots, 99 colRoots, 100 removeRandShares(flat, tc.d), 101 rsmt2d.NewRSGF8Codec(), 102 recoverTree.Constructor, 103 ) 104 105 if tc.expectErr { 106 require.Error(t, err) 107 require.Contains(t, err.Error(), tc.errString) 108 return 109 } 110 assert.NoError(t, err) 111 112 // check that the squares are equal 113 assert.Equal(t, flatten(eds), flatten(reds)) 114 }) 115 } 116 } 117 118 func TestRetrieveBlockData(t *testing.T) { 119 logger := log.TestingLogger() 120 type test struct { 121 name string 122 squareSize int 123 expectErr bool 124 errStr string 125 } 126 tests := []test{ 127 {"Empty block", 1, false, ""}, 128 {"4 KB block", 4, false, ""}, 129 {"16 KB block", 8, false, ""}, 130 {"16 KB block timeout expected", 8, true, "not found"}, 131 {"max square size", consts.MaxSquareSize, false, ""}, 132 } 133 134 for _, tc := range tests { 135 // TODO(Wondertan): remove this 136 if tc.squareSize > 8 { 137 continue 138 } 139 140 tc := tc 141 t.Run(fmt.Sprintf("%s size %d", tc.name, tc.squareSize), func(t *testing.T) { 142 ctx := context.Background() 143 dag := mdutils.Mock() 144 croute := ipfs.MockRouting() 145 146 blockData := generateRandomBlockData(tc.squareSize*tc.squareSize, consts.MsgShareSize-2) 147 block := &types.Block{ 148 Data: blockData, 149 LastCommit: &types.Commit{}, 150 } 151 152 // if an error is exected, don't put the block 153 if !tc.expectErr { 154 err := PutBlock(ctx, dag, block, croute, logger) 155 require.NoError(t, err) 156 } 157 158 shareData, _ := blockData.ComputeShares() 159 rawData := shareData.RawShares() 160 161 tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(tc.squareSize)) 162 eds, err := rsmt2d.ComputeExtendedDataSquare(rawData, rsmt2d.NewRSGF8Codec(), tree.Constructor) 163 require.NoError(t, err) 164 165 rawRowRoots := eds.RowRoots() 166 rawColRoots := eds.ColumnRoots() 167 rowRoots := rootsToDigests(rawRowRoots) 168 colRoots := rootsToDigests(rawColRoots) 169 170 // limit with deadline retrieval specifically 171 ctx, cancel := context.WithTimeout(ctx, time.Second*2) 172 defer cancel() 173 174 rblockData, err := RetrieveBlockData( 175 ctx, 176 &types.DataAvailabilityHeader{ 177 RowsRoots: rowRoots, 178 ColumnRoots: colRoots, 179 }, 180 dag, 181 rsmt2d.NewRSGF8Codec(), 182 ) 183 184 if tc.expectErr { 185 require.Error(t, err) 186 require.Contains(t, err.Error(), tc.errStr) 187 return 188 } 189 require.NoError(t, err) 190 191 nsShares, _ := rblockData.ComputeShares() 192 assert.Equal(t, rawData, nsShares.RawShares()) 193 }) 194 } 195 } 196 197 func flatten(eds *rsmt2d.ExtendedDataSquare) [][]byte { 198 flattenedEDSSize := eds.Width() * eds.Width() 199 out := make([][]byte, flattenedEDSSize) 200 count := 0 201 for i := uint(0); i < eds.Width(); i++ { 202 for _, share := range eds.Row(i) { 203 out[count] = share 204 count++ 205 } 206 } 207 return out 208 } 209 210 // getNmtRoot generates the nmt root of some namespaced data 211 func getNmtRoot( 212 ctx context.Context, 213 dag format.NodeAdder, 214 namespacedData [][]byte, 215 ) (namespace.IntervalDigest, error) { 216 na := NewNmtNodeAdder(ctx, format.NewBatch(ctx, dag)) 217 tree := nmt.New(sha256.New, nmt.NamespaceIDSize(consts.NamespaceSize), nmt.NodeVisitor(na.Visit)) 218 for _, leaf := range namespacedData { 219 err := tree.Push(leaf) 220 if err != nil { 221 return namespace.IntervalDigest{}, err 222 } 223 } 224 225 return tree.Root(), na.Commit() 226 } 227 228 // this code is copy pasted from the plugin, and should likely be exported in the plugin instead 229 func generateRandNamespacedRawData(total int, nidSize int, leafSize int) [][]byte { 230 data := make([][]byte, total) 231 for i := 0; i < total; i++ { 232 nid := make([]byte, nidSize) 233 _, err := rand.Read(nid) 234 if err != nil { 235 panic(err) 236 } 237 data[i] = nid 238 } 239 240 sortByteArrays(data) 241 for i := 0; i < total; i++ { 242 d := make([]byte, leafSize) 243 _, err := rand.Read(d) 244 if err != nil { 245 panic(err) 246 } 247 data[i] = append(data[i], d...) 248 } 249 250 return data 251 } 252 253 func sortByteArrays(src [][]byte) { 254 sort.Slice(src, func(i, j int) bool { return bytes.Compare(src[i], src[j]) < 0 }) 255 } 256 257 // removes d shares from data 258 func removeRandShares(data [][]byte, d int) [][]byte { 259 count := len(data) 260 // remove shares randomly 261 for i := 0; i < d; { 262 ind := rand.Intn(count) 263 if len(data[ind]) == 0 { 264 continue 265 } 266 data[ind] = nil 267 i++ 268 } 269 return data 270 } 271 272 func rootsToDigests(roots [][]byte) []namespace.IntervalDigest { 273 out := make([]namespace.IntervalDigest, len(roots)) 274 for i, root := range roots { 275 idigest, err := namespace.IntervalDigestFromBytes(consts.NamespaceSize, root) 276 if err != nil { 277 panic(err) 278 } 279 out[i] = idigest 280 } 281 return out 282 } 283 284 func generateRandomBlockData(msgCount, msgSize int) types.Data { 285 var out types.Data 286 if msgCount == 1 { 287 return out 288 } 289 out.Messages = generateRandomMessages(msgCount-1, msgSize) 290 out.Txs = generateRandomContiguousShares(1) 291 return out 292 } 293 294 func generateRandomMessages(count, msgSize int) types.Messages { 295 shares := generateRandNamespacedRawData(count, consts.NamespaceSize, msgSize) 296 msgs := make([]types.Message, count) 297 for i, s := range shares { 298 msgs[i] = types.Message{ 299 Data: s[consts.NamespaceSize:], 300 NamespaceID: s[:consts.NamespaceSize], 301 } 302 } 303 return types.Messages{MessagesList: msgs} 304 } 305 306 func generateRandomContiguousShares(count int) types.Txs { 307 // the size of a length delimited tx that takes up an entire share 308 const adjustedTxSize = consts.TxShareSize - 2 309 txs := make(types.Txs, count) 310 for i := 0; i < count; i++ { 311 tx := make([]byte, adjustedTxSize) 312 _, err := rand.Read(tx) 313 if err != nil { 314 panic(err) 315 } 316 txs[i] = types.Tx(tx) 317 } 318 return txs 319 }