github.com/linapex/ethereum-dpos-chinese@v0.0.0-20190316121959-b78b3a4a1ece/swarm/bmt/bmt.go (about)

     1  
     2  //<developer>
     3  //    <name>linapex 曹一峰</name>
     4  //    <email>linapex@163.com</email>
     5  //    <wx>superexc</wx>
     6  //    <qqgroup>128148617</qqgroup>
     7  //    <url>https://jsq.ink</url>
     8  //    <role>pku engineer</role>
     9  //    <date>2019-03-16 12:09:47</date>
    10  //</624342670495453184>
    11  
    12  //
    13  //
    14  //
    15  //
    16  //
    17  //
    18  //
    19  //
    20  //
    21  //
    22  //
    23  //
    24  //
    25  //
    26  //
    27  
    28  //
    29  package bmt
    30  
    31  import (
    32  	"fmt"
    33  	"hash"
    34  	"strings"
    35  	"sync"
    36  	"sync/atomic"
    37  )
    38  
    39  /*
    40  
    41  
    42  
    43  
    44  
    45  
    46  
    47  
    48  
    49  
    50  
    51  
    52  
    53  
    54  
    55  
    56  
    57    
    58  
    59    
    60  
    61    
    62   
    63   
    64   
    65   
    66  */
    67  
    68  
    69  const (
    70  //
    71  //
    72  	PoolSize = 8
    73  )
    74  
    75  //
    76  //
    77  type BaseHasherFunc func() hash.Hash
    78  
    79  //
    80  //
    81  //
    82  //
    83  //
    84  //
    85  //
    86  //
    87  //
    88  //
    89  type Hasher struct {
    90  pool *TreePool //
    91  bmt  *tree     //
    92  }
    93  
    94  //
    95  //
    96  func New(p *TreePool) *Hasher {
    97  	return &Hasher{
    98  		pool: p,
    99  	}
   100  }
   101  
   102  //
   103  //
   104  //
   105  type TreePool struct {
   106  	lock         sync.Mutex
   107  c            chan *tree     //
   108  hasher       BaseHasherFunc //
   109  SegmentSize  int            //
   110  SegmentCount int            //
   111  Capacity     int            //
   112  Depth        int            //
   113  Size         int            //
   114  count        int            //
   115  zerohashes   [][]byte       //
   116  }
   117  
   118  //
   119  //
   120  func NewTreePool(hasher BaseHasherFunc, segmentCount, capacity int) *TreePool {
   121  //
   122  	depth := calculateDepthFor(segmentCount)
   123  	segmentSize := hasher().Size()
   124  	zerohashes := make([][]byte, depth+1)
   125  	zeros := make([]byte, segmentSize)
   126  	zerohashes[0] = zeros
   127  	h := hasher()
   128  	for i := 1; i < depth+1; i++ {
   129  		zeros = doSum(h, nil, zeros, zeros)
   130  		zerohashes[i] = zeros
   131  	}
   132  	return &TreePool{
   133  		c:            make(chan *tree, capacity),
   134  		hasher:       hasher,
   135  		SegmentSize:  segmentSize,
   136  		SegmentCount: segmentCount,
   137  		Capacity:     capacity,
   138  		Size:         segmentCount * segmentSize,
   139  		Depth:        depth,
   140  		zerohashes:   zerohashes,
   141  	}
   142  }
   143  
   144  //
   145  func (p *TreePool) Drain(n int) {
   146  	p.lock.Lock()
   147  	defer p.lock.Unlock()
   148  	for len(p.c) > n {
   149  		<-p.c
   150  		p.count--
   151  	}
   152  }
   153  
   154  //
   155  //
   156  //
   157  func (p *TreePool) reserve() *tree {
   158  	p.lock.Lock()
   159  	defer p.lock.Unlock()
   160  	var t *tree
   161  	if p.count == p.Capacity {
   162  		return <-p.c
   163  	}
   164  	select {
   165  	case t = <-p.c:
   166  	default:
   167  		t = newTree(p.SegmentSize, p.Depth, p.hasher)
   168  		p.count++
   169  	}
   170  	return t
   171  }
   172  
   173  //
   174  //
   175  func (p *TreePool) release(t *tree) {
   176  p.c <- t //
   177  }
   178  
   179  //
   180  //
   181  //
   182  //
   183  type tree struct {
   184  leaves  []*node     //
   185  cursor  int         //
   186  offset  int         //
   187  section []byte      //
   188  result  chan []byte //
   189  span    []byte      //
   190  }
   191  
   192  //
   193  type node struct {
   194  isLeft      bool      //
   195  parent      *node     //
   196  state       int32     //
   197  left, right []byte    //
   198  hasher      hash.Hash //
   199  }
   200  
   201  //
   202  func newNode(index int, parent *node, hasher hash.Hash) *node {
   203  	return &node{
   204  		parent: parent,
   205  		isLeft: index%2 == 0,
   206  		hasher: hasher,
   207  	}
   208  }
   209  
   210  //
   211  func (t *tree) draw(hash []byte) string {
   212  	var left, right []string
   213  	var anc []*node
   214  	for i, n := range t.leaves {
   215  		left = append(left, fmt.Sprintf("%v", hashstr(n.left)))
   216  		if i%2 == 0 {
   217  			anc = append(anc, n.parent)
   218  		}
   219  		right = append(right, fmt.Sprintf("%v", hashstr(n.right)))
   220  	}
   221  	anc = t.leaves
   222  	var hashes [][]string
   223  	for l := 0; len(anc) > 0; l++ {
   224  		var nodes []*node
   225  		hash := []string{""}
   226  		for i, n := range anc {
   227  			hash = append(hash, fmt.Sprintf("%v|%v", hashstr(n.left), hashstr(n.right)))
   228  			if i%2 == 0 && n.parent != nil {
   229  				nodes = append(nodes, n.parent)
   230  			}
   231  		}
   232  		hash = append(hash, "")
   233  		hashes = append(hashes, hash)
   234  		anc = nodes
   235  	}
   236  	hashes = append(hashes, []string{"", fmt.Sprintf("%v", hashstr(hash)), ""})
   237  	total := 60
   238  	del := "                             "
   239  	var rows []string
   240  	for i := len(hashes) - 1; i >= 0; i-- {
   241  		var textlen int
   242  		hash := hashes[i]
   243  		for _, s := range hash {
   244  			textlen += len(s)
   245  		}
   246  		if total < textlen {
   247  			total = textlen + len(hash)
   248  		}
   249  		delsize := (total - textlen) / (len(hash) - 1)
   250  		if delsize > len(del) {
   251  			delsize = len(del)
   252  		}
   253  		row := fmt.Sprintf("%v: %v", len(hashes)-i-1, strings.Join(hash, del[:delsize]))
   254  		rows = append(rows, row)
   255  
   256  	}
   257  	rows = append(rows, strings.Join(left, "  "))
   258  	rows = append(rows, strings.Join(right, "  "))
   259  	return strings.Join(rows, "\n") + "\n"
   260  }
   261  
   262  //
   263  //
   264  func newTree(segmentSize, depth int, hashfunc func() hash.Hash) *tree {
   265  	n := newNode(0, nil, hashfunc())
   266  	prevlevel := []*node{n}
   267  //
   268  //
   269  	count := 2
   270  	for level := depth - 2; level >= 0; level-- {
   271  		nodes := make([]*node, count)
   272  		for i := 0; i < count; i++ {
   273  			parent := prevlevel[i/2]
   274  			var hasher hash.Hash
   275  			if level == 0 {
   276  				hasher = hashfunc()
   277  			}
   278  			nodes[i] = newNode(i, parent, hasher)
   279  		}
   280  		prevlevel = nodes
   281  		count *= 2
   282  	}
   283  //
   284  	return &tree{
   285  		leaves:  prevlevel,
   286  		result:  make(chan []byte),
   287  		section: make([]byte, 2*segmentSize),
   288  	}
   289  }
   290  
   291  //
   292  
   293  //
   294  func (h *Hasher) Size() int {
   295  	return h.pool.SegmentSize
   296  }
   297  
   298  //
   299  func (h *Hasher) BlockSize() int {
   300  	return 2 * h.pool.SegmentSize
   301  }
   302  
   303  //
   304  //
   305  //
   306  //
   307  //
   308  func (h *Hasher) Sum(b []byte) (s []byte) {
   309  	t := h.getTree()
   310  //
   311  	go h.writeSection(t.cursor, t.section, true, true)
   312  //
   313  	s = <-t.result
   314  	span := t.span
   315  //
   316  	h.releaseTree()
   317  //
   318  	if len(span) == 0 {
   319  		return append(b, s...)
   320  	}
   321  	return doSum(h.pool.hasher(), b, span, s)
   322  }
   323  
   324  //
   325  
   326  //
   327  //
   328  func (h *Hasher) Write(b []byte) (int, error) {
   329  	l := len(b)
   330  	if l == 0 || l > h.pool.Size {
   331  		return 0, nil
   332  	}
   333  	t := h.getTree()
   334  	secsize := 2 * h.pool.SegmentSize
   335  //
   336  	smax := secsize - t.offset
   337  //
   338  	if t.offset < secsize {
   339  //
   340  		copy(t.section[t.offset:], b)
   341  //
   342  //
   343  		if smax == 0 {
   344  			smax = secsize
   345  		}
   346  		if l <= smax {
   347  			t.offset += l
   348  			return l, nil
   349  		}
   350  	} else {
   351  //
   352  		if t.cursor == h.pool.SegmentCount*2 {
   353  			return 0, nil
   354  		}
   355  	}
   356  //
   357  	for smax < l {
   358  //
   359  		go h.writeSection(t.cursor, t.section, true, false)
   360  //
   361  		t.section = make([]byte, secsize)
   362  //
   363  		copy(t.section, b[smax:])
   364  //
   365  		t.cursor++
   366  //
   367  		smax += secsize
   368  	}
   369  	t.offset = l - smax + secsize
   370  	return l, nil
   371  }
   372  
   373  //
   374  func (h *Hasher) Reset() {
   375  	h.releaseTree()
   376  }
   377  
   378  //
   379  
   380  //
   381  //
   382  //
   383  func (h *Hasher) ResetWithLength(span []byte) {
   384  	h.Reset()
   385  	h.getTree().span = span
   386  }
   387  
   388  //
   389  //
   390  func (h *Hasher) releaseTree() {
   391  	t := h.bmt
   392  	if t == nil {
   393  		return
   394  	}
   395  	h.bmt = nil
   396  	go func() {
   397  		t.cursor = 0
   398  		t.offset = 0
   399  		t.span = nil
   400  		t.section = make([]byte, h.pool.SegmentSize*2)
   401  		select {
   402  		case <-t.result:
   403  		default:
   404  		}
   405  		h.pool.release(t)
   406  	}()
   407  }
   408  
   409  //
   410  func (h *Hasher) NewAsyncWriter(double bool) *AsyncHasher {
   411  	secsize := h.pool.SegmentSize
   412  	if double {
   413  		secsize *= 2
   414  	}
   415  	write := func(i int, section []byte, final bool) {
   416  		h.writeSection(i, section, double, final)
   417  	}
   418  	return &AsyncHasher{
   419  		Hasher:  h,
   420  		double:  double,
   421  		secsize: secsize,
   422  		write:   write,
   423  	}
   424  }
   425  
   426  //
   427  type SectionWriter interface {
   428  Reset()                                       //
   429  Write(index int, data []byte)                 //
   430  Sum(b []byte, length int, span []byte) []byte //
   431  SectionSize() int                             //
   432  }
   433  
   434  //
   435  //
   436  //
   437  //
   438  //
   439  //
   440  //
   441  //
   442  //
   443  //
   444  //
   445  //
   446  //
   447  //
   448  type AsyncHasher struct {
   449  *Hasher            //
   450  mtx     sync.Mutex //
   451  double  bool       //
   452  secsize int        //
   453  	write   func(i int, section []byte, final bool)
   454  }
   455  
   456  //
   457  
   458  //
   459  func (sw *AsyncHasher) SectionSize() int {
   460  	return sw.secsize
   461  }
   462  
   463  //
   464  //
   465  //
   466  func (sw *AsyncHasher) Write(i int, section []byte) {
   467  	sw.mtx.Lock()
   468  	defer sw.mtx.Unlock()
   469  	t := sw.getTree()
   470  //
   471  //
   472  	if i < t.cursor {
   473  //
   474  		go sw.write(i, section, false)
   475  		return
   476  	}
   477  //
   478  	if t.offset > 0 {
   479  		if i == t.cursor {
   480  //
   481  //
   482  			t.section = make([]byte, sw.secsize)
   483  			copy(t.section, section)
   484  			go sw.write(i, t.section, true)
   485  			return
   486  		}
   487  //
   488  		go sw.write(t.cursor, t.section, false)
   489  	}
   490  //
   491  //
   492  	t.cursor = i
   493  	t.offset = i*sw.secsize + 1
   494  	t.section = make([]byte, sw.secsize)
   495  	copy(t.section, section)
   496  }
   497  
   498  //
   499  //
   500  //
   501  //
   502  //
   503  //
   504  //
   505  //
   506  //
   507  func (sw *AsyncHasher) Sum(b []byte, length int, meta []byte) (s []byte) {
   508  	sw.mtx.Lock()
   509  	t := sw.getTree()
   510  	if length == 0 {
   511  		sw.mtx.Unlock()
   512  		s = sw.pool.zerohashes[sw.pool.Depth]
   513  	} else {
   514  //
   515  //
   516  		maxsec := (length - 1) / sw.secsize
   517  		if t.offset > 0 {
   518  			go sw.write(t.cursor, t.section, maxsec == t.cursor)
   519  		}
   520  //
   521  		t.cursor = maxsec
   522  		t.offset = length
   523  		result := t.result
   524  		sw.mtx.Unlock()
   525  //
   526  		s = <-result
   527  	}
   528  //
   529  	sw.releaseTree()
   530  //
   531  	if len(meta) == 0 {
   532  		return append(b, s...)
   533  	}
   534  //
   535  	return doSum(sw.pool.hasher(), b, meta, s)
   536  }
   537  
   538  //
   539  func (h *Hasher) writeSection(i int, section []byte, double bool, final bool) {
   540  //
   541  	var n *node
   542  	var isLeft bool
   543  	var hasher hash.Hash
   544  	var level int
   545  	t := h.getTree()
   546  	if double {
   547  		level++
   548  		n = t.leaves[i]
   549  		hasher = n.hasher
   550  		isLeft = n.isLeft
   551  		n = n.parent
   552  //
   553  		section = doSum(hasher, nil, section)
   554  	} else {
   555  		n = t.leaves[i/2]
   556  		hasher = n.hasher
   557  		isLeft = i%2 == 0
   558  	}
   559  //
   560  	if final {
   561  //
   562  		h.writeFinalNode(level, n, hasher, isLeft, section)
   563  	} else {
   564  		h.writeNode(n, hasher, isLeft, section)
   565  	}
   566  }
   567  
   568  //
   569  //
   570  //
   571  //
   572  //
   573  func (h *Hasher) writeNode(n *node, bh hash.Hash, isLeft bool, s []byte) {
   574  	level := 1
   575  	for {
   576  //
   577  		if n == nil {
   578  			h.getTree().result <- s
   579  			return
   580  		}
   581  //
   582  		if isLeft {
   583  			n.left = s
   584  		} else {
   585  			n.right = s
   586  		}
   587  //
   588  		if n.toggle() {
   589  			return
   590  		}
   591  //
   592  //
   593  		s = doSum(bh, nil, n.left, n.right)
   594  		isLeft = n.isLeft
   595  		n = n.parent
   596  		level++
   597  	}
   598  }
   599  
   600  //
   601  //
   602  //
   603  //
   604  //
   605  func (h *Hasher) writeFinalNode(level int, n *node, bh hash.Hash, isLeft bool, s []byte) {
   606  
   607  	for {
   608  //
   609  		if n == nil {
   610  			if s != nil {
   611  				h.getTree().result <- s
   612  			}
   613  			return
   614  		}
   615  		var noHash bool
   616  		if isLeft {
   617  //
   618  //
   619  //
   620  			n.right = h.pool.zerohashes[level]
   621  			if s != nil {
   622  				n.left = s
   623  //
   624  //
   625  //
   626  				noHash = false
   627  			} else {
   628  //
   629  				noHash = n.toggle()
   630  			}
   631  		} else {
   632  //
   633  			if s != nil {
   634  //
   635  				n.right = s
   636  //
   637  				noHash = n.toggle()
   638  
   639  			} else {
   640  //
   641  //
   642  				noHash = true
   643  			}
   644  		}
   645  //
   646  //
   647  //
   648  		if noHash {
   649  			s = nil
   650  		} else {
   651  			s = doSum(bh, nil, n.left, n.right)
   652  		}
   653  //
   654  		isLeft = n.isLeft
   655  		n = n.parent
   656  		level++
   657  	}
   658  }
   659  
   660  //
   661  func (h *Hasher) getTree() *tree {
   662  	if h.bmt != nil {
   663  		return h.bmt
   664  	}
   665  	t := h.pool.reserve()
   666  	h.bmt = t
   667  	return t
   668  }
   669  
   670  //
   671  //
   672  //
   673  func (n *node) toggle() bool {
   674  	return atomic.AddInt32(&n.state, 1)%2 == 1
   675  }
   676  
   677  //
   678  func doSum(h hash.Hash, b []byte, data ...[]byte) []byte {
   679  	h.Reset()
   680  	for _, v := range data {
   681  		h.Write(v)
   682  	}
   683  	return h.Sum(b)
   684  }
   685  
   686  //
   687  func hashstr(b []byte) string {
   688  	end := len(b)
   689  	if end > 4 {
   690  		end = 4
   691  	}
   692  	return fmt.Sprintf("%x", b[:end])
   693  }
   694  
   695  //
   696  func calculateDepthFor(n int) (d int) {
   697  	c := 2
   698  	for ; c < n; c *= 2 {
   699  		d++
   700  	}
   701  	return d + 1
   702  }
   703