github.com/sunrise-zone/sunrise-node@v0.13.1-sr2/share/ipld/nmt_adder.go (about)

     1  package ipld
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/ipfs/boxo/blockservice"
     9  	"github.com/ipfs/boxo/ipld/merkledag"
    10  	"github.com/ipfs/go-cid"
    11  	ipld "github.com/ipfs/go-ipld-format"
    12  
    13  	"github.com/celestiaorg/nmt"
    14  )
    15  
    16  type ctxKey int
    17  
    18  const (
    19  	proofsAdderKey ctxKey = iota
    20  )
    21  
    22  // NmtNodeAdder adds ipld.Nodes to the underlying ipld.Batch if it is inserted
    23  // into a nmt tree.
    24  type NmtNodeAdder struct {
    25  	// lock protects Batch, Set and error from parallel writes / reads
    26  	lock   sync.Mutex
    27  	ctx    context.Context
    28  	add    *ipld.Batch
    29  	leaves *cid.Set
    30  	err    error
    31  }
    32  
    33  // NewNmtNodeAdder returns a new NmtNodeAdder with the provided context and
    34  // batch. Note that the context provided should have a timeout
    35  // It is not thread-safe.
    36  func NewNmtNodeAdder(ctx context.Context, bs blockservice.BlockService, opts ...ipld.BatchOption) *NmtNodeAdder {
    37  	return &NmtNodeAdder{
    38  		add:    ipld.NewBatch(ctx, merkledag.NewDAGService(bs), opts...),
    39  		ctx:    ctx,
    40  		leaves: cid.NewSet(),
    41  	}
    42  }
    43  
    44  // Visit is a NodeVisitor that can be used during the creation of a new NMT to
    45  // create and add ipld.Nodes to the Batch while computing the root of the NMT.
    46  func (n *NmtNodeAdder) Visit(hash []byte, children ...[]byte) {
    47  	n.lock.Lock()
    48  	defer n.lock.Unlock()
    49  
    50  	if n.err != nil {
    51  		return // protect from further visits if there is an error
    52  	}
    53  	id := MustCidFromNamespacedSha256(hash)
    54  	switch len(children) {
    55  	case 1:
    56  		if n.leaves.Visit(id) {
    57  			n.err = n.add.Add(n.ctx, newNMTNode(id, children[0]))
    58  		}
    59  	case 2:
    60  		n.err = n.add.Add(n.ctx, newNMTNode(id, append(children[0], children[1]...)))
    61  	default:
    62  		panic("expected a binary tree")
    63  	}
    64  }
    65  
    66  // Commit checks for errors happened during Visit and if absent commits data to inner Batch.
    67  func (n *NmtNodeAdder) Commit() error {
    68  	n.lock.Lock()
    69  	defer n.lock.Unlock()
    70  
    71  	if n.err != nil {
    72  		return fmt.Errorf("before batch commit: %w", n.err)
    73  	}
    74  
    75  	n.err = n.add.Commit()
    76  	if n.err != nil {
    77  		return fmt.Errorf("after batch commit: %w", n.err)
    78  	}
    79  	return nil
    80  }
    81  
    82  // MaxSizeBatchOption sets the maximum amount of buffered data before writing
    83  // blocks.
    84  func MaxSizeBatchOption(size int) ipld.BatchOption {
    85  	return ipld.MaxSizeBatchOption(BatchSize(size))
    86  }
    87  
    88  // BatchSize calculates the amount of nodes that are generated from block of 'squareSizes'
    89  // to be batched in one write.
    90  func BatchSize(squareSize int) int {
    91  	// (squareSize*2-1) - amount of nodes in a generated binary tree
    92  	// squareSize*2 - the total number of trees, both for rows and cols
    93  	// (squareSize*squareSize) - all the shares
    94  	//
    95  	// Note that while our IPLD tree looks like this:
    96  	// ---X
    97  	// -X---X
    98  	// X-X-X-X
    99  	// here we count leaves only once: the CIDs are the same for columns and rows
   100  	// and for the last two layers as well:
   101  	return (squareSize*2-1)*squareSize*2 - (squareSize * squareSize)
   102  }
   103  
   104  // ProofsAdder is used to collect proof nodes, while traversing merkle tree
   105  type ProofsAdder struct {
   106  	lock   sync.RWMutex
   107  	proofs map[cid.Cid][]byte
   108  }
   109  
   110  // NewProofsAdder creates new instance of ProofsAdder.
   111  func NewProofsAdder(squareSize int) *ProofsAdder {
   112  	return &ProofsAdder{
   113  		// preallocate map to fit all inner nodes for given square size
   114  		proofs: make(map[cid.Cid][]byte, innerNodesAmount(squareSize)),
   115  	}
   116  }
   117  
   118  // CtxWithProofsAdder creates context, that will contain ProofsAdder. If context is leaked to
   119  // another go-routine, proofs will be not collected by gc. To prevent it, use Purge after Proofs
   120  // are collected from adder, to preemptively release memory allocated for proofs.
   121  func CtxWithProofsAdder(ctx context.Context, adder *ProofsAdder) context.Context {
   122  	return context.WithValue(ctx, proofsAdderKey, adder)
   123  }
   124  
   125  // ProofsAdderFromCtx extracts ProofsAdder from context
   126  func ProofsAdderFromCtx(ctx context.Context) *ProofsAdder {
   127  	val := ctx.Value(proofsAdderKey)
   128  	adder, ok := val.(*ProofsAdder)
   129  	if !ok || adder == nil {
   130  		return nil
   131  	}
   132  	return adder
   133  }
   134  
   135  // Proofs returns proofs collected by ProofsAdder
   136  func (a *ProofsAdder) Proofs() map[cid.Cid][]byte {
   137  	if a == nil {
   138  		return nil
   139  	}
   140  
   141  	a.lock.RLock()
   142  	defer a.lock.RUnlock()
   143  	return a.proofs
   144  }
   145  
   146  // VisitFn returns NodeVisitorFn, that will collect proof nodes while traversing merkle tree.
   147  func (a *ProofsAdder) VisitFn() nmt.NodeVisitorFn {
   148  	if a == nil {
   149  		return nil
   150  	}
   151  
   152  	a.lock.RLock()
   153  	defer a.lock.RUnlock()
   154  
   155  	// proofs are already collected, don't collect second time
   156  	if len(a.proofs) > 0 {
   157  		return nil
   158  	}
   159  	return a.visitInnerNodes
   160  }
   161  
   162  // Purge removed proofs from ProofsAdder allowing GC to collect the memory
   163  func (a *ProofsAdder) Purge() {
   164  	if a == nil {
   165  		return
   166  	}
   167  
   168  	a.lock.Lock()
   169  	defer a.lock.Unlock()
   170  
   171  	a.proofs = nil
   172  }
   173  
   174  func (a *ProofsAdder) visitInnerNodes(hash []byte, children ...[]byte) {
   175  	switch len(children) {
   176  	case 1:
   177  		break
   178  	case 2:
   179  		id := MustCidFromNamespacedSha256(hash)
   180  		a.addProof(id, append(children[0], children[1]...))
   181  	default:
   182  		panic("expected a binary tree")
   183  	}
   184  }
   185  
   186  func (a *ProofsAdder) addProof(id cid.Cid, proof []byte) {
   187  	a.lock.Lock()
   188  	defer a.lock.Unlock()
   189  	a.proofs[id] = proof
   190  }
   191  
   192  // innerNodesAmount return amount of inner nodes in eds with given size
   193  func innerNodesAmount(squareSize int) int {
   194  	return 2 * (squareSize - 1) * squareSize
   195  }