github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/merkletree2/position.go (about)

     1  package merkletree2
     2  
     3  import (
     4  	"math/big"
     5  )
     6  
     7  // Position represents the position of a node in the tree. When converted to
     8  // bytes, a Position can be interpreted as a 1 followed (from left to right) by
     9  // a sequence of log2(Config.ChildrenPerNode)-bit symbols, where each such
    10  // symbol identifies which child to descend to in a path from the root to a
    11  // node. The sequence is padded with 0s on the left to the nearest byte. For
    12  // example, in a binary tree the root has position 0x01 (i.e. 0b00000001), and
    13  // the second child of the first child of the root has position 0x05
    14  // (0b00000101).
    15  type Position big.Int
    16  
    17  func (t *Config) GetRootPosition() *Position {
    18  	return (*Position)(big.NewInt(1))
    19  }
    20  
    21  func (t *Config) GetChild(p *Position, c ChildIndex) *Position {
    22  	var q big.Int
    23  	q.Lsh((*big.Int)(p), uint(t.BitsPerIndex))
    24  	q.Bits()[0] = q.Bits()[0] | big.Word(c)
    25  	return (*Position)(&q)
    26  }
    27  
    28  func (p *Position) GetBytes() []byte {
    29  	return (*big.Int)(p).Bytes()
    30  }
    31  
    32  func (p *Position) AsString() string {
    33  	return string(p.GetBytes())
    34  }
    35  
    36  func (p *Position) SetBytes(b []byte) {
    37  	(*big.Int)(p).SetBytes(b)
    38  }
    39  
    40  func NewPositionFromBytes(pos []byte) *Position {
    41  	var p big.Int
    42  	p.SetBytes(pos)
    43  	return (*Position)(&p)
    44  }
    45  
    46  // Set updates p to the value of q
    47  func (p *Position) Set(q *Position) {
    48  	(*big.Int)(p).Set((*big.Int)(q))
    49  }
    50  
    51  // Clone returns a pointer to a deep copy of a position
    52  func (p *Position) Clone() *Position {
    53  	var q Position
    54  	q.Set(p)
    55  	return &q
    56  }
    57  
    58  func (p *Position) isOnPathToKey(k Key) bool {
    59  	// If the Key is shorter than current prefix
    60  	if len(k)*8 < (*big.Int)(p).BitLen()-1 {
    61  		return false
    62  	}
    63  	var q big.Int
    64  	q.SetBytes([]byte(k))
    65  	q.SetBit(&q, len(k)*8, 1)
    66  	q.Rsh(&q, uint(q.BitLen()-(*big.Int)(p).BitLen()))
    67  	return (*big.Int)(p).Cmp(&q) == 0
    68  }
    69  
    70  func (p *Position) Equals(q *Position) bool {
    71  	return (*big.Int)(p).CmpAbs((*big.Int)(q)) == 0
    72  }
    73  
    74  // getParent return nil if the p is the root
    75  func (t *Config) getParent(p *Position) *Position {
    76  	if (*big.Int)(p).BitLen() < 2 {
    77  		return nil
    78  	}
    79  
    80  	f := p.Clone()
    81  	t.updateToParent(f)
    82  
    83  	return f
    84  }
    85  
    86  func (t *Config) updateToParent(p *Position) {
    87  	((*big.Int)(p)).Rsh((*big.Int)(p), uint(t.BitsPerIndex))
    88  }
    89  
    90  // Behavior if p has no parent at the requested level is undefined.
    91  func (t *Config) updateToParentAtLevel(p *Position, level uint) {
    92  	shift := (*big.Int)(p).BitLen() - 1 - int(t.BitsPerIndex)*int(level)
    93  	((*big.Int)(p)).Rsh((*big.Int)(p), uint(shift))
    94  }
    95  
    96  // updateToParentAndAllSiblings takes as input p and a slice of size
    97  // t.cfg.ChildrenPerNode - 1. It populates the slice with the siblings of p, and
    98  // updates p to be its parent.
    99  func (t *Config) updateToParentAndAllSiblings(p *Position, sibs []Position) {
   100  	if (*big.Int)(p).BitLen() < 2 {
   101  		return
   102  	}
   103  
   104  	// Optimization for binary trees
   105  	if t.ChildrenPerNode == 2 {
   106  		sibs[0].Set(p)
   107  		lsBits := &(((*big.Int)(&sibs[0]).Bits())[0])
   108  		*lsBits = (*lsBits ^ 1)
   109  
   110  	} else {
   111  
   112  		pChildIndex := big.Word(t.getDeepestChildIndex(p))
   113  
   114  		mask := ^((big.Word)((1 << t.BitsPerIndex) - 1))
   115  
   116  		for i, j := uint(0), big.Word(0); j < big.Word(t.ChildrenPerNode); j++ {
   117  			if j == pChildIndex {
   118  				continue
   119  			}
   120  
   121  			sibs[i].Set(p)
   122  			// Set least significant bits to the j-th children
   123  			lsBits := &(((*big.Int)(&sibs[i]).Bits())[0])
   124  			*lsBits = (*lsBits & mask) | j
   125  			i++
   126  		}
   127  	}
   128  
   129  	t.updateToParent(p)
   130  }
   131  
   132  // getDeepestPositionForKey converts the key into the position the key would be
   133  // stored at if the tree was full with only one key per leaf.
   134  func (t *Config) getDeepestPositionForKey(k Key) (*Position, error) {
   135  	if len(k) != t.KeysByteLength {
   136  		return nil, NewInvalidKeyError()
   137  	}
   138  	var p Position
   139  	(*big.Int)(&p).SetBytes(k)
   140  	(*big.Int)(&p).SetBit((*big.Int)(&p), len(k)*8, 1)
   141  	return &p, nil
   142  }
   143  
   144  // Returns the lexicographically first key which could be found at any children
   145  // of position p in the tree
   146  func (t *Config) getMinKey(p *Position) Key {
   147  	var min big.Int
   148  	min.Set((*big.Int)(p))
   149  	n := uint(t.KeysByteLength*8 + 1 - min.BitLen())
   150  	min.Lsh(&min, n)
   151  	return min.Bytes()[1:]
   152  }
   153  
   154  func (t *Config) GetKeyIntervalUnderPosition(p *Position) (minKey, maxKey Key) {
   155  	var min, max big.Int
   156  
   157  	min.Set((*big.Int)(p))
   158  	n := uint(t.KeysByteLength*8 + 1 - min.BitLen())
   159  	min.Lsh(&min, n)
   160  	minKey = min.Bytes()[1:]
   161  
   162  	one := big.NewInt(1)
   163  	max.Lsh(one, n)
   164  	max.Sub(&max, one)
   165  	max.Or(&max, &min)
   166  	maxKey = max.Bytes()[1:]
   167  
   168  	return minKey, maxKey
   169  }
   170  
   171  // getDeepestPositionAtLevelAndSiblingsOnPathToKey returns a slice of positions,
   172  // in descending order by level (siblings farther from the root come first) and
   173  // in lexicographic order within each level. The first position in the slice is
   174  // the position at level lastLevel on a path from the root to k (or the deepest
   175  // possible position for such key if latLevel is greater than that). The
   176  // following positions are all the siblings of the nodes on the longest possible
   177  // path from the root to the key k with are at levels from lastLevel (excluded)
   178  // to firstLevel (included).
   179  // See TestGetDeepestPositionAtLevelAndSiblingsOnPathToKey for sample outputs.
   180  func (t *Config) getDeepestPositionAtLevelAndSiblingsOnPathToKey(k Key, lastLevel int, firstLevel int) (sibs []Position) {
   181  
   182  	maxLevel := t.KeysByteLength * 8 / int(t.BitsPerIndex)
   183  	if lastLevel > maxLevel {
   184  		lastLevel = maxLevel
   185  	}
   186  
   187  	// first, shrink the key for efficiency
   188  	bytesNecessary := lastLevel * int(t.BitsPerIndex) / 8
   189  	if lastLevel*int(t.BitsPerIndex)%8 != 0 {
   190  		bytesNecessary++
   191  	}
   192  	k = k[:bytesNecessary]
   193  
   194  	var buf Position
   195  	p := &buf
   196  	(*big.Int)(p).SetBytes(k)
   197  	(*big.Int)(p).SetBit((*big.Int)(p), len(k)*8, 1)
   198  
   199  	t.updateToParentAtLevel(p, uint(lastLevel))
   200  
   201  	sibs = make([]Position, (lastLevel-firstLevel+1)*(t.ChildrenPerNode-1)+1)
   202  	sibs[0].Set(p)
   203  	for i, j := lastLevel, 0; i >= firstLevel; i-- {
   204  		sibsToFill := sibs[1+(t.ChildrenPerNode-1)*j : 1+(t.ChildrenPerNode-1)*(j+1)]
   205  		t.updateToParentAndAllSiblings(p, sibsToFill)
   206  		j++
   207  	}
   208  
   209  	return sibs
   210  }
   211  
   212  // getLevel returns the level of p. The root is at level 0, and each node has
   213  // level 1 higher than its parent.
   214  func (t *Config) getLevel(p *Position) int {
   215  	return ((*big.Int)(p).BitLen() - 1) / int(t.BitsPerIndex)
   216  }
   217  
   218  // getParentAtLevel returns nil if p is at a level lower than `level`. The root
   219  // is at level 0, and each node has level 1 higher than its parent.
   220  func (t *Config) getParentAtLevel(p *Position, level uint) *Position {
   221  	shift := (*big.Int)(p).BitLen() - 1 - int(t.BitsPerIndex)*int(level)
   222  	if (*big.Int)(p).BitLen() < 2 || shift < 0 {
   223  		return nil
   224  	}
   225  
   226  	f := p.Clone()
   227  	t.updateToParentAtLevel(f, level)
   228  	return f
   229  }
   230  
   231  // positionToChildIndexPath returns the list of childIndexes to navigate from the
   232  // root to p (in reverse order).
   233  func (t *Config) positionToChildIndexPath(p *Position) (path []ChildIndex) {
   234  	path = make([]ChildIndex, t.getLevel(p))
   235  
   236  	bitMask := big.Word(t.ChildrenPerNode - 1)
   237  
   238  	buff := p.Clone()
   239  
   240  	for i := range path {
   241  		path[i] = ChildIndex(((*big.Int)(buff)).Bits()[0] & bitMask)
   242  		((*big.Int)(buff)).Rsh((*big.Int)(buff), uint(t.BitsPerIndex))
   243  	}
   244  
   245  	return path
   246  }
   247  
   248  // getDeepestChildIndex returns the only ChildIndex i such that p is the i-th children of
   249  // its parent. It returns 0 on the root.
   250  func (t *Config) getDeepestChildIndex(p *Position) ChildIndex {
   251  	if (*big.Int)(p).BitLen() < 2 {
   252  		return ChildIndex(0)
   253  	}
   254  	return ChildIndex(((*big.Int)(p).Bits())[0] & ((1 << t.BitsPerIndex) - 1))
   255  }
   256  
   257  func (p *Position) CmpInMerkleProofOrder(p2 *Position) int {
   258  	lp := (*big.Int)(p).BitLen()
   259  	lp2 := (*big.Int)(p2).BitLen()
   260  	if lp > lp2 {
   261  		return -1
   262  	} else if lp < lp2 {
   263  		return 1
   264  	}
   265  	return (*big.Int)(p).CmpAbs((*big.Int)(p2))
   266  }