github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/workers.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"github.com/coreos/go-log/log"
    14  	zmq "github.com/pebbe/zmq4"
    15  )
    16  
    17  // fitModel writes the training data in json format to a temporary file. Next
    18  // it launches the fit.py in a child process, passing the filename of the trainig
    19  // data and the location where the model should be saved as arguments. Since we
    20  // do not know the path the app will be run, we instruct python to read the fit.py
    21  // source from stdin instead of executing a file. This would be equivalent to:
    22  //
    23  // 	$ python3 - < fit.py tmp.json models/model-id
    24  //
    25  // The source for fit.py as encoded as a raw/formatted string in the file
    26  // fit_py.go
    27  //
    28  // When the command completes, go checks the exit status, anything other than exit(0)
    29  // will result in a non-nil value for the error returned by cmd.Run().
    30  
    31  func fitModel(m *Model, d ModelReq, r *ModelRepo) {
    32  	log.Infof("started fitting model %v", m.ID)
    33  	// write data to temp file
    34  	f, err := ioutil.TempFile("", m.ID)
    35  	if err != nil {
    36  		log.Error("unable to open temp file for fitting model ", err)
    37  		return
    38  	}
    39  	defer os.Remove(f.Name())
    40  
    41  	err = json.NewEncoder(f).Encode(d)
    42  	if err != nil {
    43  		log.Error("error encoding training data ", err)
    44  		return
    45  	}
    46  	f.Close()
    47  
    48  	cmd := exec.Command("python3", "-", m.dir, f.Name())
    49  	cmd.Stdin = strings.NewReader(fitPy)
    50  	var stderr bytes.Buffer
    51  	cmd.Stderr = &stderr
    52  
    53  	err = cmd.Run()
    54  	if err != nil {
    55  		log.Errorf("error fitting model %v: %v %v", m.ID, err.Error(), stderr.String())
    56  	}
    57  
    58  	// load the model into the index after fitted
    59  	_, err = r.LoadModelData(m.ID)
    60  	if err != nil {
    61  		log.Errorf("error loading model %v: %v", m.ID, err.Error())
    62  	}
    63  	log.Infof("finished fitting model %v", m.ID)
    64  }
    65  
    66  // startModel launches the prediction script for a model in a child process.
    67  //
    68  // Requests and responses between Go and the prediction process occur via a zmq
    69  // REQ/REP socket pair. The ipc socket path and model file name are passed to the
    70  // python script as command line args. On startup, predicy.py loads the model and
    71  // binds a REP socket to the provided ipc path. The script than starts a loop,
    72  // reading data from the the socket, returning predicitons back over the socket.
    73  // On the Go side, one goroutine manages the running python process, (it doesn't
    74  // really do much, just sets the Running attribute to false on exit), another
    75  // goroutine accepts requests via the model's req chan, forwards these to the REQ
    76  // socket, reads the python response, and forwards these to the model's rep chan.
    77  func startModel(m *Model) error {
    78  
    79  	// create channels and set running flag
    80  	m.req = make(chan []byte)
    81  	m.rep = make(chan []byte)
    82  	m.Running = true
    83  
    84  	socketPath := fmt.Sprint("ipc:///tmp/", m.ID)
    85  
    86  	socket, err := zmq.NewSocket(zmq.REQ)
    87  	if err != nil {
    88  		return err
    89  	}
    90  
    91  	fileName := fmt.Sprintf("%s.pkl", m.ID)
    92  
    93  	cmd := exec.Command("python3", "-", socketPath, filepath.Join(m.dir, fileName))
    94  	cmd.Stdin = strings.NewReader(predictPy)
    95  	var stderr bytes.Buffer
    96  	cmd.Stderr = &stderr
    97  
    98  	m.cmd = cmd // attach cmd to the model object
    99  
   100  	// run the predict.py in a dedicated goroutine, this function will return
   101  	// when predict.py exits
   102  	go func() {
   103  		defer func() {
   104  			m.runLock.Lock()
   105  			m.Running = false
   106  			m.runLock.Unlock()
   107  			close(m.req) // no more requests after process exits
   108  		}()
   109  
   110  		log.Infof("starting model %v", m.ID)
   111  		err := cmd.Run()
   112  		if err != nil {
   113  			log.Errorf("model %v exited: %v %v", m.ID, err.Error(), stderr.String())
   114  			return
   115  		}
   116  		// fit.py exited normally
   117  		log.Infof("model %v exited", m.ID)
   118  	}()
   119  
   120  	err = socket.Connect(socketPath)
   121  	if err != nil {
   122  		return err
   123  	}
   124  
   125  	// forward requests sent to model.req channel to the zeromq REQ socket,
   126  	// read the response from zeromq and push to model.rep channel, the loop
   127  	// will run until model.req is closed by the goroutine running predict.py
   128  	go func() {
   129  		for request := range m.req {
   130  			_, err := socket.SendBytes(request, 0)
   131  			if err != nil {
   132  				log.Errorf("error sending data to model %v: %v", m.ID, err.Error())
   133  			}
   134  			resp, err := socket.RecvBytes(0)
   135  			if err != nil {
   136  				log.Errorf("error receiving data from model %v: %v", m.ID, err.Error())
   137  			}
   138  			m.rep <- resp
   139  		}
   140  		close(m.rep) // no more replies after req closed
   141  	}()
   142  
   143  	return nil
   144  }