github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/nlp/lda.go (about) 1 package nlp 2 3 import ( 4 "fmt" 5 "math/rand" 6 "sort" 7 "time" 8 ) 9 10 // LdaVerbose determines whether progress information should be printed during 11 // LDA. For debugging. 12 var LdaVerbose = false 13 14 // ----- INTERFACE FUNCTIONS --------------------------------------------------- 15 16 // Lda performs LDA on the given data. docTokens should contain tokenized 17 // documents, such that docTokens[i][j] is the j'th token in the i'th document. 18 // k is the number of topics. Returns the topics and token-topic assignment, 19 // respective to docTokens. 20 // 21 // Topics are returned in a map from word to a probability vector, such that 22 // the i'th position is the probability of the i'th topic generating that word. 23 // For each i, the i'th position of all words sum to 1. 24 func Lda(docTokens [][]string, k int) (map[string][]float64, [][]int) { 25 return LdaThreads(docTokens, k, 1) 26 } 27 28 // LdaThreads is like the function Lda but runs on multiple subroutines. 29 // Calling this function with 1 thread is equivalent to calling Lda. 30 func LdaThreads(docTokens [][]string, k, numThreads int) (map[string][]float64, 31 [][]int) { 32 // Check input. 33 if k < 1 { 34 panic(fmt.Sprintf("k must be positive. Got %d.", k)) 35 } 36 if numThreads < 1 { 37 panic(fmt.Sprintf("Number of threads must be positive. Got %d.", 38 numThreads)) 39 } 40 41 // Create word map. 42 words := map[string]int{} 43 for _, doc := range docTokens { 44 for _, word := range doc { 45 if _, ok := words[word]; !ok { 46 words[word] = len(words) 47 } 48 } 49 } 50 if len(words) == 0 { 51 panic("Found 0 words in documents.") 52 } 53 if LdaVerbose { 54 fmt.Println("LDA:", len(words), "words in dictionary") 55 } 56 57 // Convert tokens to indexes. 58 docs := make([][]int, len(docTokens)) 59 for i := range docs { 60 docs[i] = make([]int, len(docTokens[i])) 61 for j := range docs[i] { 62 docs[i][j] = words[docTokens[i][j]] 63 } 64 } 65 66 topics := newDists(k, len(words), 0.1/float64(len(words))) 67 68 // Initial topic assignment. 69 doct := make([][]int, len(docs)) 70 for i := range docs { 71 doct[i] = make([]int, len(docs[i])) 72 for j := range doct[i] { 73 t := rand.Intn(k) 74 doct[i][j] = t 75 topics[t].add(docs[i][j]) 76 } 77 } 78 79 lastChange := 0 // How many words changed their topic in the last iteration. 80 for _, t := range doct { 81 lastChange += len(t) 82 } 83 breakSignals := 0 84 85 // Fun part! 86 for { 87 newTopics := newDists(k, len(words), 0.1/float64(len(words))) 88 89 // Big buffers for speed. 90 push := make(chan int, numThreads*1000) 91 pull := make(chan int, numThreads*1000) 92 change := make(chan int, numThreads) 93 done := make(chan int, numThreads) 94 95 // Pusher thread - pushes documnet index to threads. 96 go func() { 97 for i := range docs { 98 push <- i 99 } 100 close(push) 101 }() 102 103 // Puller thread - updates new topics with done documents. 104 go func() { 105 count := 0 106 progress := -1 107 for i := range pull { 108 // Print progress if verbose. 109 if LdaVerbose { 110 count++ 111 newProgress := count * 100 / len(doct) 112 if newProgress > progress { 113 progress = newProgress 114 fmt.Printf("\rLDA: [%d%%]", progress) 115 } 116 } 117 118 // Update document. 119 for j := range doct[i] { 120 newTopics[doct[i][j]].add(docs[i][j]) 121 } 122 } 123 124 if LdaVerbose { 125 fmt.Println() 126 } 127 done <- 0 128 }() 129 130 // changeCount thread - counts how many word changed their topic. 131 changeCount := 0 132 go func() { 133 for count := range change { 134 changeCount += count 135 } 136 done <- 0 137 }() 138 139 // Worker threads. 140 for thread := 0; thread < numThreads; thread++ { 141 go func() { 142 // Make a local copy of topics. 143 myTopics := copyDists(topics) 144 myChangeCount := 0 145 myRand := newRand() // Thread-local random to prevent waiting on rand's default source. 146 ts := make([]float64, k) // Reusable slice for randomly picking topics. 147 148 // For each document. 149 for i := range push { 150 // Create distribution of profiles. 151 d := newDist(k, 0.1/float64(k)) 152 for j := range doct[i] { 153 d.add(doct[i][j]) 154 } 155 156 // Reassign each word. 157 for j := range doct[i] { 158 t := doct[i][j] 159 word := docs[i][j] 160 161 // Unassign. 162 d.sub(t) 163 myTopics[t].sub(word) 164 165 // Pick new topic. 166 for k := range ts { 167 ts[k] = d.p(k) * myTopics[k].p(word) 168 } 169 t2 := pickRandom(ts, myRand) 170 if t2 != doct[i][j] { 171 myChangeCount++ 172 } 173 174 // Assign. 175 doct[i][j] = t2 176 d.add(t2) 177 myTopics[t2].add(word) 178 } 179 180 // Report this doc is done. 181 pull <- i 182 } 183 184 change <- myChangeCount 185 done <- 0 186 }() 187 } 188 189 // Wait for threads. 190 for i := 0; i < numThreads; i++ { 191 <-done 192 } 193 close(pull) 194 close(change) 195 <-done 196 <-done 197 198 // Update topics. 199 topics = newTopics 200 201 // Check halting condition. 202 if changeCount >= lastChange { 203 breakSignals++ 204 if breakSignals == 5 { 205 break 206 } 207 } 208 209 if LdaVerbose { 210 fmt.Printf("LDA: Changes: %d (%d) %.3f\n", changeCount, breakSignals, 211 float64(changeCount)/float64(lastChange)) 212 } 213 lastChange = changeCount 214 } 215 216 // Make return values. 217 topicDists := make([][]float64, len(topics)) 218 for i := range topicDists { 219 topicDists[i] = topics[i].dist() 220 } 221 222 dict := map[string][]float64{} 223 for word, i := range words { 224 d := make([]float64, k) 225 for j := range d { 226 d[j] = topicDists[j][i] 227 } 228 dict[word] = d 229 } 230 231 return dict, doct 232 } 233 234 // ----- HELPERS --------------------------------------------------------------- 235 236 // dist is a distribution on elements by counts. 237 type dist struct { 238 sum float64 239 count []float64 240 alpha float64 241 alphas float64 242 } 243 244 // newDist creates a new empty distribution. 245 func newDist(n int, alpha float64) *dist { 246 return &dist{0, make([]float64, n), alpha, alpha * float64(n)} 247 } 248 249 // newDists creates a slice of empty distributions. 250 func newDists(k, n int, alpha float64) []*dist { 251 result := make([]*dist, k) 252 for i := range result { 253 result[i] = newDist(n, alpha) 254 } 255 return result 256 } 257 258 // p returns the probability of i, considering alpha. 259 func (d *dist) p(i int) float64 { 260 if d.sum == 0 { 261 return 0 262 } 263 return (d.count[i] + d.alpha*d.sum) / (d.sum + d.alphas*d.sum) 264 } 265 266 // add increments i by 1. 267 func (d *dist) add(i int) { 268 d.count[i]++ 269 d.sum++ 270 } 271 272 // sun decrements i by 1. 273 func (d *dist) sub(i int) { 274 d.count[i]-- 275 d.sum-- 276 277 if d.count[i] < 0 { 278 panic(fmt.Sprintf("Reached negative count for i=%d.", i)) 279 } 280 } 281 282 // dist returns the counts of this distribution, normalized by its sum. 283 func (d *dist) dist() []float64 { 284 result := make([]float64, len(d.count)) 285 copy(result, d.count) 286 if d.sum != 0 { 287 for i := range result { 288 result[i] /= d.sum 289 } 290 } 291 return result 292 } 293 294 // copy deep-copies a distribution. 295 func (d *dist) copy() *dist { 296 count := make([]float64, len(d.count)) 297 for i := range count { 298 count[i] = d.count[i] 299 } 300 return &dist{d.sum, count, d.alpha, d.alphas} 301 } 302 303 // copyDists deep-copies a slice of distributions. 304 func copyDists(dists []*dist) []*dist { 305 result := make([]*dist, len(dists)) 306 for i := range result { 307 result[i] = dists[i].copy() 308 } 309 return result 310 } 311 312 // top returns the n most likely items in the distribution. 313 func (d *dist) top(n int) []int { 314 s := newDistSorter(d) 315 sort.Sort(s) 316 if n > len(s.perm) { 317 n = len(s.perm) 318 } 319 return s.perm[:n] 320 } 321 322 // distSorter is a distribution sorting interface. 323 type distSorter struct { 324 *dist 325 perm []int 326 } 327 328 func newDistSorter(d *dist) *distSorter { 329 s := &distSorter{d, make([]int, len(d.count))} 330 for i := range s.perm { 331 s.perm[i] = i 332 } 333 return s 334 } 335 336 func (d *distSorter) Len() int { 337 return len(d.perm) 338 } 339 340 func (d *distSorter) Less(i, j int) bool { 341 return d.count[d.perm[i]] > d.count[d.perm[j]] 342 } 343 344 func (d *distSorter) Swap(i, j int) { 345 d.perm[i], d.perm[j] = d.perm[j], d.perm[i] 346 } 347 348 // newRand creates a new random generator. 349 func newRand() *rand.Rand { 350 return rand.New(rand.NewSource(time.Now().UnixNano())) 351 } 352 353 // pickRandom picks a random index from a, with a probability proportional to 354 // its value. Using a local random-generator to prevent waiting on rand's 355 // default source. 356 func pickRandom(a []float64, rnd *rand.Rand) int { 357 if len(a) == 0 { 358 panic("Cannot pick element from an empty distribution.") 359 } 360 361 sum := float64(0) 362 for i := range a { 363 if a[i] < 0 { 364 panic(fmt.Sprintf("Got negative value in distribution: %v", a[i])) 365 } 366 sum += a[i] 367 } 368 if sum == 0 { 369 return rnd.Intn(len(a)) 370 } 371 372 r := rnd.Float64() * sum 373 i := 0 374 for i < len(a) && r > a[i] { 375 r -= a[i] 376 i++ 377 } 378 if i == len(a) { 379 i-- 380 } 381 return i 382 }