github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/parse.go (about)

     1  package main
     2  
     3  import (
     4  	"encoding/csv"
     5  	"encoding/json"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/coreos/go-log/log"
    13  )
    14  
    15  // ParseJSON parses a JSON encoded request:
    16  //
    17  //		{
    18  //			"name": "iris model",
    19  //			"data": [
    20  //				{
    21  //					"var_1": 2.5,
    22  //					"var_2": 3.6,
    23  //					...
    24  //				},
    25  //				...
    26  //			],
    27  //			"labels": [
    28  //				"yes",
    29  //				"no",
    30  //				...
    31  //			]
    32  //		}
    33  //
    34  // into a ModelReq struct. If the hasTarget arg is true, ParseJSON will also set
    35  // the isRegression attribute if the returned ModelReq if all the values in the
    36  // label slice can be parsed as floats.
    37  func ParseJSON(r io.Reader, hasTarget bool) (ModelReq, error) {
    38  	var d ModelReq
    39  	err := json.NewDecoder(r).Decode(&d)
    40  	if err != nil {
    41  		return ModelReq{}, err
    42  	}
    43  
    44  	// the json decoder will correctly parse string vs float for the label slice
    45  	// check a few values to determine if this is a regression or classification
    46  	// task
    47  	if hasTarget {
    48  		allFloats := true
    49  		for _, val := range d.Labels {
    50  			_, ok := val.(float64)
    51  			if !ok {
    52  				allFloats = false
    53  				break
    54  			}
    55  		}
    56  		d.isRegression = allFloats
    57  	}
    58  
    59  	return d, nil
    60  }
    61  
    62  // ParseCSV parses a csv file with the following format:
    63  //
    64  //		<target_var>,<var_1>,<var_2>,...<var_n>
    65  //		"true",1.5,"red",...
    66  //
    67  // returning a slice of maps representing the feature:value pairs for each row,
    68  // a slice of labels, and an error. If the hasTarget flag is true, the first
    69  // column of input data will be copied to the label slice and excluded from the
    70  // feature:value pairs. If hasTarget is false, the label slice will be empty and
    71  // all columns will be included in the feature:value pairs.
    72  func ParseCSV(r io.Reader, hasTarget bool) (ModelReq, error) {
    73  	reader := csv.NewReader(r)
    74  
    75  	// grab the var names from the first row
    76  	fieldNames, err := reader.Read()
    77  	if err != nil {
    78  		return ModelReq{}, err
    79  	}
    80  
    81  	xStart := 0 // column where feature data starts
    82  	if hasTarget {
    83  		xStart = 1
    84  	}
    85  
    86  	var d ModelReq
    87  	allFloats := true // regression if all labels are floats, classification otherwise
    88  
    89  	for {
    90  		row, err := reader.Read()
    91  		if err == io.EOF {
    92  			break
    93  		}
    94  		if err != nil {
    95  			return ModelReq{}, err
    96  		}
    97  
    98  		if len(row) != len(fieldNames) {
    99  			return ModelReq{}, errors.New("mlserver: csv header and row length mismatch")
   100  		}
   101  
   102  		if hasTarget { // first column is the target variable
   103  			// check if float
   104  			numVal, err := strconv.ParseFloat(row[0], 64)
   105  			if err != nil {
   106  				d.Labels = append(d.Labels, row[0])
   107  				allFloats = false
   108  			} else {
   109  				d.Labels = append(d.Labels, numVal)
   110  			}
   111  		}
   112  
   113  		// save the rest as <feature_name>:<value> pairs
   114  		features := make(map[string]interface{})
   115  		for i := xStart; i < len(row); i++ {
   116  			val := row[i]
   117  			// check for numeric value
   118  			numVal, err := strconv.ParseFloat(row[i], 64)
   119  			if err != nil {
   120  				features[fieldNames[i]] = val // use string val
   121  			} else {
   122  				features[fieldNames[i]] = numVal // use numeric val
   123  			}
   124  		}
   125  		d.Data = append(d.Data, features)
   126  	}
   127  
   128  	if hasTarget {
   129  		d.isRegression = allFloats
   130  	}
   131  
   132  	return d, nil
   133  }
   134  
   135  // parseFileUpload parses ModelReq from a csv file uploaded in a POST request.
   136  // the hasTarget arg should be true when the uploaded csv file has the target
   137  // variable in the first column (i.e. when parsing a request for fitting a model).
   138  // ErrCSVFileMissing will be returned if there is no file associated with the key
   139  // 'file'.
   140  func parseFileUpload(r *http.Request, hasTarget bool) (ModelReq, error) {
   141  
   142  	err := r.ParseMultipartForm(1 << 28)
   143  	if err != nil {
   144  		return ModelReq{}, err
   145  	}
   146  
   147  	defer func() {
   148  		err := r.MultipartForm.RemoveAll()
   149  		if err != nil {
   150  			log.Error("error removing file uploads ", err)
   151  		}
   152  	}()
   153  
   154  	files, ok := r.MultipartForm.File["file"]
   155  	if !ok || len(files) < 1 {
   156  		return ModelReq{}, errors.New("csv file missing")
   157  	}
   158  
   159  	f, err := files[0].Open()
   160  	if err != nil {
   161  		return ModelReq{}, err
   162  	}
   163  	defer f.Close()
   164  
   165  	d, err := ParseCSV(f, hasTarget)
   166  	if err != nil {
   167  		return ModelReq{}, err
   168  	}
   169  
   170  	d.Name = strings.Join(r.MultipartForm.Value["name"], " ")
   171  
   172  	return d, nil
   173  }
   174  
   175  // parseFitPredictRequest parses an http request into a ModelReq struct. The appropriate
   176  // parser (json or csv) is determined from the content-type.
   177  func parseFitPredictRequest(r *http.Request, isFitReq bool) (ModelReq, error) {
   178  	if r.Header.Get("Content-Type") == "application/json" {
   179  		return ParseJSON(r.Body, isFitReq)
   180  	} else {
   181  		return parseFileUpload(r, isFitReq)
   182  	}
   183  }