github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/models.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"os"
     8  	"os/exec"
     9  	"path/filepath"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"code.google.com/p/go-uuid/uuid"
    15  	"github.com/coreos/go-log/log"
    16  )
    17  
    18  // Prediction is the parsed result from the Python worker
    19  type Prediction struct {
    20  	ModelID string               `json:"model_id"`
    21  	Labels  []map[string]float64 `json:"labels"`
    22  }
    23  
    24  // ModelReq represents an incoming request for fit or predict
    25  type ModelReq struct {
    26  	ModelID      string                   `json:"model_id"`
    27  	Name         string                   `json:"name"`
    28  	Date         time.Time                `json:"created_at"`
    29  	Data         []map[string]interface{} `json:"data"`
    30  	Labels       []interface{}            `json:"labels"`
    31  	isRegression bool
    32  }
    33  
    34  // Model represents a previously fitted model
    35  type Model struct {
    36  	ID       string `json:"model_id"`
    37  	Metadata struct {
    38  		Name string    `json:"name"`
    39  		Date time.Time `json:"created_at"`
    40  	} `json:"metadata"`
    41  	Performance struct {
    42  		Algorithm       string                        `json:"algorithm"`
    43  		ConfusionMatrix map[string]map[string]float64 `json:"confusion_matrix,omitempty"`
    44  		Score           float64                       `json:"score"`
    45  	} `json:"performance"`
    46  	runLock sync.RWMutex // protect running attribute
    47  	Running bool         `json:"running"`
    48  	Trained bool         `json:"trained"`
    49  	// req and rep follow the zmq semantics for REQ/REP socket pairs,
    50  	// data sent to the req channel is piped to the REQ socket connected
    51  	// to the running Python process, replies from Python are piped to the
    52  	// rep channel
    53  	req, rep chan []byte
    54  	dir      string    // path to the directory containing <model_id>.pkl and <model_id>.json
    55  	cmd      *exec.Cmd // the running process
    56  }
    57  
    58  // Predict encodes the client supplied data, passes it to the Python process for
    59  // the model via zmq, parses and returns the response.
    60  func (m *Model) Predict(r ModelReq) Prediction {
    61  	// should find a way to do this w/o re-encoding
    62  	var buf bytes.Buffer
    63  	err := json.NewEncoder(&buf).Encode(r)
    64  	if err != nil {
    65  		log.Error("error encoding prediction ", err)
    66  		return Prediction{}
    67  	}
    68  
    69  	if m.req == nil {
    70  		log.Errorf("request chan for model %v is nil", m.ID)
    71  		return Prediction{}
    72  	}
    73  	m.req <- buf.Bytes()
    74  	resp := <-m.rep
    75  
    76  	var pred []map[string]float64
    77  	err = json.NewDecoder(bytes.NewReader(resp)).Decode(&pred)
    78  	if err != nil {
    79  		log.Error("error decoding prediction ", err)
    80  	}
    81  
    82  	prediction := Prediction{
    83  		ModelID: r.ModelID,
    84  		Labels:  pred,
    85  	}
    86  
    87  	return prediction
    88  }
    89  
    90  // Stop sends SIGINT to the underlying process running the model
    91  func (m *Model) Stop() error {
    92  	if m.cmd != nil {
    93  		return m.cmd.Process.Signal(os.Interrupt)
    94  	}
    95  	return nil
    96  }
    97  
    98  // ModelRepo represents a collection of models
    99  type ModelRepo struct {
   100  	sync.RWMutex
   101  	collection map[string]*Model
   102  	path       string
   103  }
   104  
   105  // NewModelRepo initializes and returns a pointer to a ModelRepo, the supplied
   106  // path argument refers to the directory where pickled models will be saved.
   107  func NewModelRepo(path string) *ModelRepo {
   108  	return &ModelRepo{
   109  		collection: make(map[string]*Model),
   110  		path:       path,
   111  	}
   112  }
   113  
   114  // Add inserts a model into the model collection
   115  func (r *ModelRepo) Add(m *Model) {
   116  	r.Lock()
   117  	defer r.Unlock()
   118  	r.collection[m.ID] = m
   119  }
   120  
   121  // Remove deletes a model from the model collection
   122  func (r *ModelRepo) Remove(id string) {
   123  	r.Lock()
   124  	defer r.Unlock()
   125  	// TODO: make sure the python process has exited or kill
   126  	// prior to delete
   127  	delete(r.collection, id)
   128  }
   129  
   130  // NewModel initializes a model with a generated ID and dir
   131  func (r *ModelRepo) NewModel() *Model {
   132  	id := uuid.New()
   133  	m := Model{ID: id, dir: filepath.Join(r.path, id)}
   134  	return &m
   135  }
   136  
   137  // All returns a slice of all models currently in the collection
   138  func (r *ModelRepo) All() []*Model {
   139  	var models []*Model
   140  
   141  	r.RLock()
   142  	for _, model := range r.collection {
   143  		models = append(models, model)
   144  	}
   145  	r.RUnlock()
   146  
   147  	return models
   148  }
   149  
   150  // ErrModelNotFound is returned when the model can't be found in the model dir
   151  var ErrModelNotFound = errors.New("model not found")
   152  
   153  // Get fetches a model by id, if the model is not present in the collection, it
   154  // will attempt to load from disk adding it to the collection. If the model is
   155  // not in the model directory, Get will return ErrModelNotFound.
   156  func (r *ModelRepo) Get(id string) (*Model, error) {
   157  	r.RLock()
   158  	m, ok := r.collection[id]
   159  	r.RUnlock()
   160  
   161  	var err error
   162  	if !ok {
   163  		m, err = r.LoadModelData(id)
   164  		if err != nil {
   165  			return nil, err
   166  		}
   167  	}
   168  
   169  	r.Add(m) // add to cache
   170  
   171  	// start/restart if not running
   172  	m.runLock.Lock() // make sure we don't start twice
   173  	defer m.runLock.Unlock()
   174  	if !m.Running {
   175  		err = startModel(m)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  	}
   180  
   181  	return m, nil
   182  }
   183  
   184  // LoadModelData loads the model metadata from the file
   185  // <path>/<model_id>/<model_id>.json, if the file does not exist, ErrModelNotFound
   186  // is returned. The json file is expected to contain the model score, confusion matrix,
   187  // and algorithm used, see Model.Metadata. The loaded model is added to the collection.
   188  func (r *ModelRepo) LoadModelData(id string) (*Model, error) {
   189  	// check the collection first
   190  	r.RLock()
   191  	m, ok := r.collection[id]
   192  	r.RUnlock()
   193  	if !ok { // not currently loaded
   194  		modelDir := filepath.Join(r.path, id)
   195  		f, err := os.Open(filepath.Join(modelDir, id+".json"))
   196  		if os.IsNotExist(err) {
   197  			return nil, ErrModelNotFound
   198  		}
   199  
   200  		if err != nil {
   201  			return nil, err
   202  		}
   203  
   204  		var m Model
   205  		err = json.NewDecoder(f).Decode(&m)
   206  		if err != nil {
   207  			return nil, err
   208  		}
   209  		m.dir = modelDir
   210  		m.Trained = true
   211  
   212  		r.Add(&m) // add to cache
   213  	}
   214  
   215  	return m, nil
   216  }
   217  
   218  func (r *ModelRepo) IndexModelDir() error {
   219  	models, err := filepath.Glob(filepath.Join(r.path, "/*"))
   220  	if err != nil {
   221  		return err
   222  	}
   223  
   224  	for _, model := range models {
   225  		modelID := strings.TrimPrefix(model, r.path+"/")
   226  		r.LoadModelData(modelID)
   227  	}
   228  	return nil
   229  }