go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/experiments/text-chain/main.go (about)

     1  /*
     2  
     3  Copyright (c) 2024 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package main
     9  
    10  import (
    11  	"bufio"
    12  	"encoding/json"
    13  	"fmt"
    14  	"math/rand"
    15  	"os"
    16  	"sort"
    17  	"strings"
    18  
    19  	"github.com/urfave/cli/v2"
    20  	"go.charczuk.com/sdk/slant"
    21  )
    22  
    23  func main() {
    24  	textchain := &cli.App{
    25  		Name:  "text-chain",
    26  		Usage: "text-chain trains a model from existing text and writes sentences.",
    27  		Commands: []*cli.Command{
    28  			train(),
    29  			generate(),
    30  		},
    31  	}
    32  	if err := textchain.Run(os.Args); err != nil {
    33  		fmt.Fprintf(os.Stderr, "%+v\n", err)
    34  		os.Exit(1)
    35  	}
    36  }
    37  
    38  func train() *cli.Command {
    39  	return &cli.Command{
    40  		Name:  "train",
    41  		Usage: "Train a model file",
    42  		Flags: []cli.Flag{
    43  			&cli.StringFlag{
    44  				Name:  "input-file",
    45  				Usage: "The input file in lines format",
    46  			},
    47  			&cli.StringFlag{
    48  				Name:  "output-file",
    49  				Usage: "The output file path",
    50  			},
    51  		},
    52  		Action: func(ctx *cli.Context) error {
    53  			slant.Print(os.Stdout, "TEXT CHAIN")
    54  
    55  			inputFile := ctx.String("input-file")
    56  			if inputFile == "" {
    57  				return cli.Exit("--input-file is required", 10)
    58  			}
    59  			outputFile := ctx.String("output-file")
    60  			if outputFile == "" {
    61  				return cli.Exit("--output-file is required", 10)
    62  			}
    63  
    64  			inputFileReader, err := os.Open(inputFile)
    65  			if err != nil {
    66  				return err
    67  			}
    68  			defer inputFileReader.Close()
    69  
    70  			wordCounts := make(map[string]map[string]int)
    71  			scanner := bufio.NewScanner(inputFileReader)
    72  			for scanner.Scan() {
    73  				line := scanner.Text()
    74  				lineFields := strings.Fields(line)
    75  				if len(lineFields) < 3 {
    76  					continue
    77  				}
    78  				previous := null
    79  				previous2 := null
    80  				for x := 0; x < len(lineFields); x++ {
    81  					word := lineFields[x]
    82  					if word == "file:" {
    83  						continue
    84  					}
    85  					addNextWord(wordCounts, previous, previous2, word)
    86  					previous = previous2
    87  					previous2 = word
    88  				}
    89  
    90  				// make sure to add the trailing word
    91  				addNextWord(wordCounts, previous, previous2, null)
    92  			}
    93  
    94  			fmt.Printf("word density map has %d words\n", len(wordCounts))
    95  
    96  			wordDensity := make(map[string]map[string]float64)
    97  			outputFileWriter, err := os.Create(outputFile)
    98  			if err != nil {
    99  				return err
   100  			}
   101  			defer outputFileWriter.Close()
   102  
   103  			fmt.Printf("writing word densities to: %s\n", outputFile)
   104  
   105  			for word, counts := range wordCounts {
   106  				var total int
   107  				for _, count := range counts {
   108  					total += count
   109  				}
   110  				wordDensity[word] = make(map[string]float64)
   111  				for next, count := range counts {
   112  					wordDensity[word][next] = float64(count) / float64(total)
   113  				}
   114  			}
   115  			return json.NewEncoder(outputFileWriter).Encode(wordDensity)
   116  		},
   117  	}
   118  }
   119  
   120  type modelDensities map[string]map[string]float64
   121  
   122  func generate() *cli.Command {
   123  	return &cli.Command{
   124  		Name:  "generate",
   125  		Usage: "Generate a statement from a model file",
   126  		Flags: []cli.Flag{
   127  			&cli.StringFlag{
   128  				Name:  "model-file",
   129  				Usage: "The model file in json formt",
   130  			},
   131  		},
   132  		Action: func(ctx *cli.Context) error {
   133  			model, err := readModelFile(ctx.String("model-file"))
   134  			if err != nil {
   135  				return err
   136  			}
   137  
   138  			token := null
   139  			token2 := null
   140  			next := predict(model, token, token2)
   141  			for next != null {
   142  				fmt.Print(next)
   143  				token = token2
   144  				token2 = next
   145  				next = predict(model, token, token2)
   146  				if next != null {
   147  					fmt.Print(" ")
   148  				}
   149  			}
   150  			fmt.Println()
   151  			return nil
   152  		},
   153  	}
   154  }
   155  
   156  type WordDensity struct {
   157  	Word    string
   158  	Density float64
   159  }
   160  
   161  func predict(model modelDensities, token, token2 string) string {
   162  	key := token + " " + token2
   163  	nextCounts, ok := model[key]
   164  	if !ok {
   165  		return null
   166  	}
   167  	var wordDensities []WordDensity
   168  	for word, density := range nextCounts {
   169  		if word == token2 {
   170  			continue
   171  		}
   172  		wordDensities = append(wordDensities, WordDensity{Word: word, Density: density})
   173  	}
   174  
   175  	sort.Slice(wordDensities, func(a, b int) bool {
   176  		return wordDensities[a].Density > wordDensities[b].Density
   177  	})
   178  
   179  	cutoff := rand.Float64()
   180  	var total float64
   181  	for _, wd := range wordDensities {
   182  		total += wd.Density
   183  		if total > cutoff {
   184  			return wd.Word
   185  		}
   186  	}
   187  	return null
   188  }
   189  
   190  func readModelFile(path string) (modelDensities, error) {
   191  	input, err := os.Open(path)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	defer input.Close()
   196  	wordDensity := make(modelDensities)
   197  	if err := json.NewDecoder(input).Decode(&wordDensity); err != nil {
   198  		return nil, err
   199  	}
   200  	return wordDensity, nil
   201  }
   202  
   203  const null = "__null__"
   204  
   205  func addNextWord(wordCounts map[string]map[string]int, from, from2, to string) {
   206  	key := from + " " + from2
   207  	_, ok := wordCounts[key]
   208  	if !ok {
   209  		wordCounts[key] = make(map[string]int)
   210  	}
   211  	wordCounts[key][to]++
   212  }