go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/stash.go (about)

     1  package model
     2  
     3  import (
     4  	"go-ml.dev/pkg/base/fu"
     5  	"go-ml.dev/pkg/iokit"
     6  	"go-ml.dev/pkg/zorros"
     7  	"io"
     8  )
     9  
    10  type ModelStash struct {
    11  	iteration int
    12  	pattern   string
    13  	files     []iokit.TemporaryFile
    14  }
    15  
    16  func NewStash(histlen int, pattern string) *ModelStash {
    17  	return &ModelStash{
    18  		pattern: pattern,
    19  		files:   make([]iokit.TemporaryFile, histlen+1),
    20  	}
    21  }
    22  
    23  func (ms *ModelStash) Length() int {
    24  	return fu.Mini(ms.iteration+1, len(ms.files))
    25  }
    26  
    27  func (ms *ModelStash) Output(iteration int) (out iokit.Output, err error) {
    28  	ms.iteration = iteration
    29  	f := ms.files[ms.iteration%len(ms.files)]
    30  	if f == nil {
    31  		if f, err = iokit.Tempfile(ms.pattern); err != nil {
    32  			return
    33  		}
    34  		ms.files[ms.iteration%len(ms.files)] = f
    35  	} else {
    36  		if err = f.Truncate(); err != nil {
    37  			return
    38  		}
    39  	}
    40  	return iokit.Writer(f.(io.Writer)), nil
    41  }
    42  
    43  func (ms *ModelStash) Reader(iteration int) (rd io.Reader, err error) {
    44  	if iteration > ms.iteration || (ms.iteration-iteration) > len(ms.files) {
    45  		return nil, zorros.Errorf("iteration %v is out of stash [%v,%v]",
    46  			iteration,
    47  			fu.Maxi(ms.iteration-len(ms.files), 0),
    48  			ms.iteration)
    49  	}
    50  	f := ms.files[iteration%len(ms.files)]
    51  	if err = f.Reset(); err != nil {
    52  		return
    53  	}
    54  	return f, nil
    55  }
    56  
    57  func (ms *ModelStash) Close() error {
    58  	for _, f := range ms.files {
    59  		if f != nil {
    60  			f.Close()
    61  		}
    62  	}
    63  	return nil
    64  }