github.com/ethersphere/bee/v2@v2.2.0/pkg/bmt/bmt.go (about)

     1  // Copyright 2021 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bmt
     6  
     7  import (
     8  	"encoding/binary"
     9  	"hash"
    10  
    11  	"github.com/ethersphere/bee/v2/pkg/swarm"
    12  )
    13  
    14  var _ Hash = (*Hasher)(nil)
    15  
    16  var (
    17  	zerospan    = make([]byte, 8)
    18  	zerosection = make([]byte, 64)
    19  )
    20  
    21  // Hasher is a reusable hasher for fixed maximum size chunks representing a BMT
    22  // It reuses a pool of trees for amortised memory allocation and resource control,
    23  // and supports order-agnostic concurrent segment writes and section (double segment) writes
    24  // as well as sequential read and write.
    25  //
    26  // The same hasher instance must not be called concurrently on more than one chunk.
    27  //
    28  // The same hasher instance is synchronously reusable.
    29  //
    30  // Sum gives back the tree to the pool and guaranteed to leave
    31  // the tree and itself in a state reusable for hashing a new chunk.
    32  type Hasher struct {
    33  	*Conf              // configuration
    34  	bmt    *tree       // prebuilt BMT resource for flowcontrol and proofs
    35  	size   int         // bytes written to Hasher since last Reset()
    36  	pos    int         // index of rightmost currently open segment
    37  	offset int         // offset (cursor position) within currently open segment
    38  	result chan []byte // result channel
    39  	errc   chan error  // error channel
    40  	span   []byte      // The span of the data subsumed under the chunk
    41  }
    42  
    43  // NewHasher gives back an instance of a Hasher struct
    44  func NewHasher(hasherFact func() hash.Hash) *Hasher {
    45  	conf := NewConf(hasherFact, swarm.BmtBranches, 32)
    46  
    47  	return &Hasher{
    48  		Conf:   conf,
    49  		result: make(chan []byte),
    50  		errc:   make(chan error, 1),
    51  		span:   make([]byte, SpanSize),
    52  		bmt:    newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher),
    53  	}
    54  }
    55  
    56  // Capacity returns the maximum amount of bytes that will be processed by this hasher implementation.
    57  // since BMT assumes a balanced binary tree, capacity it is always a power of 2
    58  func (h *Hasher) Capacity() int {
    59  	return h.maxSize
    60  }
    61  
    62  // LengthToSpan creates a binary data span size representation.
    63  // It is required for calculating the BMT hash.
    64  func LengthToSpan(length int64) []byte {
    65  	span := make([]byte, SpanSize)
    66  	binary.LittleEndian.PutUint64(span, uint64(length))
    67  	return span
    68  }
    69  
    70  // LengthFromSpan returns length from span.
    71  func LengthFromSpan(span []byte) uint64 {
    72  	return binary.LittleEndian.Uint64(span)
    73  }
    74  
    75  // SetHeaderInt64 sets the metadata preamble to the little endian binary representation of int64 argument for the current hash operation.
    76  func (h *Hasher) SetHeaderInt64(length int64) {
    77  	binary.LittleEndian.PutUint64(h.span, uint64(length))
    78  }
    79  
    80  // SetHeader sets the metadata preamble to the span bytes given argument for the current hash operation.
    81  func (h *Hasher) SetHeader(span []byte) {
    82  	copy(h.span, span)
    83  }
    84  
    85  // Size returns the digest size of the hash
    86  func (h *Hasher) Size() int {
    87  	return h.segmentSize
    88  }
    89  
    90  // BlockSize returns the optimal write size to the Hasher
    91  func (h *Hasher) BlockSize() int {
    92  	return 2 * h.segmentSize
    93  }
    94  
    95  // Hash returns the BMT root hash of the buffer and an error
    96  // using Hash presupposes sequential synchronous writes (io.Writer interface).
    97  func (h *Hasher) Hash(b []byte) ([]byte, error) {
    98  	if h.size == 0 {
    99  		return doHash(h.hasher(), h.span, h.zerohashes[h.depth])
   100  	}
   101  	copy(h.bmt.buffer[h.size:], zerosection)
   102  	// write the last section with final flag set to true
   103  	go h.processSection(h.pos, true)
   104  	select {
   105  	case result := <-h.result:
   106  		return doHash(h.hasher(), h.span, result)
   107  	case err := <-h.errc:
   108  		return nil, err
   109  	}
   110  }
   111  
   112  // Sum returns the BMT root hash of the buffer, unsafe version of Hash
   113  func (h *Hasher) Sum(b []byte) []byte {
   114  	s, _ := h.Hash(b)
   115  	return s
   116  }
   117  
   118  // Write calls sequentially add to the buffer to be hashed,
   119  // with every full segment calls processSection in a go routine.
   120  func (h *Hasher) Write(b []byte) (int, error) {
   121  	l := len(b)
   122  	max := h.maxSize - h.size
   123  	if l > max {
   124  		l = max
   125  	}
   126  	copy(h.bmt.buffer[h.size:], b)
   127  	secsize := 2 * h.segmentSize
   128  	from := h.size / secsize
   129  	h.offset = h.size % secsize
   130  	h.size += l
   131  	to := h.size / secsize
   132  	if l == max {
   133  		to--
   134  	}
   135  	h.pos = to
   136  	for i := from; i < to; i++ {
   137  		go h.processSection(i, false)
   138  	}
   139  	return l, nil
   140  }
   141  
   142  // Reset prepares the Hasher for reuse
   143  func (h *Hasher) Reset() {
   144  	h.pos = 0
   145  	h.size = 0
   146  	h.offset = 0
   147  	copy(h.span, zerospan)
   148  }
   149  
   150  // processSection writes the hash of i-th section into level 1 node of the BMT tree.
   151  func (h *Hasher) processSection(i int, final bool) {
   152  	secsize := 2 * h.segmentSize
   153  	offset := i * secsize
   154  	level := 1
   155  	// select the leaf node for the section
   156  	n := h.bmt.leaves[i]
   157  	isLeft := n.isLeft
   158  	hasher := n.hasher
   159  	n = n.parent
   160  	// hash the section
   161  	section, err := doHash(hasher, h.bmt.buffer[offset:offset+secsize])
   162  	if err != nil {
   163  		select {
   164  		case h.errc <- err:
   165  		default:
   166  		}
   167  		return
   168  	}
   169  	// write hash into parent node
   170  	if final {
   171  		// for the last segment use writeFinalNode
   172  		h.writeFinalNode(level, n, isLeft, section)
   173  	} else {
   174  		h.writeNode(n, isLeft, section)
   175  	}
   176  }
   177  
   178  // writeNode pushes the data to the node.
   179  // if it is the first of 2 sisters written, the routine terminates.
   180  // if it is the second, it calculates the hash and writes it
   181  // to the parent node recursively.
   182  // since hashing the parent is synchronous the same hasher can be used.
   183  func (h *Hasher) writeNode(n *node, isLeft bool, s []byte) {
   184  	var err error
   185  	level := 1
   186  	for {
   187  		// at the root of the bmt just write the result to the result channel
   188  		if n == nil {
   189  			h.result <- s
   190  			return
   191  		}
   192  		// otherwise assign child hash to left or right segment
   193  		if isLeft {
   194  			n.left = s
   195  		} else {
   196  			n.right = s
   197  		}
   198  		// the child-thread first arriving will terminate
   199  		if n.toggle() {
   200  			return
   201  		}
   202  		// the thread coming second now can be sure both left and right children are written
   203  		// so it calculates the hash of left|right and pushes it to the parent
   204  		s, err = doHash(n.hasher, n.left, n.right)
   205  		if err != nil {
   206  			select {
   207  			case h.errc <- err:
   208  			default:
   209  			}
   210  			return
   211  		}
   212  		isLeft = n.isLeft
   213  		n = n.parent
   214  		level++
   215  	}
   216  }
   217  
   218  // writeFinalNode is following the path starting from the final datasegment to the
   219  // BMT root via parents.
   220  // For unbalanced trees it fills in the missing right sister nodes using
   221  // the pool's lookup table for BMT subtree root hashes for all-zero sections.
   222  // Otherwise behaves like `writeNode`.
   223  func (h *Hasher) writeFinalNode(level int, n *node, isLeft bool, s []byte) {
   224  	var err error
   225  	for {
   226  		// at the root of the bmt just write the result to the result channel
   227  		if n == nil {
   228  			if s != nil {
   229  				h.result <- s
   230  			}
   231  			return
   232  		}
   233  		var noHash bool
   234  		if isLeft {
   235  			// coming from left sister branch
   236  			// when the final section's path is going via left child node
   237  			// we include an all-zero subtree hash for the right level and toggle the node.
   238  			n.right = h.zerohashes[level]
   239  			if s != nil {
   240  				n.left = s
   241  				// if a left final node carries a hash, it must be the first (and only thread)
   242  				// so the toggle is already in passive state no need no call
   243  				// yet thread needs to carry on pushing hash to parent
   244  				noHash = false
   245  			} else {
   246  				// if again first thread then propagate nil and calculate no hash
   247  				noHash = n.toggle()
   248  			}
   249  		} else {
   250  			// right sister branch
   251  			if s != nil {
   252  				// if hash was pushed from right child node, write right segment change state
   253  				n.right = s
   254  				// if toggle is true, we arrived first so no hashing just push nil to parent
   255  				noHash = n.toggle()
   256  			} else {
   257  				// if s is nil, then thread arrived first at previous node and here there will be two,
   258  				// so no need to do anything and keep s = nil for parent
   259  				noHash = true
   260  			}
   261  		}
   262  		// the child-thread first arriving will just continue resetting s to nil
   263  		// the second thread now can be sure both left and right children are written
   264  		// it calculates the hash of left|right and pushes it to the parent
   265  		if noHash {
   266  			s = nil
   267  		} else {
   268  			s, err = doHash(n.hasher, n.left, n.right)
   269  			if err != nil {
   270  				select {
   271  				case h.errc <- err:
   272  				default:
   273  				}
   274  				return
   275  			}
   276  		}
   277  		// iterate to parent
   278  		isLeft = n.isLeft
   279  		n = n.parent
   280  		level++
   281  	}
   282  }
   283  
   284  // calculates the Keccak256 SHA3 hash of the data
   285  func sha3hash(data ...[]byte) ([]byte, error) {
   286  	return doHash(swarm.NewHasher(), data...)
   287  }
   288  
   289  // calculates Hash of the data
   290  func doHash(h hash.Hash, data ...[]byte) ([]byte, error) {
   291  	h.Reset()
   292  	for _, v := range data {
   293  		if _, err := h.Write(v); err != nil {
   294  			return nil, err
   295  		}
   296  	}
   297  	return h.Sum(nil), nil
   298  }