
     1  // Package rpc is a go-micro rpc handler.
     2  package rpc
     4  import (
     5  	"encoding/json"
     6  	"io"
     7  	"net/http"
     8  	"net/textproto"
     9  	"strconv"
    10  	"strings"
    12  	jsonpatch ""
    13  	""
    14  	""
    15  	""
    16  	""
    17  	""
    18  	""
    19  	""
    20  	""
    21  	""
    22  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  	""
    28  )
    30  const (
    31  	Handler = "rpc"
    32  )
    34  var (
    35  	// supported json codecs
    36  	jsonCodecs = []string{
    37  		"application/grpc+json",
    38  		"application/json",
    39  		"application/json-rpc",
    40  	}
    42  	// support proto codecs
    43  	protoCodecs = []string{
    44  		"application/grpc",
    45  		"application/grpc+proto",
    46  		"application/proto",
    47  		"application/protobuf",
    48  		"application/proto-rpc",
    49  		"application/octet-stream",
    50  	}
    52  	bufferPool = bpool.NewSizedBufferPool(1024, 8)
    53  )
    55  type rpcHandler struct {
    56  	opts handler.Options
    57  	s    *api.Service
    58  }
    60  type buffer struct {
    61  	io.ReadCloser
    62  }
    64  func (b *buffer) Write(_ []byte) (int, error) {
    65  	return 0, nil
    66  }
    68  // strategy is a hack for selection
    69  func strategy(services []*registry.Service) selector.Strategy {
    70  	return func(_ []*registry.Service) selector.Next {
    71  		// ignore input to this function, use services above
    72  		return selector.Random(services)
    73  	}
    74  }
    76  func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    77  	bsize := handler.DefaultMaxRecvSize
    78  	if h.opts.MaxRecvSize > 0 {
    79  		bsize = h.opts.MaxRecvSize
    80  	}
    82  	r.Body = http.MaxBytesReader(w, r.Body, bsize)
    84  	defer r.Body.Close()
    85  	var service *api.Service
    87  	if h.s != nil {
    88  		// we were given the service
    89  		service = h.s
    90  	} else if h.opts.Router != nil {
    91  		// try get service from router
    92  		s, err := h.opts.Router.Route(r)
    93  		if err != nil {
    94  			writeError(w, r, errors.InternalServerError("go.micro.api", err.Error()))
    95  			return
    96  		}
    97  		service = s
    98  	} else {
    99  		// we have no way of routing the request
   100  		writeError(w, r, errors.InternalServerError("go.micro.api", "no route found"))
   101  		return
   102  	}
   104  	ct := r.Header.Get("Content-Type")
   106  	// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
   107  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
   108  		ct = ct[:idx]
   109  	}
   111  	// micro client
   112  	c := h.opts.Client
   114  	// create context
   115  	cx := ctx.FromRequest(r)
   116  	// get context from http handler wrappers
   117  	md, ok := metadata.FromContext(r.Context())
   118  	if !ok {
   119  		md = make(metadata.Metadata)
   120  	}
   121  	// fill contex with http headers
   122  	md["Host"] = r.Host
   123  	md["Method"] = r.Method
   124  	// get canonical headers
   125  	for k, _ := range r.Header {
   126  		// may be need to get all values for key like r.Header.Values() provide in go 1.14
   127  		md[textproto.CanonicalMIMEHeaderKey(k)] = r.Header.Get(k)
   128  	}
   130  	// merge context with overwrite
   131  	cx = metadata.MergeContext(cx, md, true)
   133  	// set merged context to request
   134  	*r = *r.Clone(cx)
   135  	// if stream we currently only support json
   136  	if isStream(r, service) {
   137  		// drop older context as it can have timeouts and create new
   138  		//		md, _ := metadata.FromContext(cx)
   139  		//serveWebsocket(context.TODO(), w, r, service, c)
   140  		serveWebsocket(cx, w, r, service, c)
   141  		return
   142  	}
   144  	// create strategy
   145  	so := selector.WithStrategy(strategy(service.Services))
   147  	// walk the standard call path
   148  	// get payload
   149  	br, err := requestPayload(r)
   150  	if err != nil {
   151  		writeError(w, r, err)
   152  		return
   153  	}
   155  	var rsp []byte
   157  	switch {
   158  	// proto codecs
   159  	case hasCodec(ct, protoCodecs):
   160  		request := &proto.Message{}
   161  		// if the extracted payload isn't empty lets use it
   162  		if len(br) > 0 {
   163  			request = proto.NewMessage(br)
   164  		}
   166  		// create request/response
   167  		response := &proto.Message{}
   169  		req := c.NewRequest(
   170  			service.Name,
   171  			service.Endpoint.Name,
   172  			request,
   173  			client.WithContentType(ct),
   174  		)
   176  		// make the call
   177  		if err := c.Call(cx, req, response, client.WithSelectOption(so)); err != nil {
   178  			writeError(w, r, err)
   179  			return
   180  		}
   182  		// marshall response
   183  		rsp, err = response.Marshal()
   184  		if err != nil {
   185  			writeError(w, r, err)
   186  			return
   187  		}
   189  	default:
   190  		// if json codec is not present set to json
   191  		if !hasCodec(ct, jsonCodecs) {
   192  			ct = "application/json"
   193  		}
   195  		// default to trying json
   196  		var request json.RawMessage
   197  		// if the extracted payload isn't empty lets use it
   198  		if len(br) > 0 {
   199  			request = json.RawMessage(br)
   200  		}
   202  		// create request/response
   203  		var response json.RawMessage
   205  		req := c.NewRequest(
   206  			service.Name,
   207  			service.Endpoint.Name,
   208  			&request,
   209  			client.WithContentType(ct),
   210  		)
   211  		// make the call
   212  		if err := c.Call(cx, req, &response, client.WithSelectOption(so)); err != nil {
   213  			writeError(w, r, err)
   214  			return
   215  		}
   217  		// marshall response
   218  		rsp, err = response.MarshalJSON()
   219  		if err != nil {
   220  			writeError(w, r, err)
   221  			return
   222  		}
   223  	}
   225  	// write the response
   226  	writeResponse(w, r, rsp)
   227  }
   229  func (rh *rpcHandler) String() string {
   230  	return "rpc"
   231  }
   233  func hasCodec(ct string, codecs []string) bool {
   234  	for _, codec := range codecs {
   235  		if ct == codec {
   236  			return true
   237  		}
   238  	}
   239  	return false
   240  }
   242  // requestPayload takes a *http.Request.
   243  // If the request is a GET the query string parameters are extracted and marshaled to JSON and the raw bytes are returned.
   244  // If the request method is a POST the request body is read and returned
   245  func requestPayload(r *http.Request) ([]byte, error) {
   246  	var err error
   248  	// we have to decode json-rpc and proto-rpc because we suck
   249  	// well actually because there's no proxy codec right now
   251  	ct := r.Header.Get("Content-Type")
   252  	switch {
   253  	case strings.Contains(ct, "application/json-rpc"):
   254  		msg := codec.Message{
   255  			Type:   codec.Request,
   256  			Header: make(map[string]string),
   257  		}
   258  		c := jsonrpc.NewCodec(&buffer{r.Body})
   259  		if err = c.ReadHeader(&msg, codec.Request); err != nil {
   260  			return nil, err
   261  		}
   262  		var raw json.RawMessage
   263  		if err = c.ReadBody(&raw); err != nil {
   264  			return nil, err
   265  		}
   266  		return ([]byte)(raw), nil
   267  	case strings.Contains(ct, "application/proto-rpc"), strings.Contains(ct, "application/octet-stream"):
   268  		msg := codec.Message{
   269  			Type:   codec.Request,
   270  			Header: make(map[string]string),
   271  		}
   272  		c := protorpc.NewCodec(&buffer{r.Body})
   273  		if err = c.ReadHeader(&msg, codec.Request); err != nil {
   274  			return nil, err
   275  		}
   276  		var raw proto.Message
   277  		if err = c.ReadBody(&raw); err != nil {
   278  			return nil, err
   279  		}
   280  		return raw.Marshal()
   281  	case strings.Contains(ct, "application/www-x-form-urlencoded"):
   282  		r.ParseForm()
   284  		// generate a new set of values from the form
   285  		vals := make(map[string]string)
   286  		for k, v := range r.Form {
   287  			vals[k] = strings.Join(v, ",")
   288  		}
   290  		// marshal
   291  		return json.Marshal(vals)
   292  		// TODO: application/grpc
   293  	}
   295  	// otherwise as per usual
   296  	ctx := r.Context()
   297  	// dont user meadata.FromContext as it mangles names
   298  	md, ok := metadata.FromContext(ctx)
   299  	if !ok {
   300  		md = make(map[string]string)
   301  	}
   303  	// allocate maximum
   304  	matches := make(map[string]interface{}, len(md))
   305  	bodydst := ""
   307  	// get fields from url path
   308  	for k, v := range md {
   309  		k = strings.ToLower(k)
   310  		// filter own keys
   311  		if strings.HasPrefix(k, "x-api-field-") {
   312  			matches[strings.TrimPrefix(k, "x-api-field-")] = v
   313  			delete(md, k)
   314  		} else if k == "x-api-body" {
   315  			bodydst = v
   316  			delete(md, k)
   317  		}
   318  	}
   320  	// map of all fields
   321  	req := make(map[string]interface{}, len(md))
   323  	// get fields from url values
   324  	if len(r.URL.RawQuery) > 0 {
   325  		umd := make(map[string]interface{})
   326  		err = qson.Unmarshal(&umd, r.URL.RawQuery)
   327  		if err != nil {
   328  			return nil, err
   329  		}
   330  		for k, v := range umd {
   331  			matches[k] = v
   332  		}
   333  	}
   335  	// restore context without fields
   336  	*r = *r.Clone(metadata.NewContext(ctx, md))
   338  	for k, v := range matches {
   339  		ps := strings.Split(k, ".")
   340  		if len(ps) == 1 {
   341  			req[k] = v
   342  			continue
   343  		}
   344  		em := make(map[string]interface{})
   345  		em[ps[len(ps)-1]] = v
   346  		for i := len(ps) - 2; i > 0; i-- {
   347  			nm := make(map[string]interface{})
   348  			nm[ps[i]] = em
   349  			em = nm
   350  		}
   351  		if vm, ok := req[ps[0]]; ok {
   352  			// nested map
   353  			nm := vm.(map[string]interface{})
   354  			for vk, vv := range em {
   355  				nm[vk] = vv
   356  			}
   357  			req[ps[0]] = nm
   358  		} else {
   359  			req[ps[0]] = em
   360  		}
   361  	}
   362  	pathbuf := []byte("{}")
   363  	if len(req) > 0 {
   364  		pathbuf, err = json.Marshal(req)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  	}
   370  	urlbuf := []byte("{}")
   371  	out, err := jsonpatch.MergeMergePatches(urlbuf, pathbuf)
   372  	if err != nil {
   373  		return nil, err
   374  	}
   376  	switch r.Method {
   377  	case "GET":
   378  		// empty response
   379  		if strings.Contains(ct, "application/json") && string(out) == "{}" {
   380  			return out, nil
   381  		} else if string(out) == "{}" && !strings.Contains(ct, "application/json") {
   382  			return []byte{}, nil
   383  		}
   384  		return out, nil
   385  	case "PATCH", "POST", "PUT", "DELETE":
   386  		bodybuf := []byte("{}")
   387  		buf := bufferPool.Get()
   388  		defer bufferPool.Put(buf)
   389  		if _, err := buf.ReadFrom(r.Body); err != nil {
   390  			return nil, err
   391  		}
   392  		if b := buf.Bytes(); len(b) > 0 {
   393  			bodybuf = b
   394  		}
   395  		if bodydst == "" || bodydst == "*" {
   396  			if out, err = jsonpatch.MergeMergePatches(out, bodybuf); err == nil {
   397  				return out, nil
   398  			}
   399  		}
   400  		var jsonbody map[string]interface{}
   401  		if json.Valid(bodybuf) {
   402  			if err = json.Unmarshal(bodybuf, &jsonbody); err != nil {
   403  				return nil, err
   404  			}
   405  		}
   406  		dstmap := make(map[string]interface{})
   407  		ps := strings.Split(bodydst, ".")
   408  		if len(ps) == 1 {
   409  			if jsonbody != nil {
   410  				dstmap[ps[0]] = jsonbody
   411  			} else {
   412  				// old unexpected behaviour
   413  				dstmap[ps[0]] = bodybuf
   414  			}
   415  		} else {
   416  			em := make(map[string]interface{})
   417  			if jsonbody != nil {
   418  				em[ps[len(ps)-1]] = jsonbody
   419  			} else {
   420  				// old unexpected behaviour
   421  				em[ps[len(ps)-1]] = bodybuf
   422  			}
   423  			for i := len(ps) - 2; i > 0; i-- {
   424  				nm := make(map[string]interface{})
   425  				nm[ps[i]] = em
   426  				em = nm
   427  			}
   428  			dstmap[ps[0]] = em
   429  		}
   431  		bodyout, err := json.Marshal(dstmap)
   432  		if err != nil {
   433  			return nil, err
   434  		}
   436  		if out, err = jsonpatch.MergeMergePatches(out, bodyout); err == nil {
   437  			return out, nil
   438  		}
   440  		//fallback to previous unknown behaviour
   441  		return bodybuf, nil
   443  	}
   445  	return []byte{}, nil
   446  }
   448  func writeError(w http.ResponseWriter, r *http.Request, err error) {
   449  	ce := errors.Parse(err.Error())
   451  	switch ce.Code {
   452  	case 0:
   453  		// assuming it's totally screwed
   454  		ce.Code = 500
   455  		ce.Id = "go.micro.api"
   456  		ce.Status = http.StatusText(500)
   457  		ce.Detail = "error during request: " + ce.Detail
   458  		w.WriteHeader(500)
   459  	default:
   460  		w.WriteHeader(int(ce.Code))
   461  	}
   463  	// response content type
   464  	w.Header().Set("Content-Type", "application/json")
   466  	// Set trailers
   467  	if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   468  		w.Header().Set("Trailer", "grpc-status")
   469  		w.Header().Set("Trailer", "grpc-message")
   470  		w.Header().Set("grpc-status", "13")
   471  		w.Header().Set("grpc-message", ce.Detail)
   472  	}
   474  	_, werr := w.Write([]byte(ce.Error()))
   475  	if werr != nil {
   476  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   477  			logger.Error(werr)
   478  		}
   479  	}
   480  }
   482  func writeResponse(w http.ResponseWriter, r *http.Request, rsp []byte) {
   483  	w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
   484  	w.Header().Set("Content-Length", strconv.Itoa(len(rsp)))
   486  	// Set trailers
   487  	if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   488  		w.Header().Set("Trailer", "grpc-status")
   489  		w.Header().Set("Trailer", "grpc-message")
   490  		w.Header().Set("grpc-status", "0")
   491  		w.Header().Set("grpc-message", "")
   492  	}
   494  	// write 204 status if rsp is nil
   495  	if len(rsp) == 0 {
   496  		w.WriteHeader(http.StatusNoContent)
   497  	}
   499  	// write response
   500  	_, err := w.Write(rsp)
   501  	if err != nil {
   502  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   503  			logger.Error(err)
   504  		}
   505  	}
   507  }
   509  func NewHandler(opts ...handler.Option) handler.Handler {
   510  	options := handler.NewOptions(opts...)
   511  	return &rpcHandler{
   512  		opts: options,
   513  	}
   514  }
   516  func WithService(s *api.Service, opts ...handler.Option) handler.Handler {
   517  	options := handler.NewOptions(opts...)
   518  	return &rpcHandler{
   519  		opts: options,
   520  		s:    s,
   521  	}
   522  }