trpc.group/trpc-go/trpc-go@v1.0.3/internal/dat/dat.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package dat provides a double array trie.
    15  // A DAT is used to filter protobuf fields specified by HttpRule.
    16  // These fields will be ignored if they also present in http request query parameters
    17  // to prevent repeated reference.
    18  package dat
    19  
    20  import (
    21  	"errors"
    22  	"math"
    23  	"sort"
    24  )
    25  
    26  var (
    27  	errByDictOrder = errors.New("not by dict order")
    28  	errEncoded     = errors.New("field name not encoded")
    29  )
    30  
    31  const (
    32  	defaultArraySize         = 64   // default array size of dat
    33  	minExpansionRate         = 1.05 // minimal expansion rate, based on experience
    34  	nextCheckPosStrategyRate = 0.95 // next check pos strategy rate, based on experience
    35  )
    36  
    37  // DoubleArrayTrie is a double array trie.
    38  // It's based on https://github.com/komiya-atsushi/darts-java.
    39  // State Transition Equation:
    40  //
    41  //	base[0] = 1
    42  //	base[s] + c = t
    43  //	check[t] = base[s]
    44  type DoubleArrayTrie struct {
    45  	base         []int      // base array
    46  	check        []int      // check array
    47  	used         []bool     // used array
    48  	size         int        // size of base/check/used arrays
    49  	allocSize    int        // allocated size of base/check/used arrays
    50  	fps          fieldPaths // fieldPaths
    51  	dict         fieldDict  // fieldDict
    52  	progress     int        // number of processed fieldPaths
    53  	nextCheckPos int        // record next index of begin to prevent start over from 0
    54  }
    55  
    56  // node is node of DAT.
    57  type node struct {
    58  	code  int // code = dictCodeOfFieldName + 1, dictCodeOfFieldName: [0, 1, 2, ..., n-1]
    59  	depth int // depth of node
    60  	left  int // left boundary
    61  	right int // right boundary
    62  }
    63  
    64  // Build performs static construction of a DAT.
    65  func Build(fps [][]string) (*DoubleArrayTrie, error) {
    66  	// sort
    67  	sort.Sort(fieldPaths(fps))
    68  
    69  	// init dat
    70  	dat := &DoubleArrayTrie{
    71  		fps:  fps,
    72  		dict: newFieldDict(fps),
    73  	}
    74  	dat.resize(defaultArraySize)
    75  	dat.base[0] = 1
    76  
    77  	// root node handling
    78  	root := &node{
    79  		right: len(dat.fps),
    80  	}
    81  	children, err := dat.fetch(root)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	if _, err := dat.insert(children); err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	// shrink
    90  	dat.resize(dat.size)
    91  
    92  	return dat, nil
    93  }
    94  
    95  // CommonPrefixSearch check if input fieldPath has common prefix with fps in DAT.
    96  func (dat *DoubleArrayTrie) CommonPrefixSearch(fieldPath []string) bool {
    97  	var pos int
    98  	baseValue := dat.base[0]
    99  
   100  	for _, name := range fieldPath {
   101  		// get dict code
   102  		v, ok := dat.dict[name]
   103  		if !ok {
   104  			break
   105  		}
   106  		code := v + 1 // code = dictCodeOfFieldName + 1
   107  
   108  		// check if leaf node has been reached, that is, check if next node is NULL according to
   109  		// the State Transition Equation.
   110  		if baseValue == dat.check[baseValue] && dat.base[baseValue] < 0 {
   111  			// has reached leaf node,it's the common prefix.
   112  			return true
   113  		}
   114  
   115  		// state transition
   116  		pos = baseValue + code
   117  		if pos >= len(dat.check) || baseValue != dat.check[pos] { // mismatch
   118  			return false
   119  		}
   120  		baseValue = dat.base[pos]
   121  	}
   122  
   123  	// check again if leaf node has been reached for last state transition
   124  	if baseValue == dat.check[baseValue] && dat.base[baseValue] < 0 {
   125  		// has reached leaf node,it's the common prefix.
   126  		return true
   127  	}
   128  
   129  	return false
   130  }
   131  
   132  // fetch returns children nodes given parent node.
   133  // If the fps in DAT is like:
   134  //
   135  //	["foobar", "foo", "bar"]
   136  //	["foobar", "baz"]
   137  //	["foo", "qux"]
   138  //
   139  // children, _ := dat.fetch(root),children should be ["foobar", "foo"],
   140  // and their depths should all be 1.
   141  func (dat *DoubleArrayTrie) fetch(parent *node) ([]*node, error) {
   142  	var (
   143  		children []*node // children nodes would be returned
   144  		prev     int     // code of prev child node
   145  	)
   146  
   147  	// search range [parent.left, parent.right)
   148  	// for root node,search range [0, len(dat.fps))
   149  	for i := parent.left; i < parent.right; i++ {
   150  		if len(dat.fps[i]) < parent.depth { // all fp of fps[i] have been fetched
   151  			continue
   152  		}
   153  
   154  		var curr int // code of curr child node
   155  		if len(dat.fps[i]) > parent.depth {
   156  			v, ok := dat.dict[dat.fps[i][parent.depth]]
   157  			if !ok { // not encoded
   158  				return nil, errEncoded
   159  			}
   160  			curr = v + 1 // code = dictCodeOfFieldName + 1
   161  		}
   162  
   163  		// not by dict order
   164  		if prev > curr {
   165  			return nil, errByDictOrder
   166  		}
   167  
   168  		// Normally, if curr == prev, skip this.
   169  		// But curr == prev && len(children) == 0 makes an exception,
   170  		// it means fetching fp from fps[i] comes to an end and an empty node should be added
   171  		// like an EOF.
   172  		if curr != prev || len(children) == 0 {
   173  			// update right boundary of prev child node
   174  			if len(children) != 0 {
   175  				children[len(children)-1].right = i
   176  			}
   177  			// curr child node
   178  			// no need to update right boundary,
   179  			// let next child node update this node's right boundary
   180  			children = append(children, &node{
   181  				code:  curr,
   182  				depth: parent.depth + 1, // depth +1
   183  				left:  i,
   184  			})
   185  		}
   186  
   187  		prev = curr
   188  	}
   189  
   190  	// update right boundary of the last child node
   191  	if len(children) > 0 {
   192  		children[len(children)-1].right = parent.right // same right boundary as parent node
   193  	}
   194  
   195  	return children, nil
   196  }
   197  
   198  // max returns the bigger int value.
   199  func max(x, y int) int {
   200  	if x > y {
   201  		return x
   202  	}
   203  	return y
   204  }
   205  
   206  // loopForBegin loops for begin value that meets the condition.
   207  func (dat *DoubleArrayTrie) loopForBegin(children []*node) (int, error) {
   208  	var (
   209  		begin        int                                         // begin to loop for
   210  		numOfNonZero int                                         // number of non zero
   211  		pos          = max(children[0].code, dat.nextCheckPos-1) // prevent start over from 0 to loop for begin value
   212  	)
   213  
   214  	for first := true; ; { // whether first time to meet a non zero
   215  		pos++
   216  		if dat.allocSize <= pos { // expand
   217  			dat.resize(pos + 1)
   218  		}
   219  		if dat.check[pos] != 0 { // occupied
   220  			numOfNonZero++
   221  			continue
   222  		} else {
   223  			if first {
   224  				dat.nextCheckPos = pos
   225  				first = false
   226  			}
   227  		}
   228  
   229  		// try this begin value
   230  		begin = pos - children[0].code
   231  
   232  		// compare with lastChildPos to check if expansion is needed
   233  		if lastChildPos := begin + children[len(children)-1].code; dat.allocSize <= lastChildPos {
   234  			// rate = {total number of fieldPaths} / ({number of processed fieldPaths} + 1), but not less than 1.05
   235  			rate := math.Max(minExpansionRate, float64(1.0*len(dat.fps)/(dat.progress+1)))
   236  			dat.resize(int(float64(dat.allocSize) * rate))
   237  		}
   238  
   239  		if dat.used[begin] { // check dup
   240  			continue
   241  		}
   242  
   243  		// check if remaining children nodes could be inserted
   244  		conflict := func() bool {
   245  			for i := 1; i < len(children); i++ {
   246  				if dat.check[begin+children[i].code] != 0 {
   247  					return true
   248  				}
   249  			}
   250  			return false
   251  		}
   252  		// if conflicting, next pos
   253  		if conflict() {
   254  			continue
   255  		}
   256  		// no conflicting, found the begin value
   257  		break
   258  	}
   259  
   260  	// if nodes from nextCheckPos to pos are all occupied, set nextCheckPos to pos
   261  	if float64((1.0*numOfNonZero)/(pos-dat.nextCheckPos+1)) >= nextCheckPosStrategyRate {
   262  		dat.nextCheckPos = pos
   263  	}
   264  
   265  	return begin, nil
   266  }
   267  
   268  // insert inserts children nodes into DAT, returns begin value that is looking for.
   269  func (dat *DoubleArrayTrie) insert(children []*node) (int, error) {
   270  	// loop for begin value
   271  	begin, err := dat.loopForBegin(children)
   272  	if err != nil {
   273  		return 0, err
   274  	}
   275  
   276  	dat.used[begin] = true
   277  	dat.size = max(dat.size, begin+children[len(children)-1].code+1)
   278  
   279  	// check arrays assignment
   280  	for i := range children {
   281  		dat.check[begin+children[i].code] = begin
   282  	}
   283  
   284  	// dfs
   285  	for _, child := range children {
   286  		grandchildren, err := dat.fetch(child)
   287  		if err != nil {
   288  			return 0, err
   289  		}
   290  		if len(grandchildren) == 0 { // no children nodes
   291  			dat.base[begin+child.code] = -child.left - 1
   292  			dat.progress++
   293  			continue
   294  		}
   295  		t, err := dat.insert(grandchildren)
   296  		if err != nil {
   297  			return 0, err
   298  		}
   299  		// base arrays assignment
   300  		dat.base[begin+child.code] = t
   301  	}
   302  
   303  	return begin, nil
   304  }
   305  
   306  // resize changes the size of the arrays.
   307  func (dat *DoubleArrayTrie) resize(newSize int) {
   308  	newBase := make([]int, newSize, newSize)
   309  	newCheck := make([]int, newSize, newSize)
   310  	newUsed := make([]bool, newSize, newSize)
   311  
   312  	if dat.allocSize > 0 {
   313  		copy(newBase, dat.base)
   314  		copy(newCheck, dat.check)
   315  		copy(newUsed, dat.used)
   316  	}
   317  
   318  	dat.base = newBase
   319  	dat.check = newCheck
   320  	dat.used = newUsed
   321  
   322  	dat.allocSize = newSize
   323  }
   324  
   325  type fieldPaths [][]string
   326  
   327  // Len implements sort.Interface
   328  func (fps fieldPaths) Len() int { return len(fps) }
   329  
   330  // Swap implements sort.Interface
   331  func (fps fieldPaths) Swap(i, j int) { fps[i], fps[j] = fps[j], fps[i] }
   332  
   333  // Less implements sort.Interface
   334  func (fps fieldPaths) Less(i, j int) bool {
   335  	var k int
   336  	for k = 0; k < len(fps[i]) && k < len(fps[j]); k++ {
   337  		if fps[i][k] < fps[j][k] {
   338  			return true
   339  		}
   340  		if fps[i][k] > fps[j][k] {
   341  			return false
   342  		}
   343  	}
   344  	return k < len(fps[j])
   345  }
   346  
   347  type fieldDict map[string]int // FieldName -> DictCodeOfFieldName
   348  
   349  func newFieldDict(fps fieldPaths) fieldDict {
   350  	dict := make(map[string]int)
   351  	// rm dup
   352  	for _, fieldPath := range fps {
   353  		for _, name := range fieldPath {
   354  			dict[name] = 0
   355  		}
   356  	}
   357  
   358  	// sort
   359  	fields := make([]string, 0, len(dict))
   360  	for name := range dict {
   361  		fields = append(fields, name)
   362  	}
   363  	sort.Sort(sort.StringSlice(fields))
   364  
   365  	// dict assignment
   366  
   367  	for code, name := range fields {
   368  		dict[name] = code
   369  	}
   370  	return dict
   371  }