github.com/btccom/go-micro/v2@v2.9.3/api/handler/rpc/rpc.go (about)

     1  // Package rpc is a go-micro rpc handler.
     2  package rpc
     3  
     4  import (
     5  	"encoding/json"
     6  	"io"
     7  	"net/http"
     8  	"net/textproto"
     9  	"strconv"
    10  	"strings"
    11  
    12  	jsonpatch "github.com/evanphx/json-patch/v5"
    13  	"github.com/btccom/go-micro/v2/api"
    14  	"github.com/btccom/go-micro/v2/api/handler"
    15  	"github.com/btccom/go-micro/v2/api/internal/proto"
    16  	"github.com/btccom/go-micro/v2/client"
    17  	"github.com/btccom/go-micro/v2/client/selector"
    18  	"github.com/btccom/go-micro/v2/codec"
    19  	"github.com/btccom/go-micro/v2/codec/jsonrpc"
    20  	"github.com/btccom/go-micro/v2/codec/protorpc"
    21  	"github.com/btccom/go-micro/v2/errors"
    22  	"github.com/btccom/go-micro/v2/logger"
    23  	"github.com/btccom/go-micro/v2/metadata"
    24  	"github.com/btccom/go-micro/v2/registry"
    25  	"github.com/btccom/go-micro/v2/util/ctx"
    26  	"github.com/btccom/go-micro/v2/util/qson"
    27  	"github.com/oxtoacart/bpool"
    28  )
    29  
    30  const (
    31  	Handler = "rpc"
    32  )
    33  
    34  var (
    35  	// supported json codecs
    36  	jsonCodecs = []string{
    37  		"application/grpc+json",
    38  		"application/json",
    39  		"application/json-rpc",
    40  	}
    41  
    42  	// support proto codecs
    43  	protoCodecs = []string{
    44  		"application/grpc",
    45  		"application/grpc+proto",
    46  		"application/proto",
    47  		"application/protobuf",
    48  		"application/proto-rpc",
    49  		"application/octet-stream",
    50  	}
    51  
    52  	bufferPool = bpool.NewSizedBufferPool(1024, 8)
    53  )
    54  
    55  type rpcHandler struct {
    56  	opts handler.Options
    57  	s    *api.Service
    58  }
    59  
    60  type buffer struct {
    61  	io.ReadCloser
    62  }
    63  
    64  func (b *buffer) Write(_ []byte) (int, error) {
    65  	return 0, nil
    66  }
    67  
    68  // strategy is a hack for selection
    69  func strategy(services []*registry.Service) selector.Strategy {
    70  	return func(_ []*registry.Service) selector.Next {
    71  		// ignore input to this function, use services above
    72  		return selector.Random(services)
    73  	}
    74  }
    75  
    76  func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    77  	bsize := handler.DefaultMaxRecvSize
    78  	if h.opts.MaxRecvSize > 0 {
    79  		bsize = h.opts.MaxRecvSize
    80  	}
    81  
    82  	r.Body = http.MaxBytesReader(w, r.Body, bsize)
    83  
    84  	defer r.Body.Close()
    85  	var service *api.Service
    86  
    87  	if h.s != nil {
    88  		// we were given the service
    89  		service = h.s
    90  	} else if h.opts.Router != nil {
    91  		// try get service from router
    92  		s, err := h.opts.Router.Route(r)
    93  		if err != nil {
    94  			writeError(w, r, errors.InternalServerError("go.micro.api", err.Error()))
    95  			return
    96  		}
    97  		service = s
    98  	} else {
    99  		// we have no way of routing the request
   100  		writeError(w, r, errors.InternalServerError("go.micro.api", "no route found"))
   101  		return
   102  	}
   103  
   104  	ct := r.Header.Get("Content-Type")
   105  
   106  	// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
   107  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
   108  		ct = ct[:idx]
   109  	}
   110  
   111  	// micro client
   112  	c := h.opts.Client
   113  
   114  	// create context
   115  	cx := ctx.FromRequest(r)
   116  	// get context from http handler wrappers
   117  	md, ok := metadata.FromContext(r.Context())
   118  	if !ok {
   119  		md = make(metadata.Metadata)
   120  	}
   121  	// fill contex with http headers
   122  	md["Host"] = r.Host
   123  	md["Method"] = r.Method
   124  	// get canonical headers
   125  	for k, _ := range r.Header {
   126  		// may be need to get all values for key like r.Header.Values() provide in go 1.14
   127  		md[textproto.CanonicalMIMEHeaderKey(k)] = r.Header.Get(k)
   128  	}
   129  
   130  	// merge context with overwrite
   131  	cx = metadata.MergeContext(cx, md, true)
   132  
   133  	// set merged context to request
   134  	*r = *r.Clone(cx)
   135  	// if stream we currently only support json
   136  	if isStream(r, service) {
   137  		// drop older context as it can have timeouts and create new
   138  		//		md, _ := metadata.FromContext(cx)
   139  		//serveWebsocket(context.TODO(), w, r, service, c)
   140  		serveWebsocket(cx, w, r, service, c)
   141  		return
   142  	}
   143  
   144  	// create strategy
   145  	so := selector.WithStrategy(strategy(service.Services))
   146  
   147  	// walk the standard call path
   148  	// get payload
   149  	br, err := requestPayload(r)
   150  	if err != nil {
   151  		writeError(w, r, err)
   152  		return
   153  	}
   154  
   155  	var rsp []byte
   156  
   157  	switch {
   158  	// proto codecs
   159  	case hasCodec(ct, protoCodecs):
   160  		request := &proto.Message{}
   161  		// if the extracted payload isn't empty lets use it
   162  		if len(br) > 0 {
   163  			request = proto.NewMessage(br)
   164  		}
   165  
   166  		// create request/response
   167  		response := &proto.Message{}
   168  
   169  		req := c.NewRequest(
   170  			service.Name,
   171  			service.Endpoint.Name,
   172  			request,
   173  			client.WithContentType(ct),
   174  		)
   175  
   176  		// make the call
   177  		if err := c.Call(cx, req, response, client.WithSelectOption(so)); err != nil {
   178  			writeError(w, r, err)
   179  			return
   180  		}
   181  
   182  		// marshall response
   183  		rsp, err = response.Marshal()
   184  		if err != nil {
   185  			writeError(w, r, err)
   186  			return
   187  		}
   188  
   189  	default:
   190  		// if json codec is not present set to json
   191  		if !hasCodec(ct, jsonCodecs) {
   192  			ct = "application/json"
   193  		}
   194  
   195  		// default to trying json
   196  		var request json.RawMessage
   197  		// if the extracted payload isn't empty lets use it
   198  		if len(br) > 0 {
   199  			request = json.RawMessage(br)
   200  		}
   201  
   202  		// create request/response
   203  		var response json.RawMessage
   204  
   205  		req := c.NewRequest(
   206  			service.Name,
   207  			service.Endpoint.Name,
   208  			&request,
   209  			client.WithContentType(ct),
   210  		)
   211  		// make the call
   212  		if err := c.Call(cx, req, &response, client.WithSelectOption(so)); err != nil {
   213  			writeError(w, r, err)
   214  			return
   215  		}
   216  
   217  		// marshall response
   218  		rsp, err = response.MarshalJSON()
   219  		if err != nil {
   220  			writeError(w, r, err)
   221  			return
   222  		}
   223  	}
   224  
   225  	// write the response
   226  	writeResponse(w, r, rsp)
   227  }
   228  
   229  func (rh *rpcHandler) String() string {
   230  	return "rpc"
   231  }
   232  
   233  func hasCodec(ct string, codecs []string) bool {
   234  	for _, codec := range codecs {
   235  		if ct == codec {
   236  			return true
   237  		}
   238  	}
   239  	return false
   240  }
   241  
   242  // requestPayload takes a *http.Request.
   243  // If the request is a GET the query string parameters are extracted and marshaled to JSON and the raw bytes are returned.
   244  // If the request method is a POST the request body is read and returned
   245  func requestPayload(r *http.Request) ([]byte, error) {
   246  	var err error
   247  
   248  	// we have to decode json-rpc and proto-rpc because we suck
   249  	// well actually because there's no proxy codec right now
   250  
   251  	ct := r.Header.Get("Content-Type")
   252  	switch {
   253  	case strings.Contains(ct, "application/json-rpc"):
   254  		msg := codec.Message{
   255  			Type:   codec.Request,
   256  			Header: make(map[string]string),
   257  		}
   258  		c := jsonrpc.NewCodec(&buffer{r.Body})
   259  		if err = c.ReadHeader(&msg, codec.Request); err != nil {
   260  			return nil, err
   261  		}
   262  		var raw json.RawMessage
   263  		if err = c.ReadBody(&raw); err != nil {
   264  			return nil, err
   265  		}
   266  		return ([]byte)(raw), nil
   267  	case strings.Contains(ct, "application/proto-rpc"), strings.Contains(ct, "application/octet-stream"):
   268  		msg := codec.Message{
   269  			Type:   codec.Request,
   270  			Header: make(map[string]string),
   271  		}
   272  		c := protorpc.NewCodec(&buffer{r.Body})
   273  		if err = c.ReadHeader(&msg, codec.Request); err != nil {
   274  			return nil, err
   275  		}
   276  		var raw proto.Message
   277  		if err = c.ReadBody(&raw); err != nil {
   278  			return nil, err
   279  		}
   280  		return raw.Marshal()
   281  	case strings.Contains(ct, "application/www-x-form-urlencoded"):
   282  		r.ParseForm()
   283  
   284  		// generate a new set of values from the form
   285  		vals := make(map[string]string)
   286  		for k, v := range r.Form {
   287  			vals[k] = strings.Join(v, ",")
   288  		}
   289  
   290  		// marshal
   291  		return json.Marshal(vals)
   292  		// TODO: application/grpc
   293  	}
   294  
   295  	// otherwise as per usual
   296  	ctx := r.Context()
   297  	// dont user meadata.FromContext as it mangles names
   298  	md, ok := metadata.FromContext(ctx)
   299  	if !ok {
   300  		md = make(map[string]string)
   301  	}
   302  
   303  	// allocate maximum
   304  	matches := make(map[string]interface{}, len(md))
   305  	bodydst := ""
   306  
   307  	// get fields from url path
   308  	for k, v := range md {
   309  		k = strings.ToLower(k)
   310  		// filter own keys
   311  		if strings.HasPrefix(k, "x-api-field-") {
   312  			matches[strings.TrimPrefix(k, "x-api-field-")] = v
   313  			delete(md, k)
   314  		} else if k == "x-api-body" {
   315  			bodydst = v
   316  			delete(md, k)
   317  		}
   318  	}
   319  
   320  	// map of all fields
   321  	req := make(map[string]interface{}, len(md))
   322  
   323  	// get fields from url values
   324  	if len(r.URL.RawQuery) > 0 {
   325  		umd := make(map[string]interface{})
   326  		err = qson.Unmarshal(&umd, r.URL.RawQuery)
   327  		if err != nil {
   328  			return nil, err
   329  		}
   330  		for k, v := range umd {
   331  			matches[k] = v
   332  		}
   333  	}
   334  
   335  	// restore context without fields
   336  	*r = *r.Clone(metadata.NewContext(ctx, md))
   337  
   338  	for k, v := range matches {
   339  		ps := strings.Split(k, ".")
   340  		if len(ps) == 1 {
   341  			req[k] = v
   342  			continue
   343  		}
   344  		em := make(map[string]interface{})
   345  		em[ps[len(ps)-1]] = v
   346  		for i := len(ps) - 2; i > 0; i-- {
   347  			nm := make(map[string]interface{})
   348  			nm[ps[i]] = em
   349  			em = nm
   350  		}
   351  		if vm, ok := req[ps[0]]; ok {
   352  			// nested map
   353  			nm := vm.(map[string]interface{})
   354  			for vk, vv := range em {
   355  				nm[vk] = vv
   356  			}
   357  			req[ps[0]] = nm
   358  		} else {
   359  			req[ps[0]] = em
   360  		}
   361  	}
   362  	pathbuf := []byte("{}")
   363  	if len(req) > 0 {
   364  		pathbuf, err = json.Marshal(req)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  	}
   369  
   370  	urlbuf := []byte("{}")
   371  	out, err := jsonpatch.MergeMergePatches(urlbuf, pathbuf)
   372  	if err != nil {
   373  		return nil, err
   374  	}
   375  
   376  	switch r.Method {
   377  	case "GET":
   378  		// empty response
   379  		if strings.Contains(ct, "application/json") && string(out) == "{}" {
   380  			return out, nil
   381  		} else if string(out) == "{}" && !strings.Contains(ct, "application/json") {
   382  			return []byte{}, nil
   383  		}
   384  		return out, nil
   385  	case "PATCH", "POST", "PUT", "DELETE":
   386  		bodybuf := []byte("{}")
   387  		buf := bufferPool.Get()
   388  		defer bufferPool.Put(buf)
   389  		if _, err := buf.ReadFrom(r.Body); err != nil {
   390  			return nil, err
   391  		}
   392  		if b := buf.Bytes(); len(b) > 0 {
   393  			bodybuf = b
   394  		}
   395  		if bodydst == "" || bodydst == "*" {
   396  			if out, err = jsonpatch.MergeMergePatches(out, bodybuf); err == nil {
   397  				return out, nil
   398  			}
   399  		}
   400  		var jsonbody map[string]interface{}
   401  		if json.Valid(bodybuf) {
   402  			if err = json.Unmarshal(bodybuf, &jsonbody); err != nil {
   403  				return nil, err
   404  			}
   405  		}
   406  		dstmap := make(map[string]interface{})
   407  		ps := strings.Split(bodydst, ".")
   408  		if len(ps) == 1 {
   409  			if jsonbody != nil {
   410  				dstmap[ps[0]] = jsonbody
   411  			} else {
   412  				// old unexpected behaviour
   413  				dstmap[ps[0]] = bodybuf
   414  			}
   415  		} else {
   416  			em := make(map[string]interface{})
   417  			if jsonbody != nil {
   418  				em[ps[len(ps)-1]] = jsonbody
   419  			} else {
   420  				// old unexpected behaviour
   421  				em[ps[len(ps)-1]] = bodybuf
   422  			}
   423  			for i := len(ps) - 2; i > 0; i-- {
   424  				nm := make(map[string]interface{})
   425  				nm[ps[i]] = em
   426  				em = nm
   427  			}
   428  			dstmap[ps[0]] = em
   429  		}
   430  
   431  		bodyout, err := json.Marshal(dstmap)
   432  		if err != nil {
   433  			return nil, err
   434  		}
   435  
   436  		if out, err = jsonpatch.MergeMergePatches(out, bodyout); err == nil {
   437  			return out, nil
   438  		}
   439  
   440  		//fallback to previous unknown behaviour
   441  		return bodybuf, nil
   442  
   443  	}
   444  
   445  	return []byte{}, nil
   446  }
   447  
   448  func writeError(w http.ResponseWriter, r *http.Request, err error) {
   449  	ce := errors.Parse(err.Error())
   450  
   451  	switch ce.Code {
   452  	case 0:
   453  		// assuming it's totally screwed
   454  		ce.Code = 500
   455  		ce.Id = "go.micro.api"
   456  		ce.Status = http.StatusText(500)
   457  		ce.Detail = "error during request: " + ce.Detail
   458  		w.WriteHeader(500)
   459  	default:
   460  		w.WriteHeader(int(ce.Code))
   461  	}
   462  
   463  	// response content type
   464  	w.Header().Set("Content-Type", "application/json")
   465  
   466  	// Set trailers
   467  	if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   468  		w.Header().Set("Trailer", "grpc-status")
   469  		w.Header().Set("Trailer", "grpc-message")
   470  		w.Header().Set("grpc-status", "13")
   471  		w.Header().Set("grpc-message", ce.Detail)
   472  	}
   473  
   474  	_, werr := w.Write([]byte(ce.Error()))
   475  	if werr != nil {
   476  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   477  			logger.Error(werr)
   478  		}
   479  	}
   480  }
   481  
   482  func writeResponse(w http.ResponseWriter, r *http.Request, rsp []byte) {
   483  	w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
   484  	w.Header().Set("Content-Length", strconv.Itoa(len(rsp)))
   485  
   486  	// Set trailers
   487  	if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   488  		w.Header().Set("Trailer", "grpc-status")
   489  		w.Header().Set("Trailer", "grpc-message")
   490  		w.Header().Set("grpc-status", "0")
   491  		w.Header().Set("grpc-message", "")
   492  	}
   493  
   494  	// write 204 status if rsp is nil
   495  	if len(rsp) == 0 {
   496  		w.WriteHeader(http.StatusNoContent)
   497  	}
   498  
   499  	// write response
   500  	_, err := w.Write(rsp)
   501  	if err != nil {
   502  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   503  			logger.Error(err)
   504  		}
   505  	}
   506  
   507  }
   508  
   509  func NewHandler(opts ...handler.Option) handler.Handler {
   510  	options := handler.NewOptions(opts...)
   511  	return &rpcHandler{
   512  		opts: options,
   513  	}
   514  }
   515  
   516  func WithService(s *api.Service, opts ...handler.Option) handler.Handler {
   517  	options := handler.NewOptions(opts...)
   518  	return &rpcHandler{
   519  		opts: options,
   520  		s:    s,
   521  	}
   522  }