go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/experiments/gzip-classify/main.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package main 9 10 import ( 11 "bytes" 12 "cmp" 13 "compress/gzip" 14 "encoding/gob" 15 "encoding/json" 16 "fmt" 17 "os" 18 "path/filepath" 19 "runtime" 20 "sort" 21 "strings" 22 "sync" 23 24 "github.com/urfave/cli/v2" 25 "gopkg.in/yaml.v3" 26 27 "go.charczuk.com/sdk/async" 28 "go.charczuk.com/sdk/cliutil" 29 ) 30 31 func main() { 32 app := &cli.App{ 33 Name: "gzip-classify", 34 Usage: "Generate labels for arbitrary text based on a pre-trained model.", 35 Commands: []*cli.Command{ 36 entry, 37 train, 38 test, 39 }, 40 DefaultCommand: "test", 41 } 42 if err := app.Run(os.Args); err != nil { 43 cliutil.Fatal(err) 44 } 45 } 46 47 var entry = &cli.Command{ 48 Name: "entry", 49 Aliases: []string{"e"}, 50 Usage: "emit a training dataset entry", 51 Flags: []cli.Flag{ 52 &cli.StringSliceFlag{ 53 Name: "label", 54 Aliases: []string{"l"}, 55 Usage: "a label", 56 Required: true, 57 }, 58 &cli.StringFlag{ 59 Name: "data", 60 Aliases: []string{"d"}, 61 Usage: "the data itself", 62 Required: true, 63 }, 64 }, 65 Action: func(ctx *cli.Context) error { 66 labels := ctx.StringSlice("label") 67 data, err := cliutil.FileOrStdin(ctx.String("data")) 68 if err != nil { 69 return err 70 } 71 return yaml.NewEncoder(os.Stdout).Encode(TrainingSetEntry{ 72 Labels: labels, 73 Data: string(data), 74 }) 75 }, 76 } 77 78 var train = &cli.Command{ 79 Name: "train", 80 Usage: "train a dataset", 81 Flags: []cli.Flag{ 82 &cli.StringFlag{ 83 Name: "input", 84 Aliases: []string{"i"}, 85 Usage: "The input path in glob pattern form", 86 Required: true, 87 }, 88 &cli.StringFlag{ 89 Name: "output", 90 Aliases: []string{"o"}, 91 Usage: "The output path", 92 Value: "data.gz", 93 }, 94 }, 95 Action: func(ctx *cli.Context) error { 96 inputGlob := ctx.String("input") 97 output := ctx.String("output") 98 99 var dataset []TrainingSetEntry 100 var datasetMu sync.Mutex 101 102 inputGlobMatches, err := filepath.Glob(inputGlob) 103 if err != nil { 104 return err 105 } 106 107 b := new(async.Batch) 108 b.SetLimit(runtime.NumCPU()) 109 for _, inputFile := range inputGlobMatches { 110 b.Go(func() error { 111 var fileData TrainingSet 112 f, ferr := os.Open(inputFile) 113 if ferr != nil { 114 fmt.Fprintf(os.Stderr, "train; could not open input file %s: %v\n", inputFile, ferr) 115 return nil 116 } 117 defer f.Close() 118 119 if ferr = decodeFile(inputFile, f, &fileData); ferr != nil { 120 return fmt.Errorf("train; could not decode input file %s: %w", inputFile, ferr) 121 } 122 for _, entry := range fileData.Entries { 123 compressedData := gzipCompress([]byte(entry.Data)) 124 entry.CompressedDataLength = len(compressedData) 125 datasetMu.Lock() 126 dataset = append(dataset, entry) 127 datasetMu.Unlock() 128 } 129 return nil 130 }) 131 } 132 133 if err = b.Wait(); err != nil { 134 return err 135 } 136 137 of, err := os.Create(output) 138 if err != nil { 139 return fmt.Errorf("train; could not create output file: %w", err) 140 } 141 defer of.Close() 142 143 gzw := gzip.NewWriter(of) 144 if err = gob.NewEncoder(gzw).Encode(TrainingSet{Entries: dataset}); err != nil { 145 return fmt.Errorf("train; could not encode output file: %w", err) 146 } 147 if err := gzw.Flush(); err != nil { 148 return fmt.Errorf("train; error flushing output file writer: %w", err) 149 } 150 fmt.Printf("wrote training file: %q\n", output) 151 return nil 152 }, 153 } 154 155 var test = &cli.Command{ 156 Name: "test", 157 Usage: "test a dataset", 158 Flags: []cli.Flag{ 159 &cli.StringFlag{ 160 Name: "dataset", 161 Aliases: []string{"ds"}, 162 Usage: "The input dataset path.", 163 Value: "data.gz", 164 }, 165 &cli.StringFlag{ 166 Name: "file", 167 Aliases: []string{"f"}, 168 Usage: "The input test data path.", 169 Required: true, 170 }, 171 &cli.IntFlag{ 172 Name: "k", 173 Usage: "the number of results to consider", 174 Value: 2, 175 }, 176 }, 177 Action: func(ctx *cli.Context) error { 178 datasetPath := ctx.String("dataset") 179 var training TrainingSet 180 err := func() error { 181 f, ferr := os.Open(datasetPath) 182 if ferr != nil { 183 return ferr 184 } 185 defer f.Close() 186 gzr, ferr := gzip.NewReader(f) 187 if ferr != nil { 188 return ferr 189 } 190 if ferr := gob.NewDecoder(gzr).Decode(&training); ferr != nil { 191 return ferr 192 } 193 return nil 194 }() 195 if err != nil { 196 return err 197 } 198 contents, err := cliutil.FileOrStdin(ctx.String("file")) 199 if err != nil { 200 return err 201 } 202 203 contentsCleaned := strings.TrimSpace(string(contents)) 204 contentsCleaned = strings.ToLower(string(contentsCleaned)) 205 206 compressedContents := gzipCompress([]byte(contentsCleaned)) 207 cx1 := len(compressedContents) 208 209 distanceFrom := make([]LabelsDistance, 0, len(training.Entries)) 210 var distanceFromMu sync.Mutex 211 212 b := new(async.Batch) 213 b.SetLimit(runtime.NumCPU()) 214 bp := newBufferPool() 215 216 testValue := func(ts TrainingSetEntry) func() error { 217 return func() error { 218 cx2 := ts.CompressedDataLength 219 220 buf := bp.Get() 221 defer bp.Put(buf) 222 223 buf.Write(contents) 224 buf.WriteRune(' ') 225 buf.Write([]byte(ts.Data)) 226 merged := buf.Bytes() 227 228 mergedCompressed := gzipCompress(merged) 229 cx1x2 := len(mergedCompressed) 230 ncd := float64(cx1x2-min(cx1, cx2)) / float64(max(cx1, cx2)) 231 232 distanceFromMu.Lock() 233 distanceFrom = append(distanceFrom, LabelsDistance{Labels: ts.Labels, Distance: ncd}) 234 distanceFromMu.Unlock() 235 return nil 236 } 237 } 238 for _, ts := range training.Entries { 239 b.Go(testValue(ts)) 240 } 241 if err := b.Wait(); err != nil { 242 return err 243 } 244 245 k := ctx.Int("k") 246 output := make(map[string][]float64) 247 for _, ld := range distanceFrom[:k] { 248 for _, label := range ld.Labels { 249 output[label] = append(output[label], ld.Distance) 250 } 251 } 252 var labelDistance = make([]LabelDistance, 0, len(output)) 253 for label, values := range output { 254 labelDistance = append(labelDistance, LabelDistance{Label: label, Distance: Avg(values...)}) 255 } 256 sort.Slice(labelDistance, func(i0, i1 int) bool { 257 return labelDistance[i0].Distance < labelDistance[i1].Distance 258 }) 259 fmt.Println(labelDistance[0].Label) 260 return nil 261 }, 262 } 263 264 // LabelDistance is a helper type. 265 type LabelDistance struct { 266 Label string 267 Distance float64 268 } 269 270 // LabelsDistance is a helper type. 271 type LabelsDistance struct { 272 Labels []string 273 Distance float64 274 } 275 276 func Avg(values ...float64) float64 { 277 var accum float64 278 for _, v := range values { 279 accum += v 280 } 281 return accum / float64(len(values)) 282 } 283 284 func min[T cmp.Ordered](a, b T) T { 285 if a < b { 286 return a 287 } 288 return b 289 } 290 291 func max[T cmp.Ordered](a, b T) T { 292 if a < b { 293 return a 294 } 295 return b 296 } 297 298 // TrainingSet is the fully serialized training set. 299 type TrainingSet struct { 300 Entries []TrainingSetEntry 301 } 302 303 // TrainingSetEntry is an entry in our model set. 304 type TrainingSetEntry struct { 305 Labels []string `yaml:"labels"` 306 Data string `yaml:"data"` 307 CompressedDataLength int `yaml:"-"` 308 } 309 310 func decodeFile(inputPath string, f *os.File, fileData *TrainingSet) error { 311 switch filepath.Ext(inputPath) { 312 case ".json", "json": 313 return json.NewDecoder(f).Decode(fileData) 314 case ".yaml", "yaml", ".yml", "yml": 315 return yaml.NewDecoder(f).Decode(fileData) 316 default: 317 return fmt.Errorf("invalid path for decoding: %s", inputPath) 318 } 319 } 320 321 func gzipCompress(data []byte) []byte { 322 buf := new(bytes.Buffer) 323 gzw := gzip.NewWriter(buf) 324 _, _ = gzw.Write(data) 325 _ = gzw.Flush() 326 return buf.Bytes() 327 } 328 329 func newBufferPool() *BufferPool { 330 return &BufferPool{ 331 Pool: sync.Pool{New: func() interface{} { 332 b := bytes.NewBuffer(make([]byte, 512)) 333 b.Reset() 334 return b 335 }}, 336 } 337 } 338 339 // BufferPool is a sync.Pool of bytes.Buffer. 340 type BufferPool struct { 341 sync.Pool 342 } 343 344 // Get returns a pooled bytes.Buffer instance. 345 func (p *BufferPool) Get() *bytes.Buffer { 346 return p.Pool.Get().(*bytes.Buffer) 347 } 348 349 // Put returns the pooled instance. 350 func (p *BufferPool) Put(b *bytes.Buffer) { 351 b.Reset() 352 p.Pool.Put(b) 353 }