github.com/annwntech/go-micro/v2@v2.9.5/api/handler/rpc/rpc.go (about)

     1  // Package rpc is a go-micro rpc handler.
     2  package rpc
     3  
     4  import (
     5  	"encoding/json"
     6  	"io"
     7  	"net/http"
     8  	"net/textproto"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/annwntech/go-micro/v2/api"
    14  	"github.com/annwntech/go-micro/v2/api/handler"
    15  	"github.com/annwntech/go-micro/v2/api/internal/proto"
    16  	"github.com/annwntech/go-micro/v2/client"
    17  	"github.com/annwntech/go-micro/v2/client/selector"
    18  	"github.com/annwntech/go-micro/v2/codec"
    19  	"github.com/annwntech/go-micro/v2/codec/jsonrpc"
    20  	"github.com/annwntech/go-micro/v2/codec/protorpc"
    21  	"github.com/annwntech/go-micro/v2/errors"
    22  	"github.com/annwntech/go-micro/v2/logger"
    23  	"github.com/annwntech/go-micro/v2/metadata"
    24  	"github.com/annwntech/go-micro/v2/registry"
    25  	"github.com/annwntech/go-micro/v2/util/ctx"
    26  	"github.com/annwntech/go-micro/v2/util/qson"
    27  	jsonpatch "github.com/evanphx/json-patch/v5"
    28  	"github.com/oxtoacart/bpool"
    29  )
    30  
    31  const (
    32  	Handler = "rpc"
    33  )
    34  
    35  var (
    36  	// supported json codecs
    37  	jsonCodecs = []string{
    38  		"application/grpc+json",
    39  		"application/json",
    40  		"application/json-rpc",
    41  	}
    42  
    43  	// support proto codecs
    44  	protoCodecs = []string{
    45  		"application/grpc",
    46  		"application/grpc+proto",
    47  		"application/proto",
    48  		"application/protobuf",
    49  		"application/proto-rpc",
    50  		"application/octet-stream",
    51  	}
    52  
    53  	bufferPool = bpool.NewSizedBufferPool(1024, 8)
    54  )
    55  
    56  type rpcHandler struct {
    57  	opts handler.Options
    58  	s    *api.Service
    59  }
    60  
    61  type buffer struct {
    62  	io.ReadCloser
    63  }
    64  
    65  func (b *buffer) Write(_ []byte) (int, error) {
    66  	return 0, nil
    67  }
    68  
    69  // strategy is a hack for selection
    70  func strategy(services []*registry.Service) selector.Strategy {
    71  	return func(_ []*registry.Service) selector.Next {
    72  		// ignore input to this function, use services above
    73  		return selector.Random(services)
    74  	}
    75  }
    76  
    77  func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    78  	bsize := handler.DefaultMaxRecvSize
    79  	if h.opts.MaxRecvSize > 0 {
    80  		bsize = h.opts.MaxRecvSize
    81  	}
    82  
    83  	r.Body = http.MaxBytesReader(w, r.Body, bsize)
    84  
    85  	defer r.Body.Close()
    86  	var service *api.Service
    87  
    88  	if h.s != nil {
    89  		// we were given the service
    90  		service = h.s
    91  	} else if h.opts.Router != nil {
    92  		// try get service from router
    93  		s, err := h.opts.Router.Route(r)
    94  		if err != nil {
    95  			writeError(w, r, errors.InternalServerError("go.micro.api", err.Error()))
    96  			return
    97  		}
    98  		service = s
    99  	} else {
   100  		// we have no way of routing the request
   101  		writeError(w, r, errors.InternalServerError("go.micro.api", "no route found"))
   102  		return
   103  	}
   104  
   105  	ct := r.Header.Get("Content-Type")
   106  
   107  	// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
   108  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
   109  		ct = ct[:idx]
   110  	}
   111  
   112  	// micro client
   113  	c := h.opts.Client
   114  
   115  	// create context
   116  	cx := ctx.FromRequest(r)
   117  	// get context from http handler wrappers
   118  	md, ok := metadata.FromContext(r.Context())
   119  	if !ok {
   120  		md = make(metadata.Metadata)
   121  	}
   122  	// fill contex with http headers
   123  	md["Host"] = r.Host
   124  	md["Method"] = r.Method
   125  	// get canonical headers
   126  	for k, _ := range r.Header {
   127  		// may be need to get all values for key like r.Header.Values() provide in go 1.14
   128  		md[textproto.CanonicalMIMEHeaderKey(k)] = r.Header.Get(k)
   129  	}
   130  
   131  	// merge context with overwrite
   132  	cx = metadata.MergeContext(cx, md, true)
   133  
   134  	// set merged context to request
   135  	*r = *r.Clone(cx)
   136  	// if stream we currently only support json
   137  	if isStream(r, service) {
   138  		// drop older context as it can have timeouts and create new
   139  		//		md, _ := metadata.FromContext(cx)
   140  		//serveWebsocket(context.TODO(), w, r, service, c)
   141  		serveWebsocket(cx, w, r, service, c)
   142  		return
   143  	}
   144  
   145  	var callOpts []client.CallOption
   146  
   147  	// create strategy
   148  	so := selector.WithStrategy(strategy(service.Services))
   149  	callOpts = append(callOpts, client.WithSelectOption(so))
   150  	if t := r.Header.Get("Timeout"); t != "" {
   151  		// assume timeout integer seconds now
   152  		// assume timeout in format 10s
   153  		if td, err := time.ParseDuration(t); err == nil {
   154  			callOpts = append(callOpts, client.WithRequestTimeout(td))
   155  		}
   156  	}
   157  	// walk the standard call path
   158  	// get payload
   159  	br, err := requestPayload(r)
   160  	if err != nil {
   161  		writeError(w, r, err)
   162  		return
   163  	}
   164  
   165  	var rsp []byte
   166  
   167  	switch {
   168  	// proto codecs
   169  	case hasCodec(ct, protoCodecs):
   170  		request := &proto.Message{}
   171  		// if the extracted payload isn't empty lets use it
   172  		if len(br) > 0 {
   173  			request = proto.NewMessage(br)
   174  		}
   175  
   176  		// create request/response
   177  		response := &proto.Message{}
   178  
   179  		req := c.NewRequest(
   180  			service.Name,
   181  			service.Endpoint.Name,
   182  			request,
   183  			client.WithContentType(ct),
   184  		)
   185  
   186  		// make the call
   187  		if err := c.Call(cx, req, response, callOpts...); err != nil {
   188  			writeError(w, r, err)
   189  			return
   190  		}
   191  
   192  		// marshall response
   193  		rsp, err = response.Marshal()
   194  		if err != nil {
   195  			writeError(w, r, err)
   196  			return
   197  		}
   198  
   199  	default:
   200  		// if json codec is not present set to json
   201  		if !hasCodec(ct, jsonCodecs) {
   202  			ct = "application/json"
   203  		}
   204  
   205  		// default to trying json
   206  		var request json.RawMessage
   207  		// if the extracted payload isn't empty lets use it
   208  		if len(br) > 0 {
   209  			request = json.RawMessage(br)
   210  		}
   211  
   212  		// create request/response
   213  		var response json.RawMessage
   214  
   215  		req := c.NewRequest(
   216  			service.Name,
   217  			service.Endpoint.Name,
   218  			&request,
   219  			client.WithContentType(ct),
   220  		)
   221  		// make the call
   222  		if err := c.Call(cx, req, &response, callOpts...); err != nil {
   223  			writeError(w, r, err)
   224  			return
   225  		}
   226  
   227  		// marshall response
   228  		rsp, err = response.MarshalJSON()
   229  		if err != nil {
   230  			writeError(w, r, err)
   231  			return
   232  		}
   233  	}
   234  
   235  	// write the response
   236  	writeResponse(w, r, rsp)
   237  }
   238  
   239  func (rh *rpcHandler) String() string {
   240  	return "rpc"
   241  }
   242  
   243  func hasCodec(ct string, codecs []string) bool {
   244  	for _, codec := range codecs {
   245  		if ct == codec {
   246  			return true
   247  		}
   248  	}
   249  	return false
   250  }
   251  
   252  // requestPayload takes a *http.Request.
   253  // If the request is a GET the query string parameters are extracted and marshaled to JSON and the raw bytes are returned.
   254  // If the request method is a POST the request body is read and returned
   255  func requestPayload(r *http.Request) ([]byte, error) {
   256  	var err error
   257  
   258  	// we have to decode json-rpc and proto-rpc because we suck
   259  	// well actually because there's no proxy codec right now
   260  
   261  	ct := r.Header.Get("Content-Type")
   262  	switch {
   263  	case strings.Contains(ct, "application/json-rpc"):
   264  		msg := codec.Message{
   265  			Type:   codec.Request,
   266  			Header: make(map[string]string),
   267  		}
   268  		c := jsonrpc.NewCodec(&buffer{r.Body})
   269  		if err = c.ReadHeader(&msg, codec.Request); err != nil {
   270  			return nil, err
   271  		}
   272  		var raw json.RawMessage
   273  		if err = c.ReadBody(&raw); err != nil {
   274  			return nil, err
   275  		}
   276  		return ([]byte)(raw), nil
   277  	case strings.Contains(ct, "application/proto-rpc"), strings.Contains(ct, "application/octet-stream"):
   278  		msg := codec.Message{
   279  			Type:   codec.Request,
   280  			Header: make(map[string]string),
   281  		}
   282  		c := protorpc.NewCodec(&buffer{r.Body})
   283  		if err = c.ReadHeader(&msg, codec.Request); err != nil {
   284  			return nil, err
   285  		}
   286  		var raw proto.Message
   287  		if err = c.ReadBody(&raw); err != nil {
   288  			return nil, err
   289  		}
   290  		return raw.Marshal()
   291  	case strings.Contains(ct, "application/www-x-form-urlencoded"):
   292  		r.ParseForm()
   293  
   294  		// generate a new set of values from the form
   295  		vals := make(map[string]string)
   296  		for k, v := range r.Form {
   297  			vals[k] = strings.Join(v, ",")
   298  		}
   299  
   300  		// marshal
   301  		return json.Marshal(vals)
   302  		// TODO: application/grpc
   303  	}
   304  
   305  	// otherwise as per usual
   306  	ctx := r.Context()
   307  	// dont user meadata.FromContext as it mangles names
   308  	md, ok := metadata.FromContext(ctx)
   309  	if !ok {
   310  		md = make(map[string]string)
   311  	}
   312  
   313  	// allocate maximum
   314  	matches := make(map[string]interface{}, len(md))
   315  	bodydst := ""
   316  
   317  	// get fields from url path
   318  	for k, v := range md {
   319  		k = strings.ToLower(k)
   320  		// filter own keys
   321  		if strings.HasPrefix(k, "x-api-field-") {
   322  			matches[strings.TrimPrefix(k, "x-api-field-")] = v
   323  			delete(md, k)
   324  		} else if k == "x-api-body" {
   325  			bodydst = v
   326  			delete(md, k)
   327  		}
   328  	}
   329  
   330  	// map of all fields
   331  	req := make(map[string]interface{}, len(md))
   332  
   333  	// get fields from url values
   334  	if len(r.URL.RawQuery) > 0 {
   335  		umd := make(map[string]interface{})
   336  		err = qson.Unmarshal(&umd, r.URL.RawQuery)
   337  		if err != nil {
   338  			return nil, err
   339  		}
   340  		for k, v := range umd {
   341  			matches[k] = v
   342  		}
   343  	}
   344  
   345  	// restore context without fields
   346  	*r = *r.Clone(metadata.NewContext(ctx, md))
   347  
   348  	for k, v := range matches {
   349  		ps := strings.Split(k, ".")
   350  		if len(ps) == 1 {
   351  			req[k] = v
   352  			continue
   353  		}
   354  		em := make(map[string]interface{})
   355  		em[ps[len(ps)-1]] = v
   356  		for i := len(ps) - 2; i > 0; i-- {
   357  			nm := make(map[string]interface{})
   358  			nm[ps[i]] = em
   359  			em = nm
   360  		}
   361  		if vm, ok := req[ps[0]]; ok {
   362  			// nested map
   363  			nm := vm.(map[string]interface{})
   364  			for vk, vv := range em {
   365  				nm[vk] = vv
   366  			}
   367  			req[ps[0]] = nm
   368  		} else {
   369  			req[ps[0]] = em
   370  		}
   371  	}
   372  	pathbuf := []byte("{}")
   373  	if len(req) > 0 {
   374  		pathbuf, err = json.Marshal(req)
   375  		if err != nil {
   376  			return nil, err
   377  		}
   378  	}
   379  
   380  	urlbuf := []byte("{}")
   381  	out, err := jsonpatch.MergeMergePatches(urlbuf, pathbuf)
   382  	if err != nil {
   383  		return nil, err
   384  	}
   385  
   386  	switch r.Method {
   387  	case "GET":
   388  		// empty response
   389  		if strings.Contains(ct, "application/json") && string(out) == "{}" {
   390  			return out, nil
   391  		} else if string(out) == "{}" && !strings.Contains(ct, "application/json") {
   392  			return []byte{}, nil
   393  		}
   394  		return out, nil
   395  	case "PATCH", "POST", "PUT", "DELETE":
   396  		bodybuf := []byte("{}")
   397  		buf := bufferPool.Get()
   398  		defer bufferPool.Put(buf)
   399  		if _, err := buf.ReadFrom(r.Body); err != nil {
   400  			return nil, err
   401  		}
   402  		if b := buf.Bytes(); len(b) > 0 {
   403  			bodybuf = b
   404  		}
   405  		if bodydst == "" || bodydst == "*" {
   406  			if out, err = jsonpatch.MergeMergePatches(out, bodybuf); err == nil {
   407  				return out, nil
   408  			}
   409  		}
   410  		var jsonbody map[string]interface{}
   411  		if json.Valid(bodybuf) {
   412  			if err = json.Unmarshal(bodybuf, &jsonbody); err != nil {
   413  				return nil, err
   414  			}
   415  		}
   416  		dstmap := make(map[string]interface{})
   417  		ps := strings.Split(bodydst, ".")
   418  		if len(ps) == 1 {
   419  			if jsonbody != nil {
   420  				dstmap[ps[0]] = jsonbody
   421  			} else {
   422  				// old unexpected behaviour
   423  				dstmap[ps[0]] = bodybuf
   424  			}
   425  		} else {
   426  			em := make(map[string]interface{})
   427  			if jsonbody != nil {
   428  				em[ps[len(ps)-1]] = jsonbody
   429  			} else {
   430  				// old unexpected behaviour
   431  				em[ps[len(ps)-1]] = bodybuf
   432  			}
   433  			for i := len(ps) - 2; i > 0; i-- {
   434  				nm := make(map[string]interface{})
   435  				nm[ps[i]] = em
   436  				em = nm
   437  			}
   438  			dstmap[ps[0]] = em
   439  		}
   440  
   441  		bodyout, err := json.Marshal(dstmap)
   442  		if err != nil {
   443  			return nil, err
   444  		}
   445  
   446  		if out, err = jsonpatch.MergeMergePatches(out, bodyout); err == nil {
   447  			return out, nil
   448  		}
   449  
   450  		//fallback to previous unknown behaviour
   451  		return bodybuf, nil
   452  
   453  	}
   454  
   455  	return []byte{}, nil
   456  }
   457  
   458  func writeError(w http.ResponseWriter, r *http.Request, err error) {
   459  	ce := errors.Parse(err.Error())
   460  
   461  	switch ce.Code {
   462  	case 0:
   463  		// assuming it's totally screwed
   464  		ce.Code = 500
   465  		ce.Id = "go.micro.api"
   466  		ce.Status = http.StatusText(500)
   467  		ce.Detail = "error during request: " + ce.Detail
   468  		w.WriteHeader(500)
   469  	default:
   470  		w.WriteHeader(int(ce.Code))
   471  	}
   472  
   473  	// response content type
   474  	w.Header().Set("Content-Type", "application/json")
   475  
   476  	// Set trailers
   477  	if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   478  		w.Header().Set("Trailer", "grpc-status")
   479  		w.Header().Set("Trailer", "grpc-message")
   480  		w.Header().Set("grpc-status", "13")
   481  		w.Header().Set("grpc-message", ce.Detail)
   482  	}
   483  
   484  	_, werr := w.Write([]byte(ce.Error()))
   485  	if werr != nil {
   486  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   487  			logger.Error(werr)
   488  		}
   489  	}
   490  }
   491  
   492  func writeResponse(w http.ResponseWriter, r *http.Request, rsp []byte) {
   493  	ct := r.Header.Get("Content-Type")
   494  	if ct == "" {
   495  		ct = "application/json"
   496  	}
   497  	w.Header().Set("Content-Type", ct)
   498  	w.Header().Set("Content-Length", strconv.Itoa(len(rsp)))
   499  
   500  	// Set trailers
   501  	if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   502  		w.Header().Set("Trailer", "grpc-status")
   503  		w.Header().Set("Trailer", "grpc-message")
   504  		w.Header().Set("grpc-status", "0")
   505  		w.Header().Set("grpc-message", "")
   506  	}
   507  
   508  	// write 204 status if rsp is nil
   509  	if len(rsp) == 0 {
   510  		w.WriteHeader(http.StatusNoContent)
   511  	}
   512  
   513  	// write response
   514  	_, err := w.Write(rsp)
   515  	if err != nil {
   516  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   517  			logger.Error(err)
   518  		}
   519  	}
   520  
   521  }
   522  
   523  func NewHandler(opts ...handler.Option) handler.Handler {
   524  	options := handler.NewOptions(opts...)
   525  	return &rpcHandler{
   526  		opts: options,
   527  	}
   528  }
   529  
   530  func WithService(s *api.Service, opts ...handler.Option) handler.Handler {
   531  	options := handler.NewOptions(opts...)
   532  	return &rpcHandler{
   533  		opts: options,
   534  		s:    s,
   535  	}
   536  }