github.com/lazin/go-ngram@v0.0.0-20160527144230-80eaf16ac4eb/ngram.go (about)

     1  package ngram
     2  
     3  import (
     4  	"errors"
     5  	"math"
     6  	"sync"
     7  
     8  	"github.com/spaolacci/murmur3"
     9  )
    10  
    11  const (
    12  	maxN       = 8
    13  	defaultPad = "$"
    14  	defaultN   = 3
    15  )
    16  
    17  // TokenID is just id of the token
    18  type TokenID int
    19  
    20  type nGramValue map[TokenID]int
    21  
    22  // NGramIndex can be initialized by default (zeroed) or created with "NewNgramIndex"
    23  type NGramIndex struct {
    24  	pad   string
    25  	n     int
    26  	spool stringPool
    27  	index map[uint32]nGramValue
    28  	warp  float64
    29  
    30  	sync.RWMutex
    31  }
    32  
    33  // SearchResult contains token id and similarity - value in range from 0.0 to 1.0
    34  type SearchResult struct {
    35  	TokenID    TokenID
    36  	Similarity float64
    37  }
    38  
    39  func (ngram *NGramIndex) splitInput(str string) ([]uint32, error) {
    40  	if len(str) == 0 {
    41  		return nil, errors.New("empty string")
    42  	}
    43  	pad := ngram.pad
    44  	n := ngram.n
    45  	input := pad + str + pad
    46  	prevIndexes := make([]int, maxN)
    47  	var counter int
    48  	results := make([]uint32, 0)
    49  
    50  	for index := range input {
    51  		counter++
    52  		if counter > n {
    53  			top := prevIndexes[(counter-n)%maxN]
    54  			substr := input[top:index]
    55  			hash := murmur3.Sum32([]byte(substr))
    56  			results = append(results, hash)
    57  		}
    58  		prevIndexes[counter%maxN] = index
    59  	}
    60  
    61  	for i := n - 1; i > 1; i-- {
    62  		if len(input) >= i {
    63  			top := prevIndexes[(len(input)-i)%maxN]
    64  			substr := input[top:]
    65  			hash := murmur3.Sum32([]byte(substr))
    66  			results = append(results, hash)
    67  		}
    68  	}
    69  
    70  	return results, nil
    71  }
    72  
    73  func (ngram *NGramIndex) init() {
    74  	ngram.Lock()
    75  	defer ngram.Unlock()
    76  
    77  	ngram.index = make(map[uint32]nGramValue)
    78  	if ngram.pad == "" {
    79  		ngram.pad = defaultPad
    80  	}
    81  	if ngram.n == 0 {
    82  		ngram.n = defaultN
    83  	}
    84  	if ngram.warp == 0.0 {
    85  		ngram.warp = 1.0
    86  	}
    87  }
    88  
    89  type Option func(*NGramIndex) error
    90  
    91  // SetPad must be used to pass padding character to NGramIndex c-tor
    92  func SetPad(c rune) Option {
    93  	return func(ngram *NGramIndex) error {
    94  		ngram.pad = string(c)
    95  		return nil
    96  	}
    97  }
    98  
    99  // SetN must be used to pass N (gram size) to NGramIndex c-tor
   100  func SetN(n int) Option {
   101  	return func(ngram *NGramIndex) error {
   102  		if n < 2 || n > maxN {
   103  			return errors.New("bad 'n' value for n-gram index")
   104  		}
   105  		ngram.n = n
   106  		return nil
   107  	}
   108  }
   109  
   110  // SetWarp must be used to pass warp to NGramIndex c-tor
   111  func SetWarp(warp float64) Option {
   112  	return func(ngram *NGramIndex) error {
   113  		if warp < 0.0 || warp > 1.0 {
   114  			return errors.New("bad 'warp' value for n-gram index")
   115  		}
   116  		ngram.warp = warp
   117  		return nil
   118  	}
   119  }
   120  
   121  // NewNGramIndex is N-gram index c-tor. In most cases must be used withot parameters.
   122  // You can pass parameters to c-tor using functions SetPad, SetWarp and SetN.
   123  func NewNGramIndex(opts ...Option) (*NGramIndex, error) {
   124  	ngram := new(NGramIndex)
   125  	for _, opt := range opts {
   126  		if err := opt(ngram); err != nil {
   127  			return nil, err
   128  		}
   129  	}
   130  	ngram.init()
   131  	return ngram, nil
   132  }
   133  
   134  // Add token to index. Function returns token id, this id can be converted
   135  // to string with function "GetString".
   136  func (ngram *NGramIndex) Add(input string) (TokenID, error) {
   137  	if ngram.index == nil {
   138  		ngram.init()
   139  	}
   140  	results, error := ngram.splitInput(input)
   141  	if error != nil {
   142  		return -1, error
   143  	}
   144  	ixstr, error := ngram.spool.Append(input)
   145  	if error != nil {
   146  		return -1, error
   147  	}
   148  	for _, hash := range results {
   149  		ngram.Lock()
   150  		if ngram.index[hash] == nil {
   151  			ngram.index[hash] = make(map[TokenID]int)
   152  		}
   153  		// insert string and counter
   154  		ngram.index[hash][ixstr]++
   155  		ngram.Unlock()
   156  	}
   157  	return ixstr, nil
   158  }
   159  
   160  // GetString converts token-id to string.
   161  func (ngram *NGramIndex) GetString(id TokenID) (string, error) {
   162  	return ngram.spool.ReadAt(id)
   163  }
   164  
   165  // countNgrams maps matched tokens to the number of ngrams, shared with input string
   166  func (ngram *NGramIndex) countNgrams(inputNgrams []uint32) map[TokenID]int {
   167  	counters := make(map[TokenID]int)
   168  	for _, ngramHash := range inputNgrams {
   169  		ngram.RLock()
   170  		for tok := range ngram.index[ngramHash] {
   171  			counters[tok]++
   172  		}
   173  		ngram.RUnlock()
   174  	}
   175  	return counters
   176  }
   177  
   178  func validateThresholdValues(thresholds []float64) (float64, error) {
   179  	var tval float64
   180  	if len(thresholds) == 1 {
   181  		tval = thresholds[0]
   182  		if tval < 0.0 || tval > 1.0 {
   183  			return 0.0, errors.New("threshold must be in range (0, 1)")
   184  		}
   185  	} else if len(thresholds) > 1 {
   186  		return 0.0, errors.New("too many arguments")
   187  	}
   188  	return tval, nil
   189  }
   190  
   191  func (ngram *NGramIndex) match(input string, tval float64) ([]SearchResult, error) {
   192  	inputNgrams, error := ngram.splitInput(input)
   193  	if error != nil {
   194  		return nil, error
   195  	}
   196  	output := make([]SearchResult, 0)
   197  	tokenCount := ngram.countNgrams(inputNgrams)
   198  	for token, count := range tokenCount {
   199  		var sim float64
   200  		allngrams := float64(len(inputNgrams))
   201  		matchngrams := float64(count)
   202  		if ngram.warp == 1.0 {
   203  			sim = matchngrams / allngrams
   204  		} else {
   205  			diffngrams := allngrams - matchngrams
   206  			sim = math.Pow(allngrams, ngram.warp) - math.Pow(diffngrams, ngram.warp)
   207  			sim /= math.Pow(allngrams, ngram.warp)
   208  		}
   209  		if sim >= tval {
   210  			res := SearchResult{Similarity: sim, TokenID: token}
   211  			output = append(output, res)
   212  		}
   213  	}
   214  	return output, nil
   215  }
   216  
   217  // Search for matches between query string (input) and indexed strings.
   218  // First parameter - threshold is optional and can be used to set minimal similarity
   219  // between input string and matching string. You can pass only one threshold value.
   220  // Results is an unordered array of 'SearchResult' structs. This struct contains similarity
   221  // value (float32 value from threshold to 1.0) and token-id.
   222  func (ngram *NGramIndex) Search(input string, threshold ...float64) ([]SearchResult, error) {
   223  	if ngram.index == nil {
   224  		ngram.init()
   225  	}
   226  	tval, error := validateThresholdValues(threshold)
   227  	if error != nil {
   228  		return nil, error
   229  	}
   230  	return ngram.match(input, tval)
   231  }
   232  
   233  // BestMatch is the same as Search except that it's returning only one best result instead of all.
   234  func (ngram *NGramIndex) BestMatch(input string, threshold ...float64) (*SearchResult, error) {
   235  	if ngram.index == nil {
   236  		ngram.init()
   237  	}
   238  	tval, error := validateThresholdValues(threshold)
   239  	if error != nil {
   240  		return nil, error
   241  	}
   242  	variants, error := ngram.match(input, tval)
   243  	if error != nil {
   244  		return nil, error
   245  	}
   246  	if len(variants) == 0 {
   247  		return nil, errors.New("no matches found")
   248  	}
   249  	var result SearchResult
   250  	maxsim := -1.0
   251  	for _, val := range variants {
   252  		if val.Similarity > maxsim {
   253  			maxsim = val.Similarity
   254  			result = val
   255  		}
   256  	}
   257  	return &result, nil
   258  }