go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/experiments/text-chain/main.go (about) 1 /* 2 3 Copyright (c) 2024 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package main 9 10 import ( 11 "bufio" 12 "encoding/json" 13 "fmt" 14 "math/rand" 15 "os" 16 "sort" 17 "strings" 18 19 "github.com/urfave/cli/v2" 20 "go.charczuk.com/sdk/slant" 21 ) 22 23 func main() { 24 textchain := &cli.App{ 25 Name: "text-chain", 26 Usage: "text-chain trains a model from existing text and writes sentences.", 27 Commands: []*cli.Command{ 28 train(), 29 generate(), 30 }, 31 } 32 if err := textchain.Run(os.Args); err != nil { 33 fmt.Fprintf(os.Stderr, "%+v\n", err) 34 os.Exit(1) 35 } 36 } 37 38 func train() *cli.Command { 39 return &cli.Command{ 40 Name: "train", 41 Usage: "Train a model file", 42 Flags: []cli.Flag{ 43 &cli.StringFlag{ 44 Name: "input-file", 45 Usage: "The input file in lines format", 46 }, 47 &cli.StringFlag{ 48 Name: "output-file", 49 Usage: "The output file path", 50 }, 51 }, 52 Action: func(ctx *cli.Context) error { 53 slant.Print(os.Stdout, "TEXT CHAIN") 54 55 inputFile := ctx.String("input-file") 56 if inputFile == "" { 57 return cli.Exit("--input-file is required", 10) 58 } 59 outputFile := ctx.String("output-file") 60 if outputFile == "" { 61 return cli.Exit("--output-file is required", 10) 62 } 63 64 inputFileReader, err := os.Open(inputFile) 65 if err != nil { 66 return err 67 } 68 defer inputFileReader.Close() 69 70 wordCounts := make(map[string]map[string]int) 71 scanner := bufio.NewScanner(inputFileReader) 72 for scanner.Scan() { 73 line := scanner.Text() 74 lineFields := strings.Fields(line) 75 if len(lineFields) < 3 { 76 continue 77 } 78 previous := null 79 previous2 := null 80 for x := 0; x < len(lineFields); x++ { 81 word := lineFields[x] 82 if word == "file:" { 83 continue 84 } 85 addNextWord(wordCounts, previous, previous2, word) 86 previous = previous2 87 previous2 = word 88 } 89 90 // make sure to add the trailing word 91 addNextWord(wordCounts, previous, previous2, null) 92 } 93 94 fmt.Printf("word density map has %d words\n", len(wordCounts)) 95 96 wordDensity := make(map[string]map[string]float64) 97 outputFileWriter, err := os.Create(outputFile) 98 if err != nil { 99 return err 100 } 101 defer outputFileWriter.Close() 102 103 fmt.Printf("writing word densities to: %s\n", outputFile) 104 105 for word, counts := range wordCounts { 106 var total int 107 for _, count := range counts { 108 total += count 109 } 110 wordDensity[word] = make(map[string]float64) 111 for next, count := range counts { 112 wordDensity[word][next] = float64(count) / float64(total) 113 } 114 } 115 return json.NewEncoder(outputFileWriter).Encode(wordDensity) 116 }, 117 } 118 } 119 120 type modelDensities map[string]map[string]float64 121 122 func generate() *cli.Command { 123 return &cli.Command{ 124 Name: "generate", 125 Usage: "Generate a statement from a model file", 126 Flags: []cli.Flag{ 127 &cli.StringFlag{ 128 Name: "model-file", 129 Usage: "The model file in json formt", 130 }, 131 }, 132 Action: func(ctx *cli.Context) error { 133 model, err := readModelFile(ctx.String("model-file")) 134 if err != nil { 135 return err 136 } 137 138 token := null 139 token2 := null 140 next := predict(model, token, token2) 141 for next != null { 142 fmt.Print(next) 143 token = token2 144 token2 = next 145 next = predict(model, token, token2) 146 if next != null { 147 fmt.Print(" ") 148 } 149 } 150 fmt.Println() 151 return nil 152 }, 153 } 154 } 155 156 type WordDensity struct { 157 Word string 158 Density float64 159 } 160 161 func predict(model modelDensities, token, token2 string) string { 162 key := token + " " + token2 163 nextCounts, ok := model[key] 164 if !ok { 165 return null 166 } 167 var wordDensities []WordDensity 168 for word, density := range nextCounts { 169 if word == token2 { 170 continue 171 } 172 wordDensities = append(wordDensities, WordDensity{Word: word, Density: density}) 173 } 174 175 sort.Slice(wordDensities, func(a, b int) bool { 176 return wordDensities[a].Density > wordDensities[b].Density 177 }) 178 179 cutoff := rand.Float64() 180 var total float64 181 for _, wd := range wordDensities { 182 total += wd.Density 183 if total > cutoff { 184 return wd.Word 185 } 186 } 187 return null 188 } 189 190 func readModelFile(path string) (modelDensities, error) { 191 input, err := os.Open(path) 192 if err != nil { 193 return nil, err 194 } 195 defer input.Close() 196 wordDensity := make(modelDensities) 197 if err := json.NewDecoder(input).Decode(&wordDensity); err != nil { 198 return nil, err 199 } 200 return wordDensity, nil 201 } 202 203 const null = "__null__" 204 205 func addNextWord(wordCounts map[string]map[string]int, from, from2, to string) { 206 key := from + " " + from2 207 _, ok := wordCounts[key] 208 if !ok { 209 wordCounts[key] = make(map[string]int) 210 } 211 wordCounts[key][to]++ 212 }