github.com/sunrise-zone/sunrise-node@v0.13.1-sr2/share/ipld/get_shares_test.go (about)

     1  package ipld
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha256"
     7  	"errors"
     8  	mrand "math/rand"
     9  	"sort"
    10  	"strconv"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/ipfs/boxo/blockservice"
    15  	"github.com/ipfs/go-cid"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"github.com/celestiaorg/rsmt2d"
    20  	"github.com/sunrise-zone/sunrise-app/pkg/wrapper"
    21  
    22  	"github.com/sunrise-zone/sunrise-node/libs/utils"
    23  	"github.com/sunrise-zone/sunrise-node/share"
    24  	"github.com/sunrise-zone/sunrise-node/share/eds/edstest"
    25  	"github.com/sunrise-zone/sunrise-node/share/sharetest"
    26  )
    27  
    28  func TestGetShare(t *testing.T) {
    29  	const size = 8
    30  
    31  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    32  	defer cancel()
    33  	bServ := NewMemBlockservice()
    34  
    35  	// generate random shares for the nmt
    36  	shares := sharetest.RandShares(t, size*size)
    37  	eds, err := AddShares(ctx, shares, bServ)
    38  	require.NoError(t, err)
    39  
    40  	for i, leaf := range shares {
    41  		row := i / size
    42  		pos := i - (size * row)
    43  		rowRoots, err := eds.RowRoots()
    44  		require.NoError(t, err)
    45  		share, err := GetShare(ctx, bServ, MustCidFromNamespacedSha256(rowRoots[row]), pos, size*2)
    46  		require.NoError(t, err)
    47  		assert.Equal(t, leaf, share)
    48  	}
    49  }
    50  
    51  func TestBlockRecovery(t *testing.T) {
    52  	originalSquareWidth := 8
    53  	shareCount := originalSquareWidth * originalSquareWidth
    54  	extendedSquareWidth := 2 * originalSquareWidth
    55  	extendedShareCount := extendedSquareWidth * extendedSquareWidth
    56  
    57  	// generate test data
    58  	quarterShares := sharetest.RandShares(t, shareCount)
    59  	allShares := sharetest.RandShares(t, shareCount)
    60  
    61  	testCases := []struct {
    62  		name      string
    63  		shares    []share.Share
    64  		expectErr bool
    65  		errString string
    66  		d         int // number of shares to delete
    67  	}{
    68  		{"missing 1/2 shares", quarterShares, false, "", extendedShareCount / 2},
    69  		{"missing 1/4 shares", quarterShares, false, "", extendedShareCount / 4},
    70  		{"max missing data", quarterShares, false, "", (originalSquareWidth + 1) * (originalSquareWidth + 1)},
    71  		{"missing all but one shares", allShares, true, "failed to solve data square", extendedShareCount - 1},
    72  	}
    73  	for _, tc := range testCases {
    74  		tc := tc
    75  
    76  		t.Run(tc.name, func(t *testing.T) {
    77  			squareSize := utils.SquareSize(len(tc.shares))
    78  
    79  			testEds, err := rsmt2d.ComputeExtendedDataSquare(
    80  				tc.shares,
    81  				share.DefaultRSMT2DCodec(),
    82  				wrapper.NewConstructor(squareSize),
    83  			)
    84  			require.NoError(t, err)
    85  
    86  			// calculate roots using the first complete square
    87  			rowRoots, err := testEds.RowRoots()
    88  			require.NoError(t, err)
    89  			colRoots, err := testEds.ColRoots()
    90  			require.NoError(t, err)
    91  
    92  			flat := testEds.Flattened()
    93  
    94  			// recover a partially complete square
    95  			rdata := removeRandShares(flat, tc.d)
    96  			testEds, err = rsmt2d.ImportExtendedDataSquare(
    97  				rdata,
    98  				share.DefaultRSMT2DCodec(),
    99  				wrapper.NewConstructor(squareSize),
   100  			)
   101  			require.NoError(t, err)
   102  
   103  			err = testEds.Repair(rowRoots, colRoots)
   104  			if tc.expectErr {
   105  				require.Error(t, err)
   106  				require.Contains(t, err.Error(), tc.errString)
   107  				return
   108  			}
   109  			assert.NoError(t, err)
   110  
   111  			reds, err := rsmt2d.ImportExtendedDataSquare(rdata, share.DefaultRSMT2DCodec(), wrapper.NewConstructor(squareSize))
   112  			require.NoError(t, err)
   113  			// check that the squares are equal
   114  			assert.Equal(t, testEds.Flattened(), reds.Flattened())
   115  		})
   116  	}
   117  }
   118  
   119  func Test_ConvertEDStoShares(t *testing.T) {
   120  	squareWidth := 16
   121  	shares := sharetest.RandShares(t, squareWidth*squareWidth)
   122  
   123  	// compute extended square
   124  	testEds, err := rsmt2d.ComputeExtendedDataSquare(
   125  		shares,
   126  		share.DefaultRSMT2DCodec(),
   127  		wrapper.NewConstructor(uint64(squareWidth)),
   128  	)
   129  	require.NoError(t, err)
   130  
   131  	resshares := testEds.FlattenedODS()
   132  	require.Equal(t, shares, resshares)
   133  }
   134  
   135  // removes d shares from data
   136  func removeRandShares(data [][]byte, d int) [][]byte {
   137  	count := len(data)
   138  	// remove shares randomly
   139  	for i := 0; i < d; {
   140  		ind := mrand.Intn(count)
   141  		if len(data[ind]) == 0 {
   142  			continue
   143  		}
   144  		data[ind] = nil
   145  		i++
   146  	}
   147  	return data
   148  }
   149  
   150  func TestGetSharesByNamespace(t *testing.T) {
   151  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   152  	t.Cleanup(cancel)
   153  	bServ := NewMemBlockservice()
   154  
   155  	var tests = []struct {
   156  		rawData []share.Share
   157  	}{
   158  		{rawData: sharetest.RandShares(t, 4)},
   159  		{rawData: sharetest.RandShares(t, 16)},
   160  	}
   161  
   162  	for i, tt := range tests {
   163  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   164  			// choose random namespace from rand shares
   165  			expected := tt.rawData[len(tt.rawData)/2]
   166  			namespace := share.GetNamespace(expected)
   167  
   168  			// change rawData to contain several shares with same namespace
   169  			tt.rawData[(len(tt.rawData)/2)+1] = expected
   170  			// put raw data in BlockService
   171  			eds, err := AddShares(ctx, tt.rawData, bServ)
   172  			require.NoError(t, err)
   173  
   174  			var shares []share.Share
   175  			rowRoots, err := eds.RowRoots()
   176  			require.NoError(t, err)
   177  			for _, row := range rowRoots {
   178  				rcid := MustCidFromNamespacedSha256(row)
   179  				rowShares, _, err := GetSharesByNamespace(ctx, bServ, rcid, namespace, len(rowRoots))
   180  				if errors.Is(err, ErrNamespaceOutsideRange) {
   181  					continue
   182  				}
   183  				require.NoError(t, err)
   184  
   185  				shares = append(shares, rowShares...)
   186  			}
   187  
   188  			assert.Equal(t, 2, len(shares))
   189  			for _, share := range shares {
   190  				assert.Equal(t, expected, share)
   191  			}
   192  		})
   193  	}
   194  }
   195  
   196  func TestCollectLeavesByNamespace_IncompleteData(t *testing.T) {
   197  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   198  	t.Cleanup(cancel)
   199  	bServ := NewMemBlockservice()
   200  
   201  	shares := sharetest.RandShares(t, 16)
   202  
   203  	// set all shares to the same namespace id
   204  	namespace := share.GetNamespace(shares[0])
   205  	for _, shr := range shares {
   206  		copy(share.GetNamespace(shr), namespace)
   207  	}
   208  
   209  	eds, err := AddShares(ctx, shares, bServ)
   210  	require.NoError(t, err)
   211  
   212  	roots, err := eds.RowRoots()
   213  	require.NoError(t, err)
   214  
   215  	// remove the second share from the first row
   216  	rcid := MustCidFromNamespacedSha256(roots[0])
   217  	node, err := GetNode(ctx, bServ, rcid)
   218  	require.NoError(t, err)
   219  
   220  	// Left side of the tree contains the original shares
   221  	data, err := GetNode(ctx, bServ, node.Links()[0].Cid)
   222  	require.NoError(t, err)
   223  
   224  	// Second share is the left side's right child
   225  	l, err := GetNode(ctx, bServ, data.Links()[0].Cid)
   226  	require.NoError(t, err)
   227  	r, err := GetNode(ctx, bServ, l.Links()[1].Cid)
   228  	require.NoError(t, err)
   229  	err = bServ.DeleteBlock(ctx, r.Cid())
   230  	require.NoError(t, err)
   231  
   232  	namespaceData := NewNamespaceData(len(shares), namespace, WithLeaves())
   233  	err = namespaceData.CollectLeavesByNamespace(ctx, bServ, rcid)
   234  	require.Error(t, err)
   235  	leaves := namespaceData.Leaves()
   236  	assert.Nil(t, leaves[1])
   237  	assert.Equal(t, 4, len(leaves))
   238  }
   239  
   240  func TestCollectLeavesByNamespace_AbsentNamespaceId(t *testing.T) {
   241  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   242  	t.Cleanup(cancel)
   243  	bServ := NewMemBlockservice()
   244  
   245  	shares := sharetest.RandShares(t, 1024)
   246  
   247  	// set all shares to the same namespace
   248  	namespaces, err := randomNamespaces(5)
   249  	require.NoError(t, err)
   250  	minNamespace := namespaces[0]
   251  	minIncluded := namespaces[1]
   252  	midNamespace := namespaces[2]
   253  	maxIncluded := namespaces[3]
   254  	maxNamespace := namespaces[4]
   255  
   256  	secondNamespaceFrom := mrand.Intn(len(shares)-2) + 1
   257  	for i, shr := range shares {
   258  		if i < secondNamespaceFrom {
   259  			copy(share.GetNamespace(shr), minIncluded)
   260  			continue
   261  		}
   262  		copy(share.GetNamespace(shr), maxIncluded)
   263  	}
   264  
   265  	var tests = []struct {
   266  		name             string
   267  		data             []share.Share
   268  		missingNamespace share.Namespace
   269  		isAbsence        bool
   270  	}{
   271  		{name: "Namespace less than the minimum namespace in data", data: shares, missingNamespace: minNamespace},
   272  		{name: "Namespace greater than the maximum namespace in data", data: shares, missingNamespace: maxNamespace},
   273  		{name: "Namespace in range but still missing", data: shares, missingNamespace: midNamespace, isAbsence: true},
   274  	}
   275  
   276  	for _, tt := range tests {
   277  		t.Run(tt.name, func(t *testing.T) {
   278  			eds, err := AddShares(ctx, shares, bServ)
   279  			require.NoError(t, err)
   280  			assertNoRowContainsNID(ctx, t, bServ, eds, tt.missingNamespace, tt.isAbsence)
   281  		})
   282  	}
   283  }
   284  
   285  func TestCollectLeavesByNamespace_MultipleRowsContainingSameNamespaceId(t *testing.T) {
   286  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   287  	t.Cleanup(cancel)
   288  	bServ := NewMemBlockservice()
   289  
   290  	shares := sharetest.RandShares(t, 16)
   291  
   292  	// set all shares to the same namespace and data but the last one
   293  	namespace := share.GetNamespace(shares[0])
   294  	commonNamespaceData := shares[0]
   295  
   296  	for i, nspace := range shares {
   297  		if i == len(shares)-1 {
   298  			break
   299  		}
   300  
   301  		copy(nspace, commonNamespaceData)
   302  	}
   303  
   304  	eds, err := AddShares(ctx, shares, bServ)
   305  	require.NoError(t, err)
   306  
   307  	rowRoots, err := eds.RowRoots()
   308  	require.NoError(t, err)
   309  
   310  	for _, row := range rowRoots {
   311  		rcid := MustCidFromNamespacedSha256(row)
   312  		data := NewNamespaceData(len(shares), namespace, WithLeaves())
   313  		err := data.CollectLeavesByNamespace(ctx, bServ, rcid)
   314  		if errors.Is(err, ErrNamespaceOutsideRange) {
   315  			continue
   316  		}
   317  		assert.Nil(t, err)
   318  		leaves := data.Leaves()
   319  		for _, node := range leaves {
   320  			// test that the data returned by collectLeavesByNamespace for nid
   321  			// matches the commonNamespaceData that was copied across almost all data
   322  			assert.Equal(t, commonNamespaceData, share.GetData(node.RawData()))
   323  		}
   324  	}
   325  }
   326  
   327  func TestGetSharesWithProofsByNamespace(t *testing.T) {
   328  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   329  	t.Cleanup(cancel)
   330  	bServ := NewMemBlockservice()
   331  
   332  	var tests = []struct {
   333  		rawData []share.Share
   334  	}{
   335  		{rawData: sharetest.RandShares(t, 4)},
   336  		{rawData: sharetest.RandShares(t, 16)},
   337  		{rawData: sharetest.RandShares(t, 64)},
   338  	}
   339  
   340  	for i, tt := range tests {
   341  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   342  			rand := mrand.New(mrand.NewSource(time.Now().UnixNano()))
   343  			// choose random range in shares slice
   344  			from := rand.Intn(len(tt.rawData))
   345  			to := rand.Intn(len(tt.rawData))
   346  
   347  			if to < from {
   348  				from, to = to, from
   349  			}
   350  
   351  			expected := tt.rawData[from]
   352  			namespace := share.GetNamespace(expected)
   353  
   354  			// change rawData to contain several shares with same namespace
   355  			for i := from; i <= to; i++ {
   356  				tt.rawData[i] = expected
   357  			}
   358  
   359  			// put raw data in BlockService
   360  			eds, err := AddShares(ctx, tt.rawData, bServ)
   361  			require.NoError(t, err)
   362  
   363  			var shares []share.Share
   364  			rowRoots, err := eds.RowRoots()
   365  			require.NoError(t, err)
   366  			for _, row := range rowRoots {
   367  				rcid := MustCidFromNamespacedSha256(row)
   368  				rowShares, proof, err := GetSharesByNamespace(ctx, bServ, rcid, namespace, len(rowRoots))
   369  				if namespace.IsOutsideRange(row, row) {
   370  					require.ErrorIs(t, err, ErrNamespaceOutsideRange)
   371  					continue
   372  				}
   373  				require.NoError(t, err)
   374  				if len(rowShares) > 0 {
   375  					require.NotNil(t, proof)
   376  					// append shares to check integrity later
   377  					shares = append(shares, rowShares...)
   378  
   379  					// construct nodes from shares by prepending namespace
   380  					var leaves [][]byte
   381  					for _, shr := range rowShares {
   382  						leaves = append(leaves, append(share.GetNamespace(shr), shr...))
   383  					}
   384  
   385  					// verify namespace
   386  					verified := proof.VerifyNamespace(
   387  						sha256.New(),
   388  						namespace.ToNMT(),
   389  						leaves,
   390  						NamespacedSha256FromCID(rcid))
   391  					require.True(t, verified)
   392  
   393  					// verify inclusion
   394  					verified = proof.VerifyInclusion(
   395  						sha256.New(),
   396  						namespace.ToNMT(),
   397  						rowShares,
   398  						NamespacedSha256FromCID(rcid))
   399  					require.True(t, verified)
   400  				}
   401  			}
   402  
   403  			// validate shares
   404  			assert.Equal(t, to-from+1, len(shares))
   405  			for _, share := range shares {
   406  				assert.Equal(t, expected, share)
   407  			}
   408  		})
   409  	}
   410  }
   411  
   412  func TestBatchSize(t *testing.T) {
   413  	tests := []struct {
   414  		name      string
   415  		origWidth int
   416  	}{
   417  		{"2", 2},
   418  		{"4", 4},
   419  		{"8", 8},
   420  		{"16", 16},
   421  		{"32", 32},
   422  		{"64", 64},
   423  	}
   424  	for _, tt := range tests {
   425  		t.Run(tt.name, func(t *testing.T) {
   426  			ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(tt.origWidth))
   427  			defer cancel()
   428  
   429  			bs := NewMemBlockservice()
   430  
   431  			randEds := edstest.RandEDS(t, tt.origWidth)
   432  			_, err := AddShares(ctx, randEds.FlattenedODS(), bs)
   433  			require.NoError(t, err)
   434  
   435  			out, err := bs.Blockstore().AllKeysChan(ctx)
   436  			require.NoError(t, err)
   437  
   438  			var count int
   439  			for range out {
   440  				count++
   441  			}
   442  			extendedWidth := tt.origWidth * 2
   443  			assert.Equalf(t, count, BatchSize(extendedWidth), "batchSize(%v)", extendedWidth)
   444  		})
   445  	}
   446  }
   447  
   448  func assertNoRowContainsNID(
   449  	ctx context.Context,
   450  	t *testing.T,
   451  	bServ blockservice.BlockService,
   452  	eds *rsmt2d.ExtendedDataSquare,
   453  	namespace share.Namespace,
   454  	isAbsent bool,
   455  ) {
   456  	rowRoots, err := eds.RowRoots()
   457  	require.NoError(t, err)
   458  	rowRootCount := len(rowRoots)
   459  	// get all row root cids
   460  	rowRootCIDs := make([]cid.Cid, rowRootCount)
   461  	for i, rowRoot := range rowRoots {
   462  		rowRootCIDs[i] = MustCidFromNamespacedSha256(rowRoot)
   463  	}
   464  
   465  	// for each row root cid check if the min namespace exists
   466  	var absentCount, foundAbsenceRows int
   467  	for _, rowRoot := range rowRoots {
   468  		var outsideRange bool
   469  		if !namespace.IsOutsideRange(rowRoot, rowRoot) {
   470  			// namespace does belong to namespace range of the row
   471  			absentCount++
   472  		} else {
   473  			outsideRange = true
   474  		}
   475  		data := NewNamespaceData(rowRootCount, namespace, WithProofs())
   476  		rootCID := MustCidFromNamespacedSha256(rowRoot)
   477  		err := data.CollectLeavesByNamespace(ctx, bServ, rootCID)
   478  		if outsideRange {
   479  			require.ErrorIs(t, err, ErrNamespaceOutsideRange)
   480  			continue
   481  		}
   482  		require.NoError(t, err)
   483  
   484  		// if no error returned, check absence proof
   485  		foundAbsenceRows++
   486  		verified := data.Proof().VerifyNamespace(sha256.New(), namespace.ToNMT(), nil, rowRoot)
   487  		require.True(t, verified)
   488  	}
   489  
   490  	if isAbsent {
   491  		require.Equal(t, foundAbsenceRows, absentCount)
   492  		// there should be max 1 row that has namespace range containing namespace
   493  		require.LessOrEqual(t, absentCount, 1)
   494  	}
   495  }
   496  
   497  func randomNamespaces(total int) ([]share.Namespace, error) {
   498  	namespaces := make([]share.Namespace, total)
   499  	for i := range namespaces {
   500  		namespaces[i] = sharetest.RandV0Namespace()
   501  	}
   502  	sort.Slice(namespaces, func(i, j int) bool { return bytes.Compare(namespaces[i], namespaces[j]) < 0 })
   503  	return namespaces, nil
   504  }