github.com/sunrise-zone/sunrise-node@v0.13.1-sr2/share/eds/byzantine/bad_encoding_test.go (about)

     1  package byzantine
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"hash"
     7  	"testing"
     8  	"time"
     9  
    10  	core "github.com/cometbft/cometbft/types"
    11  	"github.com/ipfs/boxo/blockservice"
    12  	blocks "github.com/ipfs/go-block-format"
    13  	"github.com/ipfs/go-cid"
    14  	mhcore "github.com/multiformats/go-multihash/core"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/celestiaorg/nmt"
    18  	"github.com/celestiaorg/rsmt2d"
    19  	"github.com/sunrise-zone/sunrise-app/pkg/da"
    20  	"github.com/sunrise-zone/sunrise-app/test/util/malicious"
    21  
    22  	"github.com/sunrise-zone/sunrise-node/header"
    23  	"github.com/sunrise-zone/sunrise-node/share"
    24  	"github.com/sunrise-zone/sunrise-node/share/eds/edstest"
    25  	"github.com/sunrise-zone/sunrise-node/share/ipld"
    26  	"github.com/sunrise-zone/sunrise-node/share/sharetest"
    27  )
    28  
    29  func TestBEFP_Validate(t *testing.T) {
    30  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
    31  	defer t.Cleanup(cancel)
    32  	bServ := ipld.NewMemBlockservice()
    33  
    34  	square := edstest.RandByzantineEDS(t, 16)
    35  	dah, err := da.NewDataAvailabilityHeader(square)
    36  	require.NoError(t, err)
    37  	err = ipld.ImportEDS(ctx, square, bServ)
    38  	require.NoError(t, err)
    39  
    40  	var errRsmt2d *rsmt2d.ErrByzantineData
    41  	err = square.Repair(dah.RowRoots, dah.ColumnRoots)
    42  	require.ErrorAs(t, err, &errRsmt2d)
    43  
    44  	byzantine := NewErrByzantine(ctx, bServ, &dah, errRsmt2d)
    45  	var errByz *ErrByzantine
    46  	require.ErrorAs(t, byzantine, &errByz)
    47  
    48  	proof := CreateBadEncodingProof([]byte("hash"), 0, errByz)
    49  	befp, ok := proof.(*BadEncodingProof)
    50  	require.True(t, ok)
    51  	var test = []struct {
    52  		name           string
    53  		prepareFn      func() error
    54  		expectedResult func(error)
    55  	}{
    56  		{
    57  			name: "valid BEFP",
    58  			prepareFn: func() error {
    59  				return proof.Validate(&header.ExtendedHeader{DAH: &dah})
    60  			},
    61  			expectedResult: func(err error) {
    62  				require.NoError(t, err)
    63  			},
    64  		},
    65  		{
    66  			name: "invalid BEFP for valid header",
    67  			prepareFn: func() error {
    68  				validSquare := edstest.RandEDS(t, 2)
    69  				validDah, err := da.NewDataAvailabilityHeader(validSquare)
    70  				require.NoError(t, err)
    71  				err = ipld.ImportEDS(ctx, validSquare, bServ)
    72  				require.NoError(t, err)
    73  				validShares := validSquare.Flattened()
    74  				errInvalidByz := NewErrByzantine(ctx, bServ, &validDah,
    75  					&rsmt2d.ErrByzantineData{
    76  						Axis:   rsmt2d.Row,
    77  						Index:  0,
    78  						Shares: validShares[0:4],
    79  					},
    80  				)
    81  				var errInvalid *ErrByzantine
    82  				require.ErrorAs(t, errInvalidByz, &errInvalid)
    83  				invalidBefp := CreateBadEncodingProof([]byte("hash"), 0, errInvalid)
    84  				return invalidBefp.Validate(&header.ExtendedHeader{DAH: &validDah})
    85  			},
    86  			expectedResult: func(err error) {
    87  				require.ErrorIs(t, err, errNMTTreeRootsMatch)
    88  			},
    89  		},
    90  		{
    91  			name: "incorrect share with Proof",
    92  			prepareFn: func() error {
    93  				// break the first shareWithProof to test negative case
    94  				sh := sharetest.RandShares(t, 2)
    95  				nmtProof := nmt.NewInclusionProof(0, 1, nil, false)
    96  				befp.Shares[0] = &ShareWithProof{sh[0], &nmtProof}
    97  				return proof.Validate(&header.ExtendedHeader{DAH: &dah})
    98  			},
    99  			expectedResult: func(err error) {
   100  				require.ErrorIs(t, err, errIncorrectShare)
   101  			},
   102  		},
   103  		{
   104  			name: "invalid amount of shares",
   105  			prepareFn: func() error {
   106  				befp.Shares = befp.Shares[0 : len(befp.Shares)/2]
   107  				return proof.Validate(&header.ExtendedHeader{DAH: &dah})
   108  			},
   109  			expectedResult: func(err error) {
   110  				require.ErrorIs(t, err, errIncorrectAmountOfShares)
   111  			},
   112  		},
   113  		{
   114  			name: "not enough shares to recompute the root",
   115  			prepareFn: func() error {
   116  				befp.Shares[0] = nil
   117  				return proof.Validate(&header.ExtendedHeader{DAH: &dah})
   118  			},
   119  			expectedResult: func(err error) {
   120  				require.ErrorIs(t, err, errIncorrectAmountOfShares)
   121  			},
   122  		},
   123  		{
   124  			name: "index out of bounds",
   125  			prepareFn: func() error {
   126  				befp.Index = 100
   127  				return proof.Validate(&header.ExtendedHeader{DAH: &dah})
   128  			},
   129  			expectedResult: func(err error) {
   130  				require.ErrorIs(t, err, errIncorrectIndex)
   131  			},
   132  		},
   133  		{
   134  			name: "heights mismatch",
   135  			prepareFn: func() error {
   136  				return proof.Validate(&header.ExtendedHeader{
   137  					RawHeader: core.Header{
   138  						Height: 42,
   139  					},
   140  					DAH: &dah,
   141  				})
   142  			},
   143  			expectedResult: func(err error) {
   144  				require.ErrorIs(t, err, errHeightMismatch)
   145  			},
   146  		},
   147  	}
   148  
   149  	for _, tt := range test {
   150  		t.Run(tt.name, func(t *testing.T) {
   151  			err = tt.prepareFn()
   152  			tt.expectedResult(err)
   153  		})
   154  	}
   155  }
   156  
   157  // TestIncorrectBadEncodingFraudProof asserts that BEFP is not generated for the correct data
   158  func TestIncorrectBadEncodingFraudProof(t *testing.T) {
   159  	ctx, cancel := context.WithCancel(context.Background())
   160  	defer cancel()
   161  
   162  	bServ := ipld.NewMemBlockservice()
   163  
   164  	squareSize := 8
   165  	shares := sharetest.RandShares(t, squareSize*squareSize)
   166  
   167  	eds, err := ipld.AddShares(ctx, shares, bServ)
   168  	require.NoError(t, err)
   169  
   170  	dah, err := share.NewRoot(eds)
   171  	require.NoError(t, err)
   172  
   173  	// get an arbitrary row
   174  	row := uint(squareSize / 2)
   175  	rowShares := eds.Row(row)
   176  	rowRoot := dah.RowRoots[row]
   177  
   178  	shareProofs, err := GetProofsForShares(ctx, bServ, ipld.MustCidFromNamespacedSha256(rowRoot), rowShares)
   179  	require.NoError(t, err)
   180  
   181  	// create a fake error for data that was encoded correctly
   182  	fakeError := ErrByzantine{
   183  		Index:  uint32(row),
   184  		Shares: shareProofs,
   185  		Axis:   rsmt2d.Row,
   186  	}
   187  
   188  	h := &header.ExtendedHeader{
   189  		RawHeader: core.Header{
   190  			Height: 420,
   191  		},
   192  		DAH: dah,
   193  		Commit: &core.Commit{
   194  			BlockID: core.BlockID{
   195  				Hash: []byte("made up hash"),
   196  			},
   197  		},
   198  	}
   199  
   200  	proof := CreateBadEncodingProof(h.Hash(), h.Height(), &fakeError)
   201  	err = proof.Validate(h)
   202  	require.Error(t, err)
   203  }
   204  
   205  func TestBEFP_ValidateOutOfOrderShares(t *testing.T) {
   206  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
   207  	t.Cleanup(cancel)
   208  
   209  	size := 4
   210  	eds := edstest.RandEDS(t, size)
   211  
   212  	shares := eds.Flattened()
   213  	shares[0], shares[4] = shares[4], shares[0] // corrupting eds
   214  
   215  	bServ := newNamespacedBlockService()
   216  	batchAddr := ipld.NewNmtNodeAdder(ctx, bServ, ipld.MaxSizeBatchOption(size*2))
   217  
   218  	eds, err := rsmt2d.ImportExtendedDataSquare(shares,
   219  		share.DefaultRSMT2DCodec(),
   220  		malicious.NewConstructor(uint64(size), nmt.NodeVisitor(batchAddr.Visit)),
   221  	)
   222  	require.NoError(t, err, "failure to recompute the extended data square")
   223  
   224  	err = batchAddr.Commit()
   225  	require.NoError(t, err)
   226  
   227  	dah, err := da.NewDataAvailabilityHeader(eds)
   228  	require.NoError(t, err)
   229  
   230  	var errRsmt2d *rsmt2d.ErrByzantineData
   231  	err = eds.Repair(dah.RowRoots, dah.ColumnRoots)
   232  	require.ErrorAs(t, err, &errRsmt2d)
   233  
   234  	byzantine := NewErrByzantine(ctx, bServ, &dah, errRsmt2d)
   235  	var errByz *ErrByzantine
   236  	require.ErrorAs(t, byzantine, &errByz)
   237  
   238  	befp := CreateBadEncodingProof([]byte("hash"), 0, errByz)
   239  	err = befp.Validate(&header.ExtendedHeader{DAH: &dah})
   240  	require.NoError(t, err)
   241  }
   242  
   243  // namespacedBlockService wraps `BlockService` and extends the verification part
   244  // to avoid returning blocks that has out of order namespaces.
   245  type namespacedBlockService struct {
   246  	blockservice.BlockService
   247  	// the data structure that is used on the networking level, in order
   248  	// to verify the order of the namespaces
   249  	prefix *cid.Prefix
   250  }
   251  
   252  func newNamespacedBlockService() *namespacedBlockService {
   253  	sha256NamespaceFlagged := uint64(0x7701)
   254  	// register the nmt hasher to validate the order of namespaces
   255  	mhcore.Register(sha256NamespaceFlagged, func() hash.Hash {
   256  		nh := nmt.NewNmtHasher(sha256.New(), share.NamespaceSize, true)
   257  		nh.Reset()
   258  		return nh
   259  	})
   260  
   261  	bs := &namespacedBlockService{}
   262  	bs.BlockService = ipld.NewMemBlockservice()
   263  
   264  	bs.prefix = &cid.Prefix{
   265  		Version: 1,
   266  		Codec:   sha256NamespaceFlagged,
   267  		MhType:  sha256NamespaceFlagged,
   268  		// equals to NmtHasher.Size()
   269  		MhLength: sha256.New().Size() + 2*share.NamespaceSize,
   270  	}
   271  	return bs
   272  }
   273  
   274  func (n *namespacedBlockService) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) {
   275  	block, err := n.BlockService.GetBlock(ctx, c)
   276  	if err != nil {
   277  		return nil, err
   278  	}
   279  
   280  	_, err = n.prefix.Sum(block.RawData())
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	return block, nil
   285  }
   286  
   287  func (n *namespacedBlockService) GetBlocks(ctx context.Context, cids []cid.Cid) <-chan blocks.Block {
   288  	blockCh := n.BlockService.GetBlocks(ctx, cids)
   289  	resultCh := make(chan blocks.Block)
   290  
   291  	go func() {
   292  		for {
   293  			select {
   294  			case <-ctx.Done():
   295  				close(resultCh)
   296  				return
   297  			case block, ok := <-blockCh:
   298  				if !ok {
   299  					close(resultCh)
   300  					return
   301  				}
   302  				if _, err := n.prefix.Sum(block.RawData()); err != nil {
   303  					continue
   304  				}
   305  				resultCh <- block
   306  			}
   307  		}
   308  	}()
   309  	return resultCh
   310  }