github.com/lazyledger/lazyledger-core@v0.35.0-dev.0.20210613111200-4c651f053571/p2p/ipld/read_test.go (about)

     1  package ipld
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha256"
     7  	"fmt"
     8  	"math"
     9  	"math/rand"
    10  	"sort"
    11  	"testing"
    12  	"time"
    13  
    14  	format "github.com/ipfs/go-ipld-format"
    15  	mdutils "github.com/ipfs/go-merkledag/test"
    16  	"github.com/lazyledger/nmt"
    17  	"github.com/lazyledger/nmt/namespace"
    18  	"github.com/lazyledger/rsmt2d"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/lazyledger/lazyledger-core/ipfs"
    23  	"github.com/lazyledger/lazyledger-core/ipfs/plugin"
    24  	"github.com/lazyledger/lazyledger-core/libs/log"
    25  	"github.com/lazyledger/lazyledger-core/p2p/ipld/wrapper"
    26  	"github.com/lazyledger/lazyledger-core/types"
    27  	"github.com/lazyledger/lazyledger-core/types/consts"
    28  )
    29  
    30  func TestGetLeafData(t *testing.T) {
    31  	const leaves = 16
    32  
    33  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    34  	defer cancel()
    35  
    36  	// generate random data for the nmt
    37  	data := generateRandNamespacedRawData(leaves, consts.NamespaceSize, consts.ShareSize)
    38  
    39  	// create a random tree
    40  	dag := mdutils.Mock()
    41  	root, err := getNmtRoot(ctx, dag, data)
    42  	require.NoError(t, err)
    43  
    44  	// compute the root and create a cid for the root hash
    45  	rootCid, err := plugin.CidFromNamespacedSha256(root.Bytes())
    46  	require.NoError(t, err)
    47  
    48  	for i, leaf := range data {
    49  		data, err := GetLeafData(ctx, rootCid, uint32(i), uint32(len(data)), dag)
    50  		assert.NoError(t, err)
    51  		assert.Equal(t, leaf, data)
    52  	}
    53  }
    54  
    55  func TestBlockRecovery(t *testing.T) {
    56  	originalSquareWidth := 8
    57  	shareCount := originalSquareWidth * originalSquareWidth
    58  	extendedSquareWidth := 2 * originalSquareWidth
    59  	extendedShareCount := extendedSquareWidth * extendedSquareWidth
    60  
    61  	// generate test data
    62  	quarterShares := generateRandNamespacedRawData(shareCount, consts.NamespaceSize, consts.MsgShareSize)
    63  	allShares := generateRandNamespacedRawData(shareCount, consts.NamespaceSize, consts.MsgShareSize)
    64  
    65  	testCases := []struct {
    66  		name      string
    67  		shares    [][]byte
    68  		expectErr bool
    69  		errString string
    70  		d         int // number of shares to delete
    71  	}{
    72  		{"missing 1/2 shares", quarterShares, false, "", extendedShareCount / 2},
    73  		{"missing 1/4 shares", quarterShares, false, "", extendedShareCount / 4},
    74  		{"max missing data", quarterShares, false, "", (originalSquareWidth + 1) * (originalSquareWidth + 1)},
    75  		{"missing all but one shares", allShares, true, "failed to solve data square", extendedShareCount - 1},
    76  	}
    77  	for _, tc := range testCases {
    78  		tc := tc
    79  
    80  		t.Run(tc.name, func(t *testing.T) {
    81  			squareSize := uint64(math.Sqrt(float64(len(tc.shares))))
    82  
    83  			// create trees for creating roots
    84  			tree := wrapper.NewErasuredNamespacedMerkleTree(squareSize)
    85  			recoverTree := wrapper.NewErasuredNamespacedMerkleTree(squareSize)
    86  
    87  			eds, err := rsmt2d.ComputeExtendedDataSquare(tc.shares, rsmt2d.NewRSGF8Codec(), tree.Constructor)
    88  			require.NoError(t, err)
    89  
    90  			// calculate roots using the first complete square
    91  			rowRoots := eds.RowRoots()
    92  			colRoots := eds.ColumnRoots()
    93  
    94  			flat := flatten(eds)
    95  
    96  			// recover a partially complete square
    97  			reds, err := rsmt2d.RepairExtendedDataSquare(
    98  				rowRoots,
    99  				colRoots,
   100  				removeRandShares(flat, tc.d),
   101  				rsmt2d.NewRSGF8Codec(),
   102  				recoverTree.Constructor,
   103  			)
   104  
   105  			if tc.expectErr {
   106  				require.Error(t, err)
   107  				require.Contains(t, err.Error(), tc.errString)
   108  				return
   109  			}
   110  			assert.NoError(t, err)
   111  
   112  			// check that the squares are equal
   113  			assert.Equal(t, flatten(eds), flatten(reds))
   114  		})
   115  	}
   116  }
   117  
   118  func TestRetrieveBlockData(t *testing.T) {
   119  	logger := log.TestingLogger()
   120  	type test struct {
   121  		name       string
   122  		squareSize int
   123  		expectErr  bool
   124  		errStr     string
   125  	}
   126  	tests := []test{
   127  		{"Empty block", 1, false, ""},
   128  		{"4 KB block", 4, false, ""},
   129  		{"16 KB block", 8, false, ""},
   130  		{"16 KB block timeout expected", 8, true, "not found"},
   131  		{"max square size", consts.MaxSquareSize, false, ""},
   132  	}
   133  
   134  	for _, tc := range tests {
   135  		// TODO(Wondertan): remove this
   136  		if tc.squareSize > 8 {
   137  			continue
   138  		}
   139  
   140  		tc := tc
   141  		t.Run(fmt.Sprintf("%s size %d", tc.name, tc.squareSize), func(t *testing.T) {
   142  			ctx := context.Background()
   143  			dag := mdutils.Mock()
   144  			croute := ipfs.MockRouting()
   145  
   146  			blockData := generateRandomBlockData(tc.squareSize*tc.squareSize, consts.MsgShareSize-2)
   147  			block := &types.Block{
   148  				Data:       blockData,
   149  				LastCommit: &types.Commit{},
   150  			}
   151  
   152  			// if an error is exected, don't put the block
   153  			if !tc.expectErr {
   154  				err := PutBlock(ctx, dag, block, croute, logger)
   155  				require.NoError(t, err)
   156  			}
   157  
   158  			shareData, _ := blockData.ComputeShares()
   159  			rawData := shareData.RawShares()
   160  
   161  			tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(tc.squareSize))
   162  			eds, err := rsmt2d.ComputeExtendedDataSquare(rawData, rsmt2d.NewRSGF8Codec(), tree.Constructor)
   163  			require.NoError(t, err)
   164  
   165  			rawRowRoots := eds.RowRoots()
   166  			rawColRoots := eds.ColumnRoots()
   167  			rowRoots := rootsToDigests(rawRowRoots)
   168  			colRoots := rootsToDigests(rawColRoots)
   169  
   170  			// limit with deadline retrieval specifically
   171  			ctx, cancel := context.WithTimeout(ctx, time.Second*2)
   172  			defer cancel()
   173  
   174  			rblockData, err := RetrieveBlockData(
   175  				ctx,
   176  				&types.DataAvailabilityHeader{
   177  					RowsRoots:   rowRoots,
   178  					ColumnRoots: colRoots,
   179  				},
   180  				dag,
   181  				rsmt2d.NewRSGF8Codec(),
   182  			)
   183  
   184  			if tc.expectErr {
   185  				require.Error(t, err)
   186  				require.Contains(t, err.Error(), tc.errStr)
   187  				return
   188  			}
   189  			require.NoError(t, err)
   190  
   191  			nsShares, _ := rblockData.ComputeShares()
   192  			assert.Equal(t, rawData, nsShares.RawShares())
   193  		})
   194  	}
   195  }
   196  
   197  func flatten(eds *rsmt2d.ExtendedDataSquare) [][]byte {
   198  	flattenedEDSSize := eds.Width() * eds.Width()
   199  	out := make([][]byte, flattenedEDSSize)
   200  	count := 0
   201  	for i := uint(0); i < eds.Width(); i++ {
   202  		for _, share := range eds.Row(i) {
   203  			out[count] = share
   204  			count++
   205  		}
   206  	}
   207  	return out
   208  }
   209  
   210  // getNmtRoot generates the nmt root of some namespaced data
   211  func getNmtRoot(
   212  	ctx context.Context,
   213  	dag format.NodeAdder,
   214  	namespacedData [][]byte,
   215  ) (namespace.IntervalDigest, error) {
   216  	na := NewNmtNodeAdder(ctx, format.NewBatch(ctx, dag))
   217  	tree := nmt.New(sha256.New, nmt.NamespaceIDSize(consts.NamespaceSize), nmt.NodeVisitor(na.Visit))
   218  	for _, leaf := range namespacedData {
   219  		err := tree.Push(leaf)
   220  		if err != nil {
   221  			return namespace.IntervalDigest{}, err
   222  		}
   223  	}
   224  
   225  	return tree.Root(), na.Commit()
   226  }
   227  
   228  // this code is copy pasted from the plugin, and should likely be exported in the plugin instead
   229  func generateRandNamespacedRawData(total int, nidSize int, leafSize int) [][]byte {
   230  	data := make([][]byte, total)
   231  	for i := 0; i < total; i++ {
   232  		nid := make([]byte, nidSize)
   233  		_, err := rand.Read(nid)
   234  		if err != nil {
   235  			panic(err)
   236  		}
   237  		data[i] = nid
   238  	}
   239  
   240  	sortByteArrays(data)
   241  	for i := 0; i < total; i++ {
   242  		d := make([]byte, leafSize)
   243  		_, err := rand.Read(d)
   244  		if err != nil {
   245  			panic(err)
   246  		}
   247  		data[i] = append(data[i], d...)
   248  	}
   249  
   250  	return data
   251  }
   252  
   253  func sortByteArrays(src [][]byte) {
   254  	sort.Slice(src, func(i, j int) bool { return bytes.Compare(src[i], src[j]) < 0 })
   255  }
   256  
   257  // removes d shares from data
   258  func removeRandShares(data [][]byte, d int) [][]byte {
   259  	count := len(data)
   260  	// remove shares randomly
   261  	for i := 0; i < d; {
   262  		ind := rand.Intn(count)
   263  		if len(data[ind]) == 0 {
   264  			continue
   265  		}
   266  		data[ind] = nil
   267  		i++
   268  	}
   269  	return data
   270  }
   271  
   272  func rootsToDigests(roots [][]byte) []namespace.IntervalDigest {
   273  	out := make([]namespace.IntervalDigest, len(roots))
   274  	for i, root := range roots {
   275  		idigest, err := namespace.IntervalDigestFromBytes(consts.NamespaceSize, root)
   276  		if err != nil {
   277  			panic(err)
   278  		}
   279  		out[i] = idigest
   280  	}
   281  	return out
   282  }
   283  
   284  func generateRandomBlockData(msgCount, msgSize int) types.Data {
   285  	var out types.Data
   286  	if msgCount == 1 {
   287  		return out
   288  	}
   289  	out.Messages = generateRandomMessages(msgCount-1, msgSize)
   290  	out.Txs = generateRandomContiguousShares(1)
   291  	return out
   292  }
   293  
   294  func generateRandomMessages(count, msgSize int) types.Messages {
   295  	shares := generateRandNamespacedRawData(count, consts.NamespaceSize, msgSize)
   296  	msgs := make([]types.Message, count)
   297  	for i, s := range shares {
   298  		msgs[i] = types.Message{
   299  			Data:        s[consts.NamespaceSize:],
   300  			NamespaceID: s[:consts.NamespaceSize],
   301  		}
   302  	}
   303  	return types.Messages{MessagesList: msgs}
   304  }
   305  
   306  func generateRandomContiguousShares(count int) types.Txs {
   307  	// the size of a length delimited tx that takes up an entire share
   308  	const adjustedTxSize = consts.TxShareSize - 2
   309  	txs := make(types.Txs, count)
   310  	for i := 0; i < count; i++ {
   311  		tx := make([]byte, adjustedTxSize)
   312  		_, err := rand.Read(tx)
   313  		if err != nil {
   314  			panic(err)
   315  		}
   316  		txs[i] = types.Tx(tx)
   317  	}
   318  	return txs
   319  }