go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/collection.go (about)

     1  package model
     2  
     3  import (
     4  	"archive/zip"
     5  	"github.com/ulikunitz/xz"
     6  	"go-ml.dev/pkg/iokit"
     7  	"go-ml.dev/pkg/zorros"
     8  	"io"
     9  	"path/filepath"
    10  )
    11  
    12  /*
    13  Mnemosyne is a Serialization interface for an ML model parts
    14  */
    15  type Mnemosyne interface {
    16  	Memorize(*CollectionWriter) error
    17  }
    18  
    19  /*
    20  MemorizeMap maps names of models in directory to Mnemosyne abstraction
    21  */
    22  type MemorizeMap map[string]Mnemosyne
    23  
    24  /*
    25  ObjectifyMap mpas names of models in directory to objectification functions
    26  */
    27  type ObjectifyMap map[string]func(map[string]iokit.Input) (PredictionModel, error)
    28  
    29  /*
    30  Memorize writes models directory to single output
    31  */
    32  func Memorize(output iokit.Output, m MemorizeMap) error {
    33  	if output == nil {
    34  		return nil
    35  	}
    36  	f, err := output.Create()
    37  	if err != nil {
    38  		return zorros.Trace(err)
    39  	}
    40  	defer f.End()
    41  	wz := zip.NewWriter(f)
    42  	for k, w := range m {
    43  		if err = w.Memorize(&CollectionWriter{wz, k}); err != nil {
    44  			return zorros.Trace(err)
    45  		}
    46  	}
    47  	err = wz.Close()
    48  	if err != nil {
    49  		return zorros.Trace(err)
    50  	}
    51  	err = f.Commit()
    52  	if err != nil {
    53  		return zorros.Trace(err)
    54  	}
    55  	return nil
    56  }
    57  
    58  /*
    59  LuckyMemorize writes models directory to single output
    60  */
    61  func LuckyMemorize(output iokit.Output, m MemorizeMap) {
    62  	if err := Memorize(output, m); err != nil {
    63  		panic(zorros.Panic(zorros.Trace(err)))
    64  	}
    65  }
    66  
    67  /*
    68  CollectionWriter is an abstraction of a collection creator
    69  */
    70  type CollectionWriter struct {
    71  	wz *zip.Writer
    72  	k  string
    73  }
    74  
    75  /*
    76  Add an element to collection
    77  */
    78  func (c *CollectionWriter) Add(name string, write func(io.Writer) error) error {
    79  	return c.add(name, false, write)
    80  }
    81  
    82  /*
    83  Add an Lzma2 compressed element to collection
    84  */
    85  func (c *CollectionWriter) AddLzma2(name string, write func(io.Writer) error) error {
    86  	return c.add(name, true, write)
    87  }
    88  
    89  func (c *CollectionWriter) add(name string, lzma2 bool, write func(io.Writer) error) error {
    90  	fname := c.k + "/" + name
    91  	fh := &zip.FileHeader{Name: fname, Method: zip.Deflate}
    92  	if lzma2 {
    93  		fh.Method = zip.Store
    94  	}
    95  	wr, err := c.wz.CreateHeader(fh)
    96  	if err != nil {
    97  		return zorros.Trace(err)
    98  	}
    99  	if lzma2 {
   100  		xw, err := xz.NewWriter(wr)
   101  		if err != nil {
   102  			return zorros.Trace(err)
   103  		}
   104  		if err = write(xw); err != nil {
   105  			return zorros.Trace(err)
   106  		}
   107  		if err = xw.Close(); err != nil {
   108  			return zorros.Trace(err)
   109  		}
   110  	} else {
   111  		if err = write(wr); err != nil {
   112  			return zorros.Trace(err)
   113  		}
   114  	}
   115  	return nil
   116  }
   117  
   118  /*
   119  Objectify reads and reconstructs prediction models from a directory
   120  */
   121  func Objectify(input iokit.Input, m ObjectifyMap) (pm map[string]PredictionModel, err error) {
   122  	var r *zip.Reader
   123  	f, err := input.Open()
   124  	if err != nil {
   125  		return
   126  	}
   127  	defer f.Close()
   128  	if r, err = zip.NewReader(f.(io.ReaderAt), iokit.FileSize(f)); err != nil {
   129  		return nil, zorros.Trace(err)
   130  	}
   131  	dict := map[string]map[string]iokit.Input{}
   132  	order := []string{}
   133  	for _, j := range r.File {
   134  		dir := filepath.Dir(j.Name)
   135  		if dir != "" && m[dir] != nil {
   136  			d, ok := dict[dir]
   137  			if !ok {
   138  				d = map[string]iokit.Input{}
   139  				dict[dir] = d
   140  				order = append(order, dir)
   141  			}
   142  			if j.Method == zip.Store {
   143  				d[filepath.Base(j.Name)] = iokit.Compressed(iokit.ZipFile(j.Name, input))
   144  			} else {
   145  				d[filepath.Base(j.Name)] = iokit.ZipFile(j.Name, input)
   146  			}
   147  		}
   148  	}
   149  	pm = map[string]PredictionModel{}
   150  	for _, n := range order {
   151  		var v PredictionModel
   152  		f := m[n]
   153  		if v, err = f(dict[n]); err != nil {
   154  			return
   155  		}
   156  		pm[n] = v
   157  	}
   158  	return
   159  }