github.com/sunrise-zone/sunrise-node@v0.13.1-sr2/share/eds/byzantine/bad_encoding_test.go (about) 1 package byzantine 2 3 import ( 4 "context" 5 "crypto/sha256" 6 "hash" 7 "testing" 8 "time" 9 10 core "github.com/cometbft/cometbft/types" 11 "github.com/ipfs/boxo/blockservice" 12 blocks "github.com/ipfs/go-block-format" 13 "github.com/ipfs/go-cid" 14 mhcore "github.com/multiformats/go-multihash/core" 15 "github.com/stretchr/testify/require" 16 17 "github.com/celestiaorg/nmt" 18 "github.com/celestiaorg/rsmt2d" 19 "github.com/sunrise-zone/sunrise-app/pkg/da" 20 "github.com/sunrise-zone/sunrise-app/test/util/malicious" 21 22 "github.com/sunrise-zone/sunrise-node/header" 23 "github.com/sunrise-zone/sunrise-node/share" 24 "github.com/sunrise-zone/sunrise-node/share/eds/edstest" 25 "github.com/sunrise-zone/sunrise-node/share/ipld" 26 "github.com/sunrise-zone/sunrise-node/share/sharetest" 27 ) 28 29 func TestBEFP_Validate(t *testing.T) { 30 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 31 defer t.Cleanup(cancel) 32 bServ := ipld.NewMemBlockservice() 33 34 square := edstest.RandByzantineEDS(t, 16) 35 dah, err := da.NewDataAvailabilityHeader(square) 36 require.NoError(t, err) 37 err = ipld.ImportEDS(ctx, square, bServ) 38 require.NoError(t, err) 39 40 var errRsmt2d *rsmt2d.ErrByzantineData 41 err = square.Repair(dah.RowRoots, dah.ColumnRoots) 42 require.ErrorAs(t, err, &errRsmt2d) 43 44 byzantine := NewErrByzantine(ctx, bServ, &dah, errRsmt2d) 45 var errByz *ErrByzantine 46 require.ErrorAs(t, byzantine, &errByz) 47 48 proof := CreateBadEncodingProof([]byte("hash"), 0, errByz) 49 befp, ok := proof.(*BadEncodingProof) 50 require.True(t, ok) 51 var test = []struct { 52 name string 53 prepareFn func() error 54 expectedResult func(error) 55 }{ 56 { 57 name: "valid BEFP", 58 prepareFn: func() error { 59 return proof.Validate(&header.ExtendedHeader{DAH: &dah}) 60 }, 61 expectedResult: func(err error) { 62 require.NoError(t, err) 63 }, 64 }, 65 { 66 name: "invalid BEFP for valid header", 67 prepareFn: func() error { 68 validSquare := edstest.RandEDS(t, 2) 69 validDah, err := da.NewDataAvailabilityHeader(validSquare) 70 require.NoError(t, err) 71 err = ipld.ImportEDS(ctx, validSquare, bServ) 72 require.NoError(t, err) 73 validShares := validSquare.Flattened() 74 errInvalidByz := NewErrByzantine(ctx, bServ, &validDah, 75 &rsmt2d.ErrByzantineData{ 76 Axis: rsmt2d.Row, 77 Index: 0, 78 Shares: validShares[0:4], 79 }, 80 ) 81 var errInvalid *ErrByzantine 82 require.ErrorAs(t, errInvalidByz, &errInvalid) 83 invalidBefp := CreateBadEncodingProof([]byte("hash"), 0, errInvalid) 84 return invalidBefp.Validate(&header.ExtendedHeader{DAH: &validDah}) 85 }, 86 expectedResult: func(err error) { 87 require.ErrorIs(t, err, errNMTTreeRootsMatch) 88 }, 89 }, 90 { 91 name: "incorrect share with Proof", 92 prepareFn: func() error { 93 // break the first shareWithProof to test negative case 94 sh := sharetest.RandShares(t, 2) 95 nmtProof := nmt.NewInclusionProof(0, 1, nil, false) 96 befp.Shares[0] = &ShareWithProof{sh[0], &nmtProof} 97 return proof.Validate(&header.ExtendedHeader{DAH: &dah}) 98 }, 99 expectedResult: func(err error) { 100 require.ErrorIs(t, err, errIncorrectShare) 101 }, 102 }, 103 { 104 name: "invalid amount of shares", 105 prepareFn: func() error { 106 befp.Shares = befp.Shares[0 : len(befp.Shares)/2] 107 return proof.Validate(&header.ExtendedHeader{DAH: &dah}) 108 }, 109 expectedResult: func(err error) { 110 require.ErrorIs(t, err, errIncorrectAmountOfShares) 111 }, 112 }, 113 { 114 name: "not enough shares to recompute the root", 115 prepareFn: func() error { 116 befp.Shares[0] = nil 117 return proof.Validate(&header.ExtendedHeader{DAH: &dah}) 118 }, 119 expectedResult: func(err error) { 120 require.ErrorIs(t, err, errIncorrectAmountOfShares) 121 }, 122 }, 123 { 124 name: "index out of bounds", 125 prepareFn: func() error { 126 befp.Index = 100 127 return proof.Validate(&header.ExtendedHeader{DAH: &dah}) 128 }, 129 expectedResult: func(err error) { 130 require.ErrorIs(t, err, errIncorrectIndex) 131 }, 132 }, 133 { 134 name: "heights mismatch", 135 prepareFn: func() error { 136 return proof.Validate(&header.ExtendedHeader{ 137 RawHeader: core.Header{ 138 Height: 42, 139 }, 140 DAH: &dah, 141 }) 142 }, 143 expectedResult: func(err error) { 144 require.ErrorIs(t, err, errHeightMismatch) 145 }, 146 }, 147 } 148 149 for _, tt := range test { 150 t.Run(tt.name, func(t *testing.T) { 151 err = tt.prepareFn() 152 tt.expectedResult(err) 153 }) 154 } 155 } 156 157 // TestIncorrectBadEncodingFraudProof asserts that BEFP is not generated for the correct data 158 func TestIncorrectBadEncodingFraudProof(t *testing.T) { 159 ctx, cancel := context.WithCancel(context.Background()) 160 defer cancel() 161 162 bServ := ipld.NewMemBlockservice() 163 164 squareSize := 8 165 shares := sharetest.RandShares(t, squareSize*squareSize) 166 167 eds, err := ipld.AddShares(ctx, shares, bServ) 168 require.NoError(t, err) 169 170 dah, err := share.NewRoot(eds) 171 require.NoError(t, err) 172 173 // get an arbitrary row 174 row := uint(squareSize / 2) 175 rowShares := eds.Row(row) 176 rowRoot := dah.RowRoots[row] 177 178 shareProofs, err := GetProofsForShares(ctx, bServ, ipld.MustCidFromNamespacedSha256(rowRoot), rowShares) 179 require.NoError(t, err) 180 181 // create a fake error for data that was encoded correctly 182 fakeError := ErrByzantine{ 183 Index: uint32(row), 184 Shares: shareProofs, 185 Axis: rsmt2d.Row, 186 } 187 188 h := &header.ExtendedHeader{ 189 RawHeader: core.Header{ 190 Height: 420, 191 }, 192 DAH: dah, 193 Commit: &core.Commit{ 194 BlockID: core.BlockID{ 195 Hash: []byte("made up hash"), 196 }, 197 }, 198 } 199 200 proof := CreateBadEncodingProof(h.Hash(), h.Height(), &fakeError) 201 err = proof.Validate(h) 202 require.Error(t, err) 203 } 204 205 func TestBEFP_ValidateOutOfOrderShares(t *testing.T) { 206 ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) 207 t.Cleanup(cancel) 208 209 size := 4 210 eds := edstest.RandEDS(t, size) 211 212 shares := eds.Flattened() 213 shares[0], shares[4] = shares[4], shares[0] // corrupting eds 214 215 bServ := newNamespacedBlockService() 216 batchAddr := ipld.NewNmtNodeAdder(ctx, bServ, ipld.MaxSizeBatchOption(size*2)) 217 218 eds, err := rsmt2d.ImportExtendedDataSquare(shares, 219 share.DefaultRSMT2DCodec(), 220 malicious.NewConstructor(uint64(size), nmt.NodeVisitor(batchAddr.Visit)), 221 ) 222 require.NoError(t, err, "failure to recompute the extended data square") 223 224 err = batchAddr.Commit() 225 require.NoError(t, err) 226 227 dah, err := da.NewDataAvailabilityHeader(eds) 228 require.NoError(t, err) 229 230 var errRsmt2d *rsmt2d.ErrByzantineData 231 err = eds.Repair(dah.RowRoots, dah.ColumnRoots) 232 require.ErrorAs(t, err, &errRsmt2d) 233 234 byzantine := NewErrByzantine(ctx, bServ, &dah, errRsmt2d) 235 var errByz *ErrByzantine 236 require.ErrorAs(t, byzantine, &errByz) 237 238 befp := CreateBadEncodingProof([]byte("hash"), 0, errByz) 239 err = befp.Validate(&header.ExtendedHeader{DAH: &dah}) 240 require.NoError(t, err) 241 } 242 243 // namespacedBlockService wraps `BlockService` and extends the verification part 244 // to avoid returning blocks that has out of order namespaces. 245 type namespacedBlockService struct { 246 blockservice.BlockService 247 // the data structure that is used on the networking level, in order 248 // to verify the order of the namespaces 249 prefix *cid.Prefix 250 } 251 252 func newNamespacedBlockService() *namespacedBlockService { 253 sha256NamespaceFlagged := uint64(0x7701) 254 // register the nmt hasher to validate the order of namespaces 255 mhcore.Register(sha256NamespaceFlagged, func() hash.Hash { 256 nh := nmt.NewNmtHasher(sha256.New(), share.NamespaceSize, true) 257 nh.Reset() 258 return nh 259 }) 260 261 bs := &namespacedBlockService{} 262 bs.BlockService = ipld.NewMemBlockservice() 263 264 bs.prefix = &cid.Prefix{ 265 Version: 1, 266 Codec: sha256NamespaceFlagged, 267 MhType: sha256NamespaceFlagged, 268 // equals to NmtHasher.Size() 269 MhLength: sha256.New().Size() + 2*share.NamespaceSize, 270 } 271 return bs 272 } 273 274 func (n *namespacedBlockService) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) { 275 block, err := n.BlockService.GetBlock(ctx, c) 276 if err != nil { 277 return nil, err 278 } 279 280 _, err = n.prefix.Sum(block.RawData()) 281 if err != nil { 282 return nil, err 283 } 284 return block, nil 285 } 286 287 func (n *namespacedBlockService) GetBlocks(ctx context.Context, cids []cid.Cid) <-chan blocks.Block { 288 blockCh := n.BlockService.GetBlocks(ctx, cids) 289 resultCh := make(chan blocks.Block) 290 291 go func() { 292 for { 293 select { 294 case <-ctx.Done(): 295 close(resultCh) 296 return 297 case block, ok := <-blockCh: 298 if !ok { 299 close(resultCh) 300 return 301 } 302 if _, err := n.prefix.Sum(block.RawData()); err != nil { 303 continue 304 } 305 resultCh <- block 306 } 307 } 308 }() 309 return resultCh 310 }