go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/experiments/gzip-classify/main.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package main
     9  
    10  import (
    11  	"bytes"
    12  	"cmp"
    13  	"compress/gzip"
    14  	"encoding/gob"
    15  	"encoding/json"
    16  	"fmt"
    17  	"os"
    18  	"path/filepath"
    19  	"runtime"
    20  	"sort"
    21  	"strings"
    22  	"sync"
    23  
    24  	"github.com/urfave/cli/v2"
    25  	"gopkg.in/yaml.v3"
    26  
    27  	"go.charczuk.com/sdk/async"
    28  	"go.charczuk.com/sdk/cliutil"
    29  )
    30  
    31  func main() {
    32  	app := &cli.App{
    33  		Name:  "gzip-classify",
    34  		Usage: "Generate labels for arbitrary text based on a pre-trained model.",
    35  		Commands: []*cli.Command{
    36  			entry,
    37  			train,
    38  			test,
    39  		},
    40  		DefaultCommand: "test",
    41  	}
    42  	if err := app.Run(os.Args); err != nil {
    43  		cliutil.Fatal(err)
    44  	}
    45  }
    46  
    47  var entry = &cli.Command{
    48  	Name:    "entry",
    49  	Aliases: []string{"e"},
    50  	Usage:   "emit a training dataset entry",
    51  	Flags: []cli.Flag{
    52  		&cli.StringSliceFlag{
    53  			Name:     "label",
    54  			Aliases:  []string{"l"},
    55  			Usage:    "a label",
    56  			Required: true,
    57  		},
    58  		&cli.StringFlag{
    59  			Name:     "data",
    60  			Aliases:  []string{"d"},
    61  			Usage:    "the data itself",
    62  			Required: true,
    63  		},
    64  	},
    65  	Action: func(ctx *cli.Context) error {
    66  		labels := ctx.StringSlice("label")
    67  		data, err := cliutil.FileOrStdin(ctx.String("data"))
    68  		if err != nil {
    69  			return err
    70  		}
    71  		return yaml.NewEncoder(os.Stdout).Encode(TrainingSetEntry{
    72  			Labels: labels,
    73  			Data:   string(data),
    74  		})
    75  	},
    76  }
    77  
    78  var train = &cli.Command{
    79  	Name:  "train",
    80  	Usage: "train a dataset",
    81  	Flags: []cli.Flag{
    82  		&cli.StringFlag{
    83  			Name:     "input",
    84  			Aliases:  []string{"i"},
    85  			Usage:    "The input path in glob pattern form",
    86  			Required: true,
    87  		},
    88  		&cli.StringFlag{
    89  			Name:    "output",
    90  			Aliases: []string{"o"},
    91  			Usage:   "The output path",
    92  			Value:   "data.gz",
    93  		},
    94  	},
    95  	Action: func(ctx *cli.Context) error {
    96  		inputGlob := ctx.String("input")
    97  		output := ctx.String("output")
    98  
    99  		var dataset []TrainingSetEntry
   100  		var datasetMu sync.Mutex
   101  
   102  		inputGlobMatches, err := filepath.Glob(inputGlob)
   103  		if err != nil {
   104  			return err
   105  		}
   106  
   107  		b := new(async.Batch)
   108  		b.SetLimit(runtime.NumCPU())
   109  		for _, inputFile := range inputGlobMatches {
   110  			b.Go(func() error {
   111  				var fileData TrainingSet
   112  				f, ferr := os.Open(inputFile)
   113  				if ferr != nil {
   114  					fmt.Fprintf(os.Stderr, "train; could not open input file %s: %v\n", inputFile, ferr)
   115  					return nil
   116  				}
   117  				defer f.Close()
   118  
   119  				if ferr = decodeFile(inputFile, f, &fileData); ferr != nil {
   120  					return fmt.Errorf("train; could not decode input file %s: %w", inputFile, ferr)
   121  				}
   122  				for _, entry := range fileData.Entries {
   123  					compressedData := gzipCompress([]byte(entry.Data))
   124  					entry.CompressedDataLength = len(compressedData)
   125  					datasetMu.Lock()
   126  					dataset = append(dataset, entry)
   127  					datasetMu.Unlock()
   128  				}
   129  				return nil
   130  			})
   131  		}
   132  
   133  		if err = b.Wait(); err != nil {
   134  			return err
   135  		}
   136  
   137  		of, err := os.Create(output)
   138  		if err != nil {
   139  			return fmt.Errorf("train; could not create output file: %w", err)
   140  		}
   141  		defer of.Close()
   142  
   143  		gzw := gzip.NewWriter(of)
   144  		if err = gob.NewEncoder(gzw).Encode(TrainingSet{Entries: dataset}); err != nil {
   145  			return fmt.Errorf("train; could not encode output file: %w", err)
   146  		}
   147  		if err := gzw.Flush(); err != nil {
   148  			return fmt.Errorf("train; error flushing output file writer: %w", err)
   149  		}
   150  		fmt.Printf("wrote training file: %q\n", output)
   151  		return nil
   152  	},
   153  }
   154  
   155  var test = &cli.Command{
   156  	Name:  "test",
   157  	Usage: "test a dataset",
   158  	Flags: []cli.Flag{
   159  		&cli.StringFlag{
   160  			Name:    "dataset",
   161  			Aliases: []string{"ds"},
   162  			Usage:   "The input dataset path.",
   163  			Value:   "data.gz",
   164  		},
   165  		&cli.StringFlag{
   166  			Name:     "file",
   167  			Aliases:  []string{"f"},
   168  			Usage:    "The input test data path.",
   169  			Required: true,
   170  		},
   171  		&cli.IntFlag{
   172  			Name:  "k",
   173  			Usage: "the number of results to consider",
   174  			Value: 2,
   175  		},
   176  	},
   177  	Action: func(ctx *cli.Context) error {
   178  		datasetPath := ctx.String("dataset")
   179  		var training TrainingSet
   180  		err := func() error {
   181  			f, ferr := os.Open(datasetPath)
   182  			if ferr != nil {
   183  				return ferr
   184  			}
   185  			defer f.Close()
   186  			gzr, ferr := gzip.NewReader(f)
   187  			if ferr != nil {
   188  				return ferr
   189  			}
   190  			if ferr := gob.NewDecoder(gzr).Decode(&training); ferr != nil {
   191  				return ferr
   192  			}
   193  			return nil
   194  		}()
   195  		if err != nil {
   196  			return err
   197  		}
   198  		contents, err := cliutil.FileOrStdin(ctx.String("file"))
   199  		if err != nil {
   200  			return err
   201  		}
   202  
   203  		contentsCleaned := strings.TrimSpace(string(contents))
   204  		contentsCleaned = strings.ToLower(string(contentsCleaned))
   205  
   206  		compressedContents := gzipCompress([]byte(contentsCleaned))
   207  		cx1 := len(compressedContents)
   208  
   209  		distanceFrom := make([]LabelsDistance, 0, len(training.Entries))
   210  		var distanceFromMu sync.Mutex
   211  
   212  		b := new(async.Batch)
   213  		b.SetLimit(runtime.NumCPU())
   214  		bp := newBufferPool()
   215  
   216  		testValue := func(ts TrainingSetEntry) func() error {
   217  			return func() error {
   218  				cx2 := ts.CompressedDataLength
   219  
   220  				buf := bp.Get()
   221  				defer bp.Put(buf)
   222  
   223  				buf.Write(contents)
   224  				buf.WriteRune(' ')
   225  				buf.Write([]byte(ts.Data))
   226  				merged := buf.Bytes()
   227  
   228  				mergedCompressed := gzipCompress(merged)
   229  				cx1x2 := len(mergedCompressed)
   230  				ncd := float64(cx1x2-min(cx1, cx2)) / float64(max(cx1, cx2))
   231  
   232  				distanceFromMu.Lock()
   233  				distanceFrom = append(distanceFrom, LabelsDistance{Labels: ts.Labels, Distance: ncd})
   234  				distanceFromMu.Unlock()
   235  				return nil
   236  			}
   237  		}
   238  		for _, ts := range training.Entries {
   239  			b.Go(testValue(ts))
   240  		}
   241  		if err := b.Wait(); err != nil {
   242  			return err
   243  		}
   244  
   245  		k := ctx.Int("k")
   246  		output := make(map[string][]float64)
   247  		for _, ld := range distanceFrom[:k] {
   248  			for _, label := range ld.Labels {
   249  				output[label] = append(output[label], ld.Distance)
   250  			}
   251  		}
   252  		var labelDistance = make([]LabelDistance, 0, len(output))
   253  		for label, values := range output {
   254  			labelDistance = append(labelDistance, LabelDistance{Label: label, Distance: Avg(values...)})
   255  		}
   256  		sort.Slice(labelDistance, func(i0, i1 int) bool {
   257  			return labelDistance[i0].Distance < labelDistance[i1].Distance
   258  		})
   259  		fmt.Println(labelDistance[0].Label)
   260  		return nil
   261  	},
   262  }
   263  
   264  // LabelDistance is a helper type.
   265  type LabelDistance struct {
   266  	Label    string
   267  	Distance float64
   268  }
   269  
   270  // LabelsDistance is a helper type.
   271  type LabelsDistance struct {
   272  	Labels   []string
   273  	Distance float64
   274  }
   275  
   276  func Avg(values ...float64) float64 {
   277  	var accum float64
   278  	for _, v := range values {
   279  		accum += v
   280  	}
   281  	return accum / float64(len(values))
   282  }
   283  
   284  func min[T cmp.Ordered](a, b T) T {
   285  	if a < b {
   286  		return a
   287  	}
   288  	return b
   289  }
   290  
   291  func max[T cmp.Ordered](a, b T) T {
   292  	if a < b {
   293  		return a
   294  	}
   295  	return b
   296  }
   297  
   298  // TrainingSet is the fully serialized training set.
   299  type TrainingSet struct {
   300  	Entries []TrainingSetEntry
   301  }
   302  
   303  // TrainingSetEntry is an entry in our model set.
   304  type TrainingSetEntry struct {
   305  	Labels               []string `yaml:"labels"`
   306  	Data                 string   `yaml:"data"`
   307  	CompressedDataLength int      `yaml:"-"`
   308  }
   309  
   310  func decodeFile(inputPath string, f *os.File, fileData *TrainingSet) error {
   311  	switch filepath.Ext(inputPath) {
   312  	case ".json", "json":
   313  		return json.NewDecoder(f).Decode(fileData)
   314  	case ".yaml", "yaml", ".yml", "yml":
   315  		return yaml.NewDecoder(f).Decode(fileData)
   316  	default:
   317  		return fmt.Errorf("invalid path for decoding: %s", inputPath)
   318  	}
   319  }
   320  
   321  func gzipCompress(data []byte) []byte {
   322  	buf := new(bytes.Buffer)
   323  	gzw := gzip.NewWriter(buf)
   324  	_, _ = gzw.Write(data)
   325  	_ = gzw.Flush()
   326  	return buf.Bytes()
   327  }
   328  
   329  func newBufferPool() *BufferPool {
   330  	return &BufferPool{
   331  		Pool: sync.Pool{New: func() interface{} {
   332  			b := bytes.NewBuffer(make([]byte, 512))
   333  			b.Reset()
   334  			return b
   335  		}},
   336  	}
   337  }
   338  
   339  // BufferPool is a sync.Pool of bytes.Buffer.
   340  type BufferPool struct {
   341  	sync.Pool
   342  }
   343  
   344  // Get returns a pooled bytes.Buffer instance.
   345  func (p *BufferPool) Get() *bytes.Buffer {
   346  	return p.Pool.Get().(*bytes.Buffer)
   347  }
   348  
   349  // Put returns the pooled instance.
   350  func (p *BufferPool) Put(b *bytes.Buffer) {
   351  	b.Reset()
   352  	p.Pool.Put(b)
   353  }