github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/nlp/lda.go (about)

     1  package nlp
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"sort"
     7  	"time"
     8  )
     9  
    10  // LdaVerbose determines whether progress information should be printed during
    11  // LDA. For debugging.
    12  var LdaVerbose = false
    13  
    14  // ----- INTERFACE FUNCTIONS ---------------------------------------------------
    15  
    16  // Lda performs LDA on the given data. docTokens should contain tokenized
    17  // documents, such that docTokens[i][j] is the j'th token in the i'th document.
    18  // k is the number of topics. Returns the topics and token-topic assignment,
    19  // respective to docTokens.
    20  //
    21  // Topics are returned in a map from word to a probability vector, such that
    22  // the i'th position is the probability of the i'th topic generating that word.
    23  // For each i, the i'th position of all words sum to 1.
    24  func Lda(docTokens [][]string, k int) (map[string][]float64, [][]int) {
    25  	return LdaThreads(docTokens, k, 1)
    26  }
    27  
    28  // LdaThreads is like the function Lda but runs on multiple subroutines.
    29  // Calling this function with 1 thread is equivalent to calling Lda.
    30  func LdaThreads(docTokens [][]string, k, numThreads int) (map[string][]float64,
    31  	[][]int) {
    32  	// Check input.
    33  	if k < 1 {
    34  		panic(fmt.Sprintf("k must be positive. Got %d.", k))
    35  	}
    36  	if numThreads < 1 {
    37  		panic(fmt.Sprintf("Number of threads must be positive. Got %d.",
    38  			numThreads))
    39  	}
    40  
    41  	// Create word map.
    42  	words := map[string]int{}
    43  	for _, doc := range docTokens {
    44  		for _, word := range doc {
    45  			if _, ok := words[word]; !ok {
    46  				words[word] = len(words)
    47  			}
    48  		}
    49  	}
    50  	if len(words) == 0 {
    51  		panic("Found 0 words in documents.")
    52  	}
    53  	if LdaVerbose {
    54  		fmt.Println("LDA:", len(words), "words in dictionary")
    55  	}
    56  
    57  	// Convert tokens to indexes.
    58  	docs := make([][]int, len(docTokens))
    59  	for i := range docs {
    60  		docs[i] = make([]int, len(docTokens[i]))
    61  		for j := range docs[i] {
    62  			docs[i][j] = words[docTokens[i][j]]
    63  		}
    64  	}
    65  
    66  	topics := newDists(k, len(words), 0.1/float64(len(words)))
    67  
    68  	// Initial topic assignment.
    69  	doct := make([][]int, len(docs))
    70  	for i := range docs {
    71  		doct[i] = make([]int, len(docs[i]))
    72  		for j := range doct[i] {
    73  			t := rand.Intn(k)
    74  			doct[i][j] = t
    75  			topics[t].add(docs[i][j])
    76  		}
    77  	}
    78  
    79  	lastChange := 0 // How many words changed their topic in the last iteration.
    80  	for _, t := range doct {
    81  		lastChange += len(t)
    82  	}
    83  	breakSignals := 0
    84  
    85  	// Fun part!
    86  	for {
    87  		newTopics := newDists(k, len(words), 0.1/float64(len(words)))
    88  
    89  		// Big buffers for speed.
    90  		push := make(chan int, numThreads*1000)
    91  		pull := make(chan int, numThreads*1000)
    92  		change := make(chan int, numThreads)
    93  		done := make(chan int, numThreads)
    94  
    95  		// Pusher thread - pushes documnet index to threads.
    96  		go func() {
    97  			for i := range docs {
    98  				push <- i
    99  			}
   100  			close(push)
   101  		}()
   102  
   103  		// Puller thread - updates new topics with done documents.
   104  		go func() {
   105  			count := 0
   106  			progress := -1
   107  			for i := range pull {
   108  				// Print progress if verbose.
   109  				if LdaVerbose {
   110  					count++
   111  					newProgress := count * 100 / len(doct)
   112  					if newProgress > progress {
   113  						progress = newProgress
   114  						fmt.Printf("\rLDA: [%d%%]", progress)
   115  					}
   116  				}
   117  
   118  				// Update document.
   119  				for j := range doct[i] {
   120  					newTopics[doct[i][j]].add(docs[i][j])
   121  				}
   122  			}
   123  
   124  			if LdaVerbose {
   125  				fmt.Println()
   126  			}
   127  			done <- 0
   128  		}()
   129  
   130  		// changeCount thread - counts how many word changed their topic.
   131  		changeCount := 0
   132  		go func() {
   133  			for count := range change {
   134  				changeCount += count
   135  			}
   136  			done <- 0
   137  		}()
   138  
   139  		// Worker threads.
   140  		for thread := 0; thread < numThreads; thread++ {
   141  			go func() {
   142  				// Make a local copy of topics.
   143  				myTopics := copyDists(topics)
   144  				myChangeCount := 0
   145  				myRand := newRand()      // Thread-local random to prevent waiting on rand's default source.
   146  				ts := make([]float64, k) // Reusable slice for randomly picking topics.
   147  
   148  				// For each document.
   149  				for i := range push {
   150  					// Create distribution of profiles.
   151  					d := newDist(k, 0.1/float64(k))
   152  					for j := range doct[i] {
   153  						d.add(doct[i][j])
   154  					}
   155  
   156  					// Reassign each word.
   157  					for j := range doct[i] {
   158  						t := doct[i][j]
   159  						word := docs[i][j]
   160  
   161  						// Unassign.
   162  						d.sub(t)
   163  						myTopics[t].sub(word)
   164  
   165  						// Pick new topic.
   166  						for k := range ts {
   167  							ts[k] = d.p(k) * myTopics[k].p(word)
   168  						}
   169  						t2 := pickRandom(ts, myRand)
   170  						if t2 != doct[i][j] {
   171  							myChangeCount++
   172  						}
   173  
   174  						// Assign.
   175  						doct[i][j] = t2
   176  						d.add(t2)
   177  						myTopics[t2].add(word)
   178  					}
   179  
   180  					// Report this doc is done.
   181  					pull <- i
   182  				}
   183  
   184  				change <- myChangeCount
   185  				done <- 0
   186  			}()
   187  		}
   188  
   189  		// Wait for threads.
   190  		for i := 0; i < numThreads; i++ {
   191  			<-done
   192  		}
   193  		close(pull)
   194  		close(change)
   195  		<-done
   196  		<-done
   197  
   198  		// Update topics.
   199  		topics = newTopics
   200  
   201  		// Check halting condition.
   202  		if changeCount >= lastChange {
   203  			breakSignals++
   204  			if breakSignals == 5 {
   205  				break
   206  			}
   207  		}
   208  
   209  		if LdaVerbose {
   210  			fmt.Printf("LDA: Changes: %d (%d) %.3f\n", changeCount, breakSignals,
   211  				float64(changeCount)/float64(lastChange))
   212  		}
   213  		lastChange = changeCount
   214  	}
   215  
   216  	// Make return values.
   217  	topicDists := make([][]float64, len(topics))
   218  	for i := range topicDists {
   219  		topicDists[i] = topics[i].dist()
   220  	}
   221  
   222  	dict := map[string][]float64{}
   223  	for word, i := range words {
   224  		d := make([]float64, k)
   225  		for j := range d {
   226  			d[j] = topicDists[j][i]
   227  		}
   228  		dict[word] = d
   229  	}
   230  
   231  	return dict, doct
   232  }
   233  
   234  // ----- HELPERS ---------------------------------------------------------------
   235  
   236  // dist is a distribution on elements by counts.
   237  type dist struct {
   238  	sum    float64
   239  	count  []float64
   240  	alpha  float64
   241  	alphas float64
   242  }
   243  
   244  // newDist creates a new empty distribution.
   245  func newDist(n int, alpha float64) *dist {
   246  	return &dist{0, make([]float64, n), alpha, alpha * float64(n)}
   247  }
   248  
   249  // newDists creates a slice of empty distributions.
   250  func newDists(k, n int, alpha float64) []*dist {
   251  	result := make([]*dist, k)
   252  	for i := range result {
   253  		result[i] = newDist(n, alpha)
   254  	}
   255  	return result
   256  }
   257  
   258  // p returns the probability of i, considering alpha.
   259  func (d *dist) p(i int) float64 {
   260  	if d.sum == 0 {
   261  		return 0
   262  	}
   263  	return (d.count[i] + d.alpha*d.sum) / (d.sum + d.alphas*d.sum)
   264  }
   265  
   266  // add increments i by 1.
   267  func (d *dist) add(i int) {
   268  	d.count[i]++
   269  	d.sum++
   270  }
   271  
   272  // sun decrements i by 1.
   273  func (d *dist) sub(i int) {
   274  	d.count[i]--
   275  	d.sum--
   276  
   277  	if d.count[i] < 0 {
   278  		panic(fmt.Sprintf("Reached negative count for i=%d.", i))
   279  	}
   280  }
   281  
   282  // dist returns the counts of this distribution, normalized by its sum.
   283  func (d *dist) dist() []float64 {
   284  	result := make([]float64, len(d.count))
   285  	copy(result, d.count)
   286  	if d.sum != 0 {
   287  		for i := range result {
   288  			result[i] /= d.sum
   289  		}
   290  	}
   291  	return result
   292  }
   293  
   294  // copy deep-copies a distribution.
   295  func (d *dist) copy() *dist {
   296  	count := make([]float64, len(d.count))
   297  	for i := range count {
   298  		count[i] = d.count[i]
   299  	}
   300  	return &dist{d.sum, count, d.alpha, d.alphas}
   301  }
   302  
   303  // copyDists deep-copies a slice of distributions.
   304  func copyDists(dists []*dist) []*dist {
   305  	result := make([]*dist, len(dists))
   306  	for i := range result {
   307  		result[i] = dists[i].copy()
   308  	}
   309  	return result
   310  }
   311  
   312  // top returns the n most likely items in the distribution.
   313  func (d *dist) top(n int) []int {
   314  	s := newDistSorter(d)
   315  	sort.Sort(s)
   316  	if n > len(s.perm) {
   317  		n = len(s.perm)
   318  	}
   319  	return s.perm[:n]
   320  }
   321  
   322  // distSorter is a distribution sorting interface.
   323  type distSorter struct {
   324  	*dist
   325  	perm []int
   326  }
   327  
   328  func newDistSorter(d *dist) *distSorter {
   329  	s := &distSorter{d, make([]int, len(d.count))}
   330  	for i := range s.perm {
   331  		s.perm[i] = i
   332  	}
   333  	return s
   334  }
   335  
   336  func (d *distSorter) Len() int {
   337  	return len(d.perm)
   338  }
   339  
   340  func (d *distSorter) Less(i, j int) bool {
   341  	return d.count[d.perm[i]] > d.count[d.perm[j]]
   342  }
   343  
   344  func (d *distSorter) Swap(i, j int) {
   345  	d.perm[i], d.perm[j] = d.perm[j], d.perm[i]
   346  }
   347  
   348  // newRand creates a new random generator.
   349  func newRand() *rand.Rand {
   350  	return rand.New(rand.NewSource(time.Now().UnixNano()))
   351  }
   352  
   353  // pickRandom picks a random index from a, with a probability proportional to
   354  // its value. Using a local random-generator to prevent waiting on rand's
   355  // default source.
   356  func pickRandom(a []float64, rnd *rand.Rand) int {
   357  	if len(a) == 0 {
   358  		panic("Cannot pick element from an empty distribution.")
   359  	}
   360  
   361  	sum := float64(0)
   362  	for i := range a {
   363  		if a[i] < 0 {
   364  			panic(fmt.Sprintf("Got negative value in distribution: %v", a[i]))
   365  		}
   366  		sum += a[i]
   367  	}
   368  	if sum == 0 {
   369  		return rnd.Intn(len(a))
   370  	}
   371  
   372  	r := rnd.Float64() * sum
   373  	i := 0
   374  	for i < len(a) && r > a[i] {
   375  		r -= a[i]
   376  		i++
   377  	}
   378  	if i == len(a) {
   379  		i--
   380  	}
   381  	return i
   382  }