github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/workers.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "os" 9 "os/exec" 10 "path/filepath" 11 "strings" 12 13 "github.com/coreos/go-log/log" 14 zmq "github.com/pebbe/zmq4" 15 ) 16 17 // fitModel writes the training data in json format to a temporary file. Next 18 // it launches the fit.py in a child process, passing the filename of the trainig 19 // data and the location where the model should be saved as arguments. Since we 20 // do not know the path the app will be run, we instruct python to read the fit.py 21 // source from stdin instead of executing a file. This would be equivalent to: 22 // 23 // $ python3 - < fit.py tmp.json models/model-id 24 // 25 // The source for fit.py as encoded as a raw/formatted string in the file 26 // fit_py.go 27 // 28 // When the command completes, go checks the exit status, anything other than exit(0) 29 // will result in a non-nil value for the error returned by cmd.Run(). 30 31 func fitModel(m *Model, d ModelReq, r *ModelRepo) { 32 log.Infof("started fitting model %v", m.ID) 33 // write data to temp file 34 f, err := ioutil.TempFile("", m.ID) 35 if err != nil { 36 log.Error("unable to open temp file for fitting model ", err) 37 return 38 } 39 defer os.Remove(f.Name()) 40 41 err = json.NewEncoder(f).Encode(d) 42 if err != nil { 43 log.Error("error encoding training data ", err) 44 return 45 } 46 f.Close() 47 48 cmd := exec.Command("python3", "-", m.dir, f.Name()) 49 cmd.Stdin = strings.NewReader(fitPy) 50 var stderr bytes.Buffer 51 cmd.Stderr = &stderr 52 53 err = cmd.Run() 54 if err != nil { 55 log.Errorf("error fitting model %v: %v %v", m.ID, err.Error(), stderr.String()) 56 } 57 58 // load the model into the index after fitted 59 _, err = r.LoadModelData(m.ID) 60 if err != nil { 61 log.Errorf("error loading model %v: %v", m.ID, err.Error()) 62 } 63 log.Infof("finished fitting model %v", m.ID) 64 } 65 66 // startModel launches the prediction script for a model in a child process. 67 // 68 // Requests and responses between Go and the prediction process occur via a zmq 69 // REQ/REP socket pair. The ipc socket path and model file name are passed to the 70 // python script as command line args. On startup, predicy.py loads the model and 71 // binds a REP socket to the provided ipc path. The script than starts a loop, 72 // reading data from the the socket, returning predicitons back over the socket. 73 // On the Go side, one goroutine manages the running python process, (it doesn't 74 // really do much, just sets the Running attribute to false on exit), another 75 // goroutine accepts requests via the model's req chan, forwards these to the REQ 76 // socket, reads the python response, and forwards these to the model's rep chan. 77 func startModel(m *Model) error { 78 79 // create channels and set running flag 80 m.req = make(chan []byte) 81 m.rep = make(chan []byte) 82 m.Running = true 83 84 socketPath := fmt.Sprint("ipc:///tmp/", m.ID) 85 86 socket, err := zmq.NewSocket(zmq.REQ) 87 if err != nil { 88 return err 89 } 90 91 fileName := fmt.Sprintf("%s.pkl", m.ID) 92 93 cmd := exec.Command("python3", "-", socketPath, filepath.Join(m.dir, fileName)) 94 cmd.Stdin = strings.NewReader(predictPy) 95 var stderr bytes.Buffer 96 cmd.Stderr = &stderr 97 98 m.cmd = cmd // attach cmd to the model object 99 100 // run the predict.py in a dedicated goroutine, this function will return 101 // when predict.py exits 102 go func() { 103 defer func() { 104 m.runLock.Lock() 105 m.Running = false 106 m.runLock.Unlock() 107 close(m.req) // no more requests after process exits 108 }() 109 110 log.Infof("starting model %v", m.ID) 111 err := cmd.Run() 112 if err != nil { 113 log.Errorf("model %v exited: %v %v", m.ID, err.Error(), stderr.String()) 114 return 115 } 116 // fit.py exited normally 117 log.Infof("model %v exited", m.ID) 118 }() 119 120 err = socket.Connect(socketPath) 121 if err != nil { 122 return err 123 } 124 125 // forward requests sent to model.req channel to the zeromq REQ socket, 126 // read the response from zeromq and push to model.rep channel, the loop 127 // will run until model.req is closed by the goroutine running predict.py 128 go func() { 129 for request := range m.req { 130 _, err := socket.SendBytes(request, 0) 131 if err != nil { 132 log.Errorf("error sending data to model %v: %v", m.ID, err.Error()) 133 } 134 resp, err := socket.RecvBytes(0) 135 if err != nil { 136 log.Errorf("error receiving data from model %v: %v", m.ID, err.Error()) 137 } 138 m.rep <- resp 139 } 140 close(m.rep) // no more replies after req closed 141 }() 142 143 return nil 144 }