go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/collection.go (about) 1 package model 2 3 import ( 4 "archive/zip" 5 "github.com/ulikunitz/xz" 6 "go-ml.dev/pkg/iokit" 7 "go-ml.dev/pkg/zorros" 8 "io" 9 "path/filepath" 10 ) 11 12 /* 13 Mnemosyne is a Serialization interface for an ML model parts 14 */ 15 type Mnemosyne interface { 16 Memorize(*CollectionWriter) error 17 } 18 19 /* 20 MemorizeMap maps names of models in directory to Mnemosyne abstraction 21 */ 22 type MemorizeMap map[string]Mnemosyne 23 24 /* 25 ObjectifyMap mpas names of models in directory to objectification functions 26 */ 27 type ObjectifyMap map[string]func(map[string]iokit.Input) (PredictionModel, error) 28 29 /* 30 Memorize writes models directory to single output 31 */ 32 func Memorize(output iokit.Output, m MemorizeMap) error { 33 if output == nil { 34 return nil 35 } 36 f, err := output.Create() 37 if err != nil { 38 return zorros.Trace(err) 39 } 40 defer f.End() 41 wz := zip.NewWriter(f) 42 for k, w := range m { 43 if err = w.Memorize(&CollectionWriter{wz, k}); err != nil { 44 return zorros.Trace(err) 45 } 46 } 47 err = wz.Close() 48 if err != nil { 49 return zorros.Trace(err) 50 } 51 err = f.Commit() 52 if err != nil { 53 return zorros.Trace(err) 54 } 55 return nil 56 } 57 58 /* 59 LuckyMemorize writes models directory to single output 60 */ 61 func LuckyMemorize(output iokit.Output, m MemorizeMap) { 62 if err := Memorize(output, m); err != nil { 63 panic(zorros.Panic(zorros.Trace(err))) 64 } 65 } 66 67 /* 68 CollectionWriter is an abstraction of a collection creator 69 */ 70 type CollectionWriter struct { 71 wz *zip.Writer 72 k string 73 } 74 75 /* 76 Add an element to collection 77 */ 78 func (c *CollectionWriter) Add(name string, write func(io.Writer) error) error { 79 return c.add(name, false, write) 80 } 81 82 /* 83 Add an Lzma2 compressed element to collection 84 */ 85 func (c *CollectionWriter) AddLzma2(name string, write func(io.Writer) error) error { 86 return c.add(name, true, write) 87 } 88 89 func (c *CollectionWriter) add(name string, lzma2 bool, write func(io.Writer) error) error { 90 fname := c.k + "/" + name 91 fh := &zip.FileHeader{Name: fname, Method: zip.Deflate} 92 if lzma2 { 93 fh.Method = zip.Store 94 } 95 wr, err := c.wz.CreateHeader(fh) 96 if err != nil { 97 return zorros.Trace(err) 98 } 99 if lzma2 { 100 xw, err := xz.NewWriter(wr) 101 if err != nil { 102 return zorros.Trace(err) 103 } 104 if err = write(xw); err != nil { 105 return zorros.Trace(err) 106 } 107 if err = xw.Close(); err != nil { 108 return zorros.Trace(err) 109 } 110 } else { 111 if err = write(wr); err != nil { 112 return zorros.Trace(err) 113 } 114 } 115 return nil 116 } 117 118 /* 119 Objectify reads and reconstructs prediction models from a directory 120 */ 121 func Objectify(input iokit.Input, m ObjectifyMap) (pm map[string]PredictionModel, err error) { 122 var r *zip.Reader 123 f, err := input.Open() 124 if err != nil { 125 return 126 } 127 defer f.Close() 128 if r, err = zip.NewReader(f.(io.ReaderAt), iokit.FileSize(f)); err != nil { 129 return nil, zorros.Trace(err) 130 } 131 dict := map[string]map[string]iokit.Input{} 132 order := []string{} 133 for _, j := range r.File { 134 dir := filepath.Dir(j.Name) 135 if dir != "" && m[dir] != nil { 136 d, ok := dict[dir] 137 if !ok { 138 d = map[string]iokit.Input{} 139 dict[dir] = d 140 order = append(order, dir) 141 } 142 if j.Method == zip.Store { 143 d[filepath.Base(j.Name)] = iokit.Compressed(iokit.ZipFile(j.Name, input)) 144 } else { 145 d[filepath.Base(j.Name)] = iokit.ZipFile(j.Name, input) 146 } 147 } 148 } 149 pm = map[string]PredictionModel{} 150 for _, n := range order { 151 var v PredictionModel 152 f := m[n] 153 if v, err = f(dict[n]); err != nil { 154 return 155 } 156 pm[n] = v 157 } 158 return 159 }