trpc.group/trpc-go/trpc-go@v1.0.3/internal/dat/dat.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 // Package dat provides a double array trie. 15 // A DAT is used to filter protobuf fields specified by HttpRule. 16 // These fields will be ignored if they also present in http request query parameters 17 // to prevent repeated reference. 18 package dat 19 20 import ( 21 "errors" 22 "math" 23 "sort" 24 ) 25 26 var ( 27 errByDictOrder = errors.New("not by dict order") 28 errEncoded = errors.New("field name not encoded") 29 ) 30 31 const ( 32 defaultArraySize = 64 // default array size of dat 33 minExpansionRate = 1.05 // minimal expansion rate, based on experience 34 nextCheckPosStrategyRate = 0.95 // next check pos strategy rate, based on experience 35 ) 36 37 // DoubleArrayTrie is a double array trie. 38 // It's based on https://github.com/komiya-atsushi/darts-java. 39 // State Transition Equation: 40 // 41 // base[0] = 1 42 // base[s] + c = t 43 // check[t] = base[s] 44 type DoubleArrayTrie struct { 45 base []int // base array 46 check []int // check array 47 used []bool // used array 48 size int // size of base/check/used arrays 49 allocSize int // allocated size of base/check/used arrays 50 fps fieldPaths // fieldPaths 51 dict fieldDict // fieldDict 52 progress int // number of processed fieldPaths 53 nextCheckPos int // record next index of begin to prevent start over from 0 54 } 55 56 // node is node of DAT. 57 type node struct { 58 code int // code = dictCodeOfFieldName + 1, dictCodeOfFieldName: [0, 1, 2, ..., n-1] 59 depth int // depth of node 60 left int // left boundary 61 right int // right boundary 62 } 63 64 // Build performs static construction of a DAT. 65 func Build(fps [][]string) (*DoubleArrayTrie, error) { 66 // sort 67 sort.Sort(fieldPaths(fps)) 68 69 // init dat 70 dat := &DoubleArrayTrie{ 71 fps: fps, 72 dict: newFieldDict(fps), 73 } 74 dat.resize(defaultArraySize) 75 dat.base[0] = 1 76 77 // root node handling 78 root := &node{ 79 right: len(dat.fps), 80 } 81 children, err := dat.fetch(root) 82 if err != nil { 83 return nil, err 84 } 85 if _, err := dat.insert(children); err != nil { 86 return nil, err 87 } 88 89 // shrink 90 dat.resize(dat.size) 91 92 return dat, nil 93 } 94 95 // CommonPrefixSearch check if input fieldPath has common prefix with fps in DAT. 96 func (dat *DoubleArrayTrie) CommonPrefixSearch(fieldPath []string) bool { 97 var pos int 98 baseValue := dat.base[0] 99 100 for _, name := range fieldPath { 101 // get dict code 102 v, ok := dat.dict[name] 103 if !ok { 104 break 105 } 106 code := v + 1 // code = dictCodeOfFieldName + 1 107 108 // check if leaf node has been reached, that is, check if next node is NULL according to 109 // the State Transition Equation. 110 if baseValue == dat.check[baseValue] && dat.base[baseValue] < 0 { 111 // has reached leaf node,it's the common prefix. 112 return true 113 } 114 115 // state transition 116 pos = baseValue + code 117 if pos >= len(dat.check) || baseValue != dat.check[pos] { // mismatch 118 return false 119 } 120 baseValue = dat.base[pos] 121 } 122 123 // check again if leaf node has been reached for last state transition 124 if baseValue == dat.check[baseValue] && dat.base[baseValue] < 0 { 125 // has reached leaf node,it's the common prefix. 126 return true 127 } 128 129 return false 130 } 131 132 // fetch returns children nodes given parent node. 133 // If the fps in DAT is like: 134 // 135 // ["foobar", "foo", "bar"] 136 // ["foobar", "baz"] 137 // ["foo", "qux"] 138 // 139 // children, _ := dat.fetch(root),children should be ["foobar", "foo"], 140 // and their depths should all be 1. 141 func (dat *DoubleArrayTrie) fetch(parent *node) ([]*node, error) { 142 var ( 143 children []*node // children nodes would be returned 144 prev int // code of prev child node 145 ) 146 147 // search range [parent.left, parent.right) 148 // for root node,search range [0, len(dat.fps)) 149 for i := parent.left; i < parent.right; i++ { 150 if len(dat.fps[i]) < parent.depth { // all fp of fps[i] have been fetched 151 continue 152 } 153 154 var curr int // code of curr child node 155 if len(dat.fps[i]) > parent.depth { 156 v, ok := dat.dict[dat.fps[i][parent.depth]] 157 if !ok { // not encoded 158 return nil, errEncoded 159 } 160 curr = v + 1 // code = dictCodeOfFieldName + 1 161 } 162 163 // not by dict order 164 if prev > curr { 165 return nil, errByDictOrder 166 } 167 168 // Normally, if curr == prev, skip this. 169 // But curr == prev && len(children) == 0 makes an exception, 170 // it means fetching fp from fps[i] comes to an end and an empty node should be added 171 // like an EOF. 172 if curr != prev || len(children) == 0 { 173 // update right boundary of prev child node 174 if len(children) != 0 { 175 children[len(children)-1].right = i 176 } 177 // curr child node 178 // no need to update right boundary, 179 // let next child node update this node's right boundary 180 children = append(children, &node{ 181 code: curr, 182 depth: parent.depth + 1, // depth +1 183 left: i, 184 }) 185 } 186 187 prev = curr 188 } 189 190 // update right boundary of the last child node 191 if len(children) > 0 { 192 children[len(children)-1].right = parent.right // same right boundary as parent node 193 } 194 195 return children, nil 196 } 197 198 // max returns the bigger int value. 199 func max(x, y int) int { 200 if x > y { 201 return x 202 } 203 return y 204 } 205 206 // loopForBegin loops for begin value that meets the condition. 207 func (dat *DoubleArrayTrie) loopForBegin(children []*node) (int, error) { 208 var ( 209 begin int // begin to loop for 210 numOfNonZero int // number of non zero 211 pos = max(children[0].code, dat.nextCheckPos-1) // prevent start over from 0 to loop for begin value 212 ) 213 214 for first := true; ; { // whether first time to meet a non zero 215 pos++ 216 if dat.allocSize <= pos { // expand 217 dat.resize(pos + 1) 218 } 219 if dat.check[pos] != 0 { // occupied 220 numOfNonZero++ 221 continue 222 } else { 223 if first { 224 dat.nextCheckPos = pos 225 first = false 226 } 227 } 228 229 // try this begin value 230 begin = pos - children[0].code 231 232 // compare with lastChildPos to check if expansion is needed 233 if lastChildPos := begin + children[len(children)-1].code; dat.allocSize <= lastChildPos { 234 // rate = {total number of fieldPaths} / ({number of processed fieldPaths} + 1), but not less than 1.05 235 rate := math.Max(minExpansionRate, float64(1.0*len(dat.fps)/(dat.progress+1))) 236 dat.resize(int(float64(dat.allocSize) * rate)) 237 } 238 239 if dat.used[begin] { // check dup 240 continue 241 } 242 243 // check if remaining children nodes could be inserted 244 conflict := func() bool { 245 for i := 1; i < len(children); i++ { 246 if dat.check[begin+children[i].code] != 0 { 247 return true 248 } 249 } 250 return false 251 } 252 // if conflicting, next pos 253 if conflict() { 254 continue 255 } 256 // no conflicting, found the begin value 257 break 258 } 259 260 // if nodes from nextCheckPos to pos are all occupied, set nextCheckPos to pos 261 if float64((1.0*numOfNonZero)/(pos-dat.nextCheckPos+1)) >= nextCheckPosStrategyRate { 262 dat.nextCheckPos = pos 263 } 264 265 return begin, nil 266 } 267 268 // insert inserts children nodes into DAT, returns begin value that is looking for. 269 func (dat *DoubleArrayTrie) insert(children []*node) (int, error) { 270 // loop for begin value 271 begin, err := dat.loopForBegin(children) 272 if err != nil { 273 return 0, err 274 } 275 276 dat.used[begin] = true 277 dat.size = max(dat.size, begin+children[len(children)-1].code+1) 278 279 // check arrays assignment 280 for i := range children { 281 dat.check[begin+children[i].code] = begin 282 } 283 284 // dfs 285 for _, child := range children { 286 grandchildren, err := dat.fetch(child) 287 if err != nil { 288 return 0, err 289 } 290 if len(grandchildren) == 0 { // no children nodes 291 dat.base[begin+child.code] = -child.left - 1 292 dat.progress++ 293 continue 294 } 295 t, err := dat.insert(grandchildren) 296 if err != nil { 297 return 0, err 298 } 299 // base arrays assignment 300 dat.base[begin+child.code] = t 301 } 302 303 return begin, nil 304 } 305 306 // resize changes the size of the arrays. 307 func (dat *DoubleArrayTrie) resize(newSize int) { 308 newBase := make([]int, newSize, newSize) 309 newCheck := make([]int, newSize, newSize) 310 newUsed := make([]bool, newSize, newSize) 311 312 if dat.allocSize > 0 { 313 copy(newBase, dat.base) 314 copy(newCheck, dat.check) 315 copy(newUsed, dat.used) 316 } 317 318 dat.base = newBase 319 dat.check = newCheck 320 dat.used = newUsed 321 322 dat.allocSize = newSize 323 } 324 325 type fieldPaths [][]string 326 327 // Len implements sort.Interface 328 func (fps fieldPaths) Len() int { return len(fps) } 329 330 // Swap implements sort.Interface 331 func (fps fieldPaths) Swap(i, j int) { fps[i], fps[j] = fps[j], fps[i] } 332 333 // Less implements sort.Interface 334 func (fps fieldPaths) Less(i, j int) bool { 335 var k int 336 for k = 0; k < len(fps[i]) && k < len(fps[j]); k++ { 337 if fps[i][k] < fps[j][k] { 338 return true 339 } 340 if fps[i][k] > fps[j][k] { 341 return false 342 } 343 } 344 return k < len(fps[j]) 345 } 346 347 type fieldDict map[string]int // FieldName -> DictCodeOfFieldName 348 349 func newFieldDict(fps fieldPaths) fieldDict { 350 dict := make(map[string]int) 351 // rm dup 352 for _, fieldPath := range fps { 353 for _, name := range fieldPath { 354 dict[name] = 0 355 } 356 } 357 358 // sort 359 fields := make([]string, 0, len(dict)) 360 for name := range dict { 361 fields = append(fields, name) 362 } 363 sort.Sort(sort.StringSlice(fields)) 364 365 // dict assignment 366 367 for code, name := range fields { 368 dict[name] = code 369 } 370 return dict 371 }