go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/stash.go (about) 1 package model 2 3 import ( 4 "go-ml.dev/pkg/base/fu" 5 "go-ml.dev/pkg/iokit" 6 "go-ml.dev/pkg/zorros" 7 "io" 8 ) 9 10 type ModelStash struct { 11 iteration int 12 pattern string 13 files []iokit.TemporaryFile 14 } 15 16 func NewStash(histlen int, pattern string) *ModelStash { 17 return &ModelStash{ 18 pattern: pattern, 19 files: make([]iokit.TemporaryFile, histlen+1), 20 } 21 } 22 23 func (ms *ModelStash) Length() int { 24 return fu.Mini(ms.iteration+1, len(ms.files)) 25 } 26 27 func (ms *ModelStash) Output(iteration int) (out iokit.Output, err error) { 28 ms.iteration = iteration 29 f := ms.files[ms.iteration%len(ms.files)] 30 if f == nil { 31 if f, err = iokit.Tempfile(ms.pattern); err != nil { 32 return 33 } 34 ms.files[ms.iteration%len(ms.files)] = f 35 } else { 36 if err = f.Truncate(); err != nil { 37 return 38 } 39 } 40 return iokit.Writer(f.(io.Writer)), nil 41 } 42 43 func (ms *ModelStash) Reader(iteration int) (rd io.Reader, err error) { 44 if iteration > ms.iteration || (ms.iteration-iteration) > len(ms.files) { 45 return nil, zorros.Errorf("iteration %v is out of stash [%v,%v]", 46 iteration, 47 fu.Maxi(ms.iteration-len(ms.files), 0), 48 ms.iteration) 49 } 50 f := ms.files[iteration%len(ms.files)] 51 if err = f.Reset(); err != nil { 52 return 53 } 54 return f, nil 55 } 56 57 func (ms *ModelStash) Close() error { 58 for _, f := range ms.files { 59 if f != nil { 60 f.Close() 61 } 62 } 63 return nil 64 }