github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/merkletree2/position.go (about) 1 package merkletree2 2 3 import ( 4 "math/big" 5 ) 6 7 // Position represents the position of a node in the tree. When converted to 8 // bytes, a Position can be interpreted as a 1 followed (from left to right) by 9 // a sequence of log2(Config.ChildrenPerNode)-bit symbols, where each such 10 // symbol identifies which child to descend to in a path from the root to a 11 // node. The sequence is padded with 0s on the left to the nearest byte. For 12 // example, in a binary tree the root has position 0x01 (i.e. 0b00000001), and 13 // the second child of the first child of the root has position 0x05 14 // (0b00000101). 15 type Position big.Int 16 17 func (t *Config) GetRootPosition() *Position { 18 return (*Position)(big.NewInt(1)) 19 } 20 21 func (t *Config) GetChild(p *Position, c ChildIndex) *Position { 22 var q big.Int 23 q.Lsh((*big.Int)(p), uint(t.BitsPerIndex)) 24 q.Bits()[0] = q.Bits()[0] | big.Word(c) 25 return (*Position)(&q) 26 } 27 28 func (p *Position) GetBytes() []byte { 29 return (*big.Int)(p).Bytes() 30 } 31 32 func (p *Position) AsString() string { 33 return string(p.GetBytes()) 34 } 35 36 func (p *Position) SetBytes(b []byte) { 37 (*big.Int)(p).SetBytes(b) 38 } 39 40 func NewPositionFromBytes(pos []byte) *Position { 41 var p big.Int 42 p.SetBytes(pos) 43 return (*Position)(&p) 44 } 45 46 // Set updates p to the value of q 47 func (p *Position) Set(q *Position) { 48 (*big.Int)(p).Set((*big.Int)(q)) 49 } 50 51 // Clone returns a pointer to a deep copy of a position 52 func (p *Position) Clone() *Position { 53 var q Position 54 q.Set(p) 55 return &q 56 } 57 58 func (p *Position) isOnPathToKey(k Key) bool { 59 // If the Key is shorter than current prefix 60 if len(k)*8 < (*big.Int)(p).BitLen()-1 { 61 return false 62 } 63 var q big.Int 64 q.SetBytes([]byte(k)) 65 q.SetBit(&q, len(k)*8, 1) 66 q.Rsh(&q, uint(q.BitLen()-(*big.Int)(p).BitLen())) 67 return (*big.Int)(p).Cmp(&q) == 0 68 } 69 70 func (p *Position) Equals(q *Position) bool { 71 return (*big.Int)(p).CmpAbs((*big.Int)(q)) == 0 72 } 73 74 // getParent return nil if the p is the root 75 func (t *Config) getParent(p *Position) *Position { 76 if (*big.Int)(p).BitLen() < 2 { 77 return nil 78 } 79 80 f := p.Clone() 81 t.updateToParent(f) 82 83 return f 84 } 85 86 func (t *Config) updateToParent(p *Position) { 87 ((*big.Int)(p)).Rsh((*big.Int)(p), uint(t.BitsPerIndex)) 88 } 89 90 // Behavior if p has no parent at the requested level is undefined. 91 func (t *Config) updateToParentAtLevel(p *Position, level uint) { 92 shift := (*big.Int)(p).BitLen() - 1 - int(t.BitsPerIndex)*int(level) 93 ((*big.Int)(p)).Rsh((*big.Int)(p), uint(shift)) 94 } 95 96 // updateToParentAndAllSiblings takes as input p and a slice of size 97 // t.cfg.ChildrenPerNode - 1. It populates the slice with the siblings of p, and 98 // updates p to be its parent. 99 func (t *Config) updateToParentAndAllSiblings(p *Position, sibs []Position) { 100 if (*big.Int)(p).BitLen() < 2 { 101 return 102 } 103 104 // Optimization for binary trees 105 if t.ChildrenPerNode == 2 { 106 sibs[0].Set(p) 107 lsBits := &(((*big.Int)(&sibs[0]).Bits())[0]) 108 *lsBits = (*lsBits ^ 1) 109 110 } else { 111 112 pChildIndex := big.Word(t.getDeepestChildIndex(p)) 113 114 mask := ^((big.Word)((1 << t.BitsPerIndex) - 1)) 115 116 for i, j := uint(0), big.Word(0); j < big.Word(t.ChildrenPerNode); j++ { 117 if j == pChildIndex { 118 continue 119 } 120 121 sibs[i].Set(p) 122 // Set least significant bits to the j-th children 123 lsBits := &(((*big.Int)(&sibs[i]).Bits())[0]) 124 *lsBits = (*lsBits & mask) | j 125 i++ 126 } 127 } 128 129 t.updateToParent(p) 130 } 131 132 // getDeepestPositionForKey converts the key into the position the key would be 133 // stored at if the tree was full with only one key per leaf. 134 func (t *Config) getDeepestPositionForKey(k Key) (*Position, error) { 135 if len(k) != t.KeysByteLength { 136 return nil, NewInvalidKeyError() 137 } 138 var p Position 139 (*big.Int)(&p).SetBytes(k) 140 (*big.Int)(&p).SetBit((*big.Int)(&p), len(k)*8, 1) 141 return &p, nil 142 } 143 144 // Returns the lexicographically first key which could be found at any children 145 // of position p in the tree 146 func (t *Config) getMinKey(p *Position) Key { 147 var min big.Int 148 min.Set((*big.Int)(p)) 149 n := uint(t.KeysByteLength*8 + 1 - min.BitLen()) 150 min.Lsh(&min, n) 151 return min.Bytes()[1:] 152 } 153 154 func (t *Config) GetKeyIntervalUnderPosition(p *Position) (minKey, maxKey Key) { 155 var min, max big.Int 156 157 min.Set((*big.Int)(p)) 158 n := uint(t.KeysByteLength*8 + 1 - min.BitLen()) 159 min.Lsh(&min, n) 160 minKey = min.Bytes()[1:] 161 162 one := big.NewInt(1) 163 max.Lsh(one, n) 164 max.Sub(&max, one) 165 max.Or(&max, &min) 166 maxKey = max.Bytes()[1:] 167 168 return minKey, maxKey 169 } 170 171 // getDeepestPositionAtLevelAndSiblingsOnPathToKey returns a slice of positions, 172 // in descending order by level (siblings farther from the root come first) and 173 // in lexicographic order within each level. The first position in the slice is 174 // the position at level lastLevel on a path from the root to k (or the deepest 175 // possible position for such key if latLevel is greater than that). The 176 // following positions are all the siblings of the nodes on the longest possible 177 // path from the root to the key k with are at levels from lastLevel (excluded) 178 // to firstLevel (included). 179 // See TestGetDeepestPositionAtLevelAndSiblingsOnPathToKey for sample outputs. 180 func (t *Config) getDeepestPositionAtLevelAndSiblingsOnPathToKey(k Key, lastLevel int, firstLevel int) (sibs []Position) { 181 182 maxLevel := t.KeysByteLength * 8 / int(t.BitsPerIndex) 183 if lastLevel > maxLevel { 184 lastLevel = maxLevel 185 } 186 187 // first, shrink the key for efficiency 188 bytesNecessary := lastLevel * int(t.BitsPerIndex) / 8 189 if lastLevel*int(t.BitsPerIndex)%8 != 0 { 190 bytesNecessary++ 191 } 192 k = k[:bytesNecessary] 193 194 var buf Position 195 p := &buf 196 (*big.Int)(p).SetBytes(k) 197 (*big.Int)(p).SetBit((*big.Int)(p), len(k)*8, 1) 198 199 t.updateToParentAtLevel(p, uint(lastLevel)) 200 201 sibs = make([]Position, (lastLevel-firstLevel+1)*(t.ChildrenPerNode-1)+1) 202 sibs[0].Set(p) 203 for i, j := lastLevel, 0; i >= firstLevel; i-- { 204 sibsToFill := sibs[1+(t.ChildrenPerNode-1)*j : 1+(t.ChildrenPerNode-1)*(j+1)] 205 t.updateToParentAndAllSiblings(p, sibsToFill) 206 j++ 207 } 208 209 return sibs 210 } 211 212 // getLevel returns the level of p. The root is at level 0, and each node has 213 // level 1 higher than its parent. 214 func (t *Config) getLevel(p *Position) int { 215 return ((*big.Int)(p).BitLen() - 1) / int(t.BitsPerIndex) 216 } 217 218 // getParentAtLevel returns nil if p is at a level lower than `level`. The root 219 // is at level 0, and each node has level 1 higher than its parent. 220 func (t *Config) getParentAtLevel(p *Position, level uint) *Position { 221 shift := (*big.Int)(p).BitLen() - 1 - int(t.BitsPerIndex)*int(level) 222 if (*big.Int)(p).BitLen() < 2 || shift < 0 { 223 return nil 224 } 225 226 f := p.Clone() 227 t.updateToParentAtLevel(f, level) 228 return f 229 } 230 231 // positionToChildIndexPath returns the list of childIndexes to navigate from the 232 // root to p (in reverse order). 233 func (t *Config) positionToChildIndexPath(p *Position) (path []ChildIndex) { 234 path = make([]ChildIndex, t.getLevel(p)) 235 236 bitMask := big.Word(t.ChildrenPerNode - 1) 237 238 buff := p.Clone() 239 240 for i := range path { 241 path[i] = ChildIndex(((*big.Int)(buff)).Bits()[0] & bitMask) 242 ((*big.Int)(buff)).Rsh((*big.Int)(buff), uint(t.BitsPerIndex)) 243 } 244 245 return path 246 } 247 248 // getDeepestChildIndex returns the only ChildIndex i such that p is the i-th children of 249 // its parent. It returns 0 on the root. 250 func (t *Config) getDeepestChildIndex(p *Position) ChildIndex { 251 if (*big.Int)(p).BitLen() < 2 { 252 return ChildIndex(0) 253 } 254 return ChildIndex(((*big.Int)(p).Bits())[0] & ((1 << t.BitsPerIndex) - 1)) 255 } 256 257 func (p *Position) CmpInMerkleProofOrder(p2 *Position) int { 258 lp := (*big.Int)(p).BitLen() 259 lp2 := (*big.Int)(p2).BitLen() 260 if lp > lp2 { 261 return -1 262 } else if lp < lp2 { 263 return 1 264 } 265 return (*big.Int)(p).CmpAbs((*big.Int)(p2)) 266 }