github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/models.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "os" 8 "os/exec" 9 "path/filepath" 10 "strings" 11 "sync" 12 "time" 13 14 "code.google.com/p/go-uuid/uuid" 15 "github.com/coreos/go-log/log" 16 ) 17 18 // Prediction is the parsed result from the Python worker 19 type Prediction struct { 20 ModelID string `json:"model_id"` 21 Labels []map[string]float64 `json:"labels"` 22 } 23 24 // ModelReq represents an incoming request for fit or predict 25 type ModelReq struct { 26 ModelID string `json:"model_id"` 27 Name string `json:"name"` 28 Date time.Time `json:"created_at"` 29 Data []map[string]interface{} `json:"data"` 30 Labels []interface{} `json:"labels"` 31 isRegression bool 32 } 33 34 // Model represents a previously fitted model 35 type Model struct { 36 ID string `json:"model_id"` 37 Metadata struct { 38 Name string `json:"name"` 39 Date time.Time `json:"created_at"` 40 } `json:"metadata"` 41 Performance struct { 42 Algorithm string `json:"algorithm"` 43 ConfusionMatrix map[string]map[string]float64 `json:"confusion_matrix,omitempty"` 44 Score float64 `json:"score"` 45 } `json:"performance"` 46 runLock sync.RWMutex // protect running attribute 47 Running bool `json:"running"` 48 Trained bool `json:"trained"` 49 // req and rep follow the zmq semantics for REQ/REP socket pairs, 50 // data sent to the req channel is piped to the REQ socket connected 51 // to the running Python process, replies from Python are piped to the 52 // rep channel 53 req, rep chan []byte 54 dir string // path to the directory containing <model_id>.pkl and <model_id>.json 55 cmd *exec.Cmd // the running process 56 } 57 58 // Predict encodes the client supplied data, passes it to the Python process for 59 // the model via zmq, parses and returns the response. 60 func (m *Model) Predict(r ModelReq) Prediction { 61 // should find a way to do this w/o re-encoding 62 var buf bytes.Buffer 63 err := json.NewEncoder(&buf).Encode(r) 64 if err != nil { 65 log.Error("error encoding prediction ", err) 66 return Prediction{} 67 } 68 69 if m.req == nil { 70 log.Errorf("request chan for model %v is nil", m.ID) 71 return Prediction{} 72 } 73 m.req <- buf.Bytes() 74 resp := <-m.rep 75 76 var pred []map[string]float64 77 err = json.NewDecoder(bytes.NewReader(resp)).Decode(&pred) 78 if err != nil { 79 log.Error("error decoding prediction ", err) 80 } 81 82 prediction := Prediction{ 83 ModelID: r.ModelID, 84 Labels: pred, 85 } 86 87 return prediction 88 } 89 90 // Stop sends SIGINT to the underlying process running the model 91 func (m *Model) Stop() error { 92 if m.cmd != nil { 93 return m.cmd.Process.Signal(os.Interrupt) 94 } 95 return nil 96 } 97 98 // ModelRepo represents a collection of models 99 type ModelRepo struct { 100 sync.RWMutex 101 collection map[string]*Model 102 path string 103 } 104 105 // NewModelRepo initializes and returns a pointer to a ModelRepo, the supplied 106 // path argument refers to the directory where pickled models will be saved. 107 func NewModelRepo(path string) *ModelRepo { 108 return &ModelRepo{ 109 collection: make(map[string]*Model), 110 path: path, 111 } 112 } 113 114 // Add inserts a model into the model collection 115 func (r *ModelRepo) Add(m *Model) { 116 r.Lock() 117 defer r.Unlock() 118 r.collection[m.ID] = m 119 } 120 121 // Remove deletes a model from the model collection 122 func (r *ModelRepo) Remove(id string) { 123 r.Lock() 124 defer r.Unlock() 125 // TODO: make sure the python process has exited or kill 126 // prior to delete 127 delete(r.collection, id) 128 } 129 130 // NewModel initializes a model with a generated ID and dir 131 func (r *ModelRepo) NewModel() *Model { 132 id := uuid.New() 133 m := Model{ID: id, dir: filepath.Join(r.path, id)} 134 return &m 135 } 136 137 // All returns a slice of all models currently in the collection 138 func (r *ModelRepo) All() []*Model { 139 var models []*Model 140 141 r.RLock() 142 for _, model := range r.collection { 143 models = append(models, model) 144 } 145 r.RUnlock() 146 147 return models 148 } 149 150 // ErrModelNotFound is returned when the model can't be found in the model dir 151 var ErrModelNotFound = errors.New("model not found") 152 153 // Get fetches a model by id, if the model is not present in the collection, it 154 // will attempt to load from disk adding it to the collection. If the model is 155 // not in the model directory, Get will return ErrModelNotFound. 156 func (r *ModelRepo) Get(id string) (*Model, error) { 157 r.RLock() 158 m, ok := r.collection[id] 159 r.RUnlock() 160 161 var err error 162 if !ok { 163 m, err = r.LoadModelData(id) 164 if err != nil { 165 return nil, err 166 } 167 } 168 169 r.Add(m) // add to cache 170 171 // start/restart if not running 172 m.runLock.Lock() // make sure we don't start twice 173 defer m.runLock.Unlock() 174 if !m.Running { 175 err = startModel(m) 176 if err != nil { 177 return nil, err 178 } 179 } 180 181 return m, nil 182 } 183 184 // LoadModelData loads the model metadata from the file 185 // <path>/<model_id>/<model_id>.json, if the file does not exist, ErrModelNotFound 186 // is returned. The json file is expected to contain the model score, confusion matrix, 187 // and algorithm used, see Model.Metadata. The loaded model is added to the collection. 188 func (r *ModelRepo) LoadModelData(id string) (*Model, error) { 189 // check the collection first 190 r.RLock() 191 m, ok := r.collection[id] 192 r.RUnlock() 193 if !ok { // not currently loaded 194 modelDir := filepath.Join(r.path, id) 195 f, err := os.Open(filepath.Join(modelDir, id+".json")) 196 if os.IsNotExist(err) { 197 return nil, ErrModelNotFound 198 } 199 200 if err != nil { 201 return nil, err 202 } 203 204 var m Model 205 err = json.NewDecoder(f).Decode(&m) 206 if err != nil { 207 return nil, err 208 } 209 m.dir = modelDir 210 m.Trained = true 211 212 r.Add(&m) // add to cache 213 } 214 215 return m, nil 216 } 217 218 func (r *ModelRepo) IndexModelDir() error { 219 models, err := filepath.Glob(filepath.Join(r.path, "/*")) 220 if err != nil { 221 return err 222 } 223 224 for _, model := range models { 225 modelID := strings.TrimPrefix(model, r.path+"/") 226 r.LoadModelData(modelID) 227 } 228 return nil 229 }