github.com/SupersunnySea/draft@v0.16.0/pkg/linguist/data/generate_classifier.go (about)

     1  // +build ignore
     2  
     3  /*
     4     This program trains a naive bayesian classifier
     5     provided by https://github.com/jbrukh/bayesian
     6     on a set of source code files
     7     provided by https://github.com/github/linguist
     8  
     9     This file is meant by run by go generate,
    10     refer to generate.go for its intended invokation
    11  */
    12  package main
    13  
    14  import (
    15  	"container/heap"
    16  	"fmt"
    17  	"io/ioutil"
    18  	"log"
    19  	"os"
    20  	"runtime"
    21  
    22  	"github.com/Azure/draft/pkg/linguist/tokenizer"
    23  	"github.com/jbrukh/bayesian"
    24  )
    25  
    26  type sampleFile struct {
    27  	lang, fp string
    28  	tokens   []string
    29  }
    30  
    31  func main() {
    32  	const (
    33  		sourcePath = "./linguist/samples"
    34  		outfile    = "./classifier"
    35  		quiet      = false
    36  	)
    37  
    38  	log.SetFlags(0)
    39  	if quiet {
    40  		log.SetOutput(ioutil.Discard)
    41  	}
    42  
    43  	// first we only read all the paths of the sample files
    44  	// and their corresponding and language names into:
    45  	sampleFiles := []*sampleFile{}
    46  	// and store all the language names into:
    47  	languages := []string{}
    48  
    49  	/*
    50  			   github/linguist has directory structure:
    51  
    52  			   ...
    53  			   ├── samples
    54  			   │   ├── (name of programming language)
    55  			   │   │   ├── (sample file in language)
    56  			   │   │   ├── (sample file in language)
    57  			   │   │   └── (sample file in language)
    58  			   │   ├── (name of another programming language)
    59  			   │   │   └── (sample file)
    60  			   ...
    61  
    62  		       the following hard-coded logic expects this layout
    63  	*/
    64  
    65  	log.Println("Scanning", sourcePath, "...")
    66  	srcDir, err := os.Open(sourcePath)
    67  	checkErr(err)
    68  
    69  	subDirs, err := srcDir.Readdir(-1)
    70  	checkErr(err)
    71  
    72  	for _, langDir := range subDirs {
    73  		lang := langDir.Name()
    74  		if !langDir.IsDir() {
    75  			log.Println("unexpected file:", lang)
    76  			continue
    77  		}
    78  
    79  		languages = append(languages, lang)
    80  
    81  		samplePath := sourcePath + "/" + lang
    82  		sampleDir, err := os.Open(samplePath)
    83  		checkErr(err)
    84  		files, err := sampleDir.Readdir(-1)
    85  		checkErr(err)
    86  		for _, file := range files {
    87  			fp := samplePath + "/" + file.Name()
    88  			if file.IsDir() {
    89  				// Skip subdirectories
    90  				continue
    91  			}
    92  			sampleFiles = append(sampleFiles, &sampleFile{lang, fp, nil})
    93  		}
    94  		sampleDir.Close()
    95  	}
    96  	log.Println("Found", len(languages), "languages in", len(sampleFiles), "files")
    97  
    98  	// simple progress bar
    99  	progress := 0.0
   100  	total := float64(len(sampleFiles)) * 2.0
   101  	progressBar := func() {
   102  		progress++
   103  		fmt.Printf("Processing files ... %.2f%%\r", progress/total*100.0)
   104  	}
   105  
   106  	// then we concurrently read and tokenize the samples
   107  	sampleChan := make(chan *sampleFile)
   108  	readyChan := make(chan struct{})
   109  	received := 0
   110  	tokenize := func(s *sampleFile) {
   111  		f, err := os.Open(s.fp)
   112  		checkErr(err)
   113  		contents, err := ioutil.ReadAll(f)
   114  		f.Close()
   115  		checkErr(err)
   116  		s.tokens = tokenizer.Tokenize(contents)
   117  		sampleChan <- s
   118  	}
   119  	dox := map[string][]string{}
   120  	for _, lang := range languages {
   121  		dox[lang] = []string{}
   122  	}
   123  	// this receives the processed files and stores their tokens with their language
   124  	go func() {
   125  		for {
   126  			s := <-sampleChan
   127  			dox[s.lang] = append(dox[s.lang], s.tokens...)
   128  			received++
   129  			progressBar()
   130  			if received == len(sampleFiles) {
   131  				close(readyChan)
   132  				return
   133  			}
   134  		}
   135  	}()
   136  
   137  	// this balances the workload (implementation at end of file)
   138  	requests := getRequestsChan(len(sampleFiles))
   139  	for i := range sampleFiles {
   140  		requests <- &request{
   141  			workFn: tokenize,
   142  			arg:    sampleFiles[i],
   143  		}
   144  		progressBar()
   145  	}
   146  
   147  	// once that's done
   148  	<-readyChan
   149  	close(requests)
   150  	fmt.Println() // for the progress bar
   151  
   152  	// we train the classifier in the arbitrary manner that its API demands
   153  	classes := make([]bayesian.Class, 1)
   154  	documents := make(map[bayesian.Class][]string)
   155  	for _, lang := range languages {
   156  		var class = bayesian.Class(lang)
   157  		classes = append(classes, class)
   158  		documents[class] = dox[lang]
   159  	}
   160  	log.Println("Creating bayesian.Classifier ...")
   161  	clsf := bayesian.NewClassifier(classes...)
   162  	for cls, dox := range documents {
   163  		clsf.Learn(dox, cls)
   164  	}
   165  
   166  	// and write the data to disk
   167  	log.Println("Serializing and exporting bayesian.Classifier to", outfile, "...")
   168  	checkErr(clsf.WriteToFile("classifier"))
   169  
   170  	log.Println("Done.")
   171  }
   172  func checkErr(err error) {
   173  	if err != nil {
   174  		log.Panicln(err)
   175  	}
   176  }
   177  
   178  // simple load balancer from "concurrency is not parallelism" talk
   179  type request struct {
   180  	workFn func(s *sampleFile)
   181  	arg    *sampleFile
   182  }
   183  type worker struct {
   184  	requests       chan *request
   185  	pending, index int
   186  }
   187  
   188  func (w *worker) work(done chan *worker) {
   189  	for {
   190  		req := <-w.requests
   191  		req.workFn(req.arg)
   192  		done <- w
   193  	}
   194  }
   195  
   196  type pool []*worker
   197  
   198  func (p pool) Less(i, j int) bool  { return p[i].pending < p[j].pending }
   199  func (p pool) Len() int            { return len(p) }
   200  func (p pool) Swap(i, j int)       { p[i], p[j] = p[j], p[i] }
   201  func (p *pool) Push(x interface{}) { *p = append(*p, x.(*worker)) }
   202  func (p *pool) Pop() interface{} {
   203  	old := *p
   204  	n := len(old)
   205  	x := old[n-1]
   206  	*p = old[0 : n-1]
   207  	return x
   208  }
   209  
   210  type balancer struct {
   211  	workers pool
   212  	done    chan *worker
   213  }
   214  
   215  func (b *balancer) balance(work chan *request) {
   216  	for {
   217  		select {
   218  		case req, ok := <-work:
   219  			if ok {
   220  				b.dispatch(req)
   221  			} else {
   222  				return
   223  			}
   224  		case w := <-b.done:
   225  			b.completed(w)
   226  		}
   227  	}
   228  }
   229  func (b *balancer) dispatch(req *request) {
   230  	w := heap.Pop(&b.workers).(*worker)
   231  	w.requests <- req
   232  	w.pending++
   233  	heap.Push(&b.workers, w)
   234  }
   235  func (b *balancer) completed(w *worker) {
   236  	w.pending--
   237  	heap.Remove(&b.workers, w.index)
   238  	heap.Push(&b.workers, w)
   239  }
   240  func getRequestsChan(jobs int) chan *request {
   241  	done := make(chan *worker)
   242  	workers := make(pool, runtime.GOMAXPROCS(0)*4) // I don't know how many workers there should be
   243  	for i := 0; i < len(workers); i++ {
   244  		w := &worker{make(chan *request, jobs), 0, i}
   245  		go w.work(done)
   246  		workers[i] = w
   247  	}
   248  	heap.Init(&workers)
   249  	b := &balancer{workers, done}
   250  	requests := make(chan *request)
   251  	go b.balance(requests)
   252  	return requests
   253  }