github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/api_handler.go (about)

     1  package main
     2  
     3  import (
     4  	"encoding/json"
     5  	"net/http"
     6  	"path/filepath"
     7  )
     8  
     9  type server struct {
    10  	*ModelRepo
    11  }
    12  
    13  // NewAPIHandler returns an http.Handler for responding to api requests to
    14  // mlserver. The ModelRepo parameter should be a pointer to an initialized
    15  // and indexed ModelRepo.
    16  func NewAPIHandler(r *ModelRepo) http.Handler {
    17  	s := &server{r}
    18  
    19  	m := http.NewServeMux()
    20  	m.HandleFunc("/models", s.HandleModels)
    21  	m.HandleFunc("/models/", s.HandleModel)
    22  	m.HandleFunc("/models/running", s.HandleRunningModels)
    23  	m.HandleFunc("/models/running/", s.HandleStopModel)
    24  
    25  	return m
    26  }
    27  
    28  // HandleModel is the http handler for requests made to /models/<id>, GET
    29  // returns the model status, PUT/POST return predictions by the model. Other
    30  // HTTP methods result in a Method Not Allowed response.
    31  func (s *server) HandleModel(w http.ResponseWriter, r *http.Request) {
    32  	modelID := filepath.Base(r.URL.Path)
    33  
    34  	switch r.Method {
    35  	case "GET": // status
    36  		m, err := s.LoadModelData(modelID)
    37  		if err == ErrModelNotFound {
    38  			http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
    39  			return
    40  		}
    41  		if err != nil {
    42  			http.Error(w, err.Error(), http.StatusInternalServerError)
    43  			return
    44  		}
    45  
    46  		writeJSONOK(w, m)
    47  
    48  	case "PUT", "POST": // predict
    49  		var err error
    50  
    51  		m, err := s.Get(modelID)
    52  		if err == ErrModelNotFound {
    53  			http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
    54  			return
    55  		}
    56  		if err != nil {
    57  			http.Error(w, err.Error(), http.StatusInternalServerError)
    58  			return
    59  		}
    60  
    61  		newData, err := parseFitPredictRequest(r, false)
    62  		if err != nil {
    63  			http.Error(w, err.Error(), http.StatusBadRequest)
    64  			return
    65  		}
    66  		newData.ModelID = modelID
    67  
    68  		pred := m.Predict(newData)
    69  		writeJSONOK(w, pred)
    70  
    71  	default:
    72  		notAllowed(w)
    73  	}
    74  
    75  }
    76  
    77  // HandleModels is the http handler for requests made to /models, POST
    78  // fits a new model with the supplied data. Data for fitting the model can
    79  // be encoded as JSON in the request body or uploaded as a csv file. GET responds
    80  // with a list of all models in the index. Other HTTP methods result in a
    81  // Method Not Allowed response.
    82  func (s *server) HandleModels(w http.ResponseWriter, r *http.Request) {
    83  	switch r.Method {
    84  	case "GET": // list models
    85  		writeJSONOK(w, s.All())
    86  
    87  	case "POST": // new model
    88  
    89  		trainData, err := parseFitPredictRequest(r, true)
    90  		if err != nil {
    91  			http.Error(w, err.Error(), http.StatusBadRequest)
    92  			return
    93  		}
    94  
    95  		m := s.NewModel()
    96  		go fitModel(m, trainData, s.ModelRepo)
    97  
    98  		resp := struct {
    99  			ModelID string `json:"model_id"`
   100  		}{
   101  			m.ID,
   102  		}
   103  
   104  		writeJSON(w, resp, http.StatusAccepted)
   105  
   106  	default:
   107  		notAllowed(w)
   108  	}
   109  }
   110  
   111  // HandleRunningModels accepts GET, PUT, POST requests made to /models/running
   112  // GET response with all running models, PUT/POST will start a model, the modelID
   113  // should be passed as a json encoded object in the body of the request. All other
   114  // methods result in a Method Not Allowed response.
   115  func (s *server) HandleRunningModels(w http.ResponseWriter, r *http.Request) {
   116  	switch r.Method {
   117  	case "GET": // list running models
   118  		models := s.All()
   119  		runningModels := []*Model{}
   120  		for _, model := range models {
   121  			model.runLock.RLock()
   122  			if model.Running {
   123  				runningModels = append(runningModels, model)
   124  			}
   125  			model.runLock.RUnlock()
   126  		}
   127  		writeJSONOK(w, runningModels)
   128  
   129  	case "PUT", "POST": // start a model
   130  		var msg struct {
   131  			ModelID string `json:"model_id"`
   132  		}
   133  		err := json.NewDecoder(r.Body).Decode(&msg)
   134  		if err != nil {
   135  			http.Error(w, err.Error(), http.StatusBadRequest)
   136  			return
   137  		}
   138  
   139  		_, err = s.Get(msg.ModelID)
   140  		if err == ErrModelNotFound {
   141  			http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
   142  			return
   143  		}
   144  		if err != nil {
   145  			http.Error(w, err.Error(), http.StatusInternalServerError)
   146  			return
   147  		}
   148  
   149  		w.WriteHeader(http.StatusCreated)
   150  
   151  	default:
   152  		notAllowed(w)
   153  	}
   154  }
   155  
   156  // HandleStopModel accepts DELETE requests made to /models/running/<id> and stops
   157  // the model if it's currently running. All other request methods result in a
   158  // Method Not Allowed response. If the model is not found, it will return 404
   159  func (s *server) HandleStopModel(w http.ResponseWriter, r *http.Request) {
   160  	if r.Method != "DELETE" {
   161  		notAllowed(w)
   162  		return
   163  	}
   164  
   165  	modelID := filepath.Base(r.URL.Path)
   166  	m, err := s.LoadModelData(modelID)
   167  	if err == ErrModelNotFound {
   168  		http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
   169  		return
   170  	}
   171  	if err != nil {
   172  		http.Error(w, err.Error(), http.StatusInternalServerError)
   173  		return
   174  	}
   175  
   176  	err = m.Stop()
   177  	if err != nil {
   178  		http.Error(w, err.Error(), http.StatusInternalServerError)
   179  		return
   180  	}
   181  
   182  	w.WriteHeader(http.StatusAccepted)
   183  }