github.com/annwntech/go-micro/v2@v2.9.5/handler/rpc.go (about)

     1  package handler
     2  
     3  import (
     4  	"encoding/json"
     5  	"net/http"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/annwntech/go-micro/v2/api/handler"
    11  	"github.com/annwntech/go-micro/v2/api/resolver"
    12  	"github.com/annwntech/go-micro/v2/api/server/cors"
    13  	"github.com/annwntech/go-micro/v2/client"
    14  	"github.com/annwntech/go-micro/v2/errors"
    15  	"github.com/annwntech/go-micro/v2/helper"
    16  )
    17  
    18  type rpcRequest struct {
    19  	Service  string
    20  	Endpoint string
    21  	Method   string
    22  	Address  string
    23  	Request  interface{}
    24  }
    25  
    26  type rpcHandler struct {
    27  	resolver resolver.Resolver
    28  }
    29  
    30  func (h *rpcHandler) String() string {
    31  	return "internal/rpc"
    32  }
    33  
    34  // ServeHTTP passes on a JSON or form encoded RPC request to a service.
    35  func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    36  	if r.Method == "OPTIONS" {
    37  		cors.SetHeaders(w, r)
    38  		return
    39  	}
    40  
    41  	if r.Method != "POST" {
    42  		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
    43  		return
    44  	}
    45  	defer r.Body.Close()
    46  
    47  	badRequest := func(description string) {
    48  		e := errors.BadRequest("micro.rpc", description)
    49  		w.WriteHeader(400)
    50  		w.Write([]byte(e.Error()))
    51  	}
    52  
    53  	var service, endpoint, address string
    54  	var request interface{}
    55  
    56  	// response content type
    57  	w.Header().Set("Content-Type", "application/json")
    58  
    59  	ct := r.Header.Get("Content-Type")
    60  
    61  	// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
    62  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
    63  		ct = ct[:idx]
    64  	}
    65  
    66  	switch ct {
    67  	case "application/json":
    68  		var rpcReq rpcRequest
    69  
    70  		d := json.NewDecoder(r.Body)
    71  		d.UseNumber()
    72  
    73  		if err := d.Decode(&rpcReq); err != nil {
    74  			badRequest(err.Error())
    75  			return
    76  		}
    77  
    78  		service = rpcReq.Service
    79  		endpoint = rpcReq.Endpoint
    80  		address = rpcReq.Address
    81  		request = rpcReq.Request
    82  		if len(endpoint) == 0 {
    83  			endpoint = rpcReq.Method
    84  		}
    85  
    86  		// JSON as string
    87  		if req, ok := rpcReq.Request.(string); ok {
    88  			d := json.NewDecoder(strings.NewReader(req))
    89  			d.UseNumber()
    90  
    91  			if err := d.Decode(&request); err != nil {
    92  				badRequest("error decoding request string: " + err.Error())
    93  				return
    94  			}
    95  		}
    96  	default:
    97  		r.ParseForm()
    98  		service = r.Form.Get("service")
    99  		endpoint = r.Form.Get("endpoint")
   100  		address = r.Form.Get("address")
   101  		if len(endpoint) == 0 {
   102  			endpoint = r.Form.Get("method")
   103  		}
   104  
   105  		d := json.NewDecoder(strings.NewReader(r.Form.Get("request")))
   106  		d.UseNumber()
   107  
   108  		if err := d.Decode(&request); err != nil {
   109  			badRequest("error decoding request string: " + err.Error())
   110  			return
   111  		}
   112  	}
   113  
   114  	if len(service) == 0 {
   115  		badRequest("invalid service")
   116  		return
   117  	}
   118  
   119  	if len(endpoint) == 0 {
   120  		badRequest("invalid endpoint")
   121  		return
   122  	}
   123  
   124  	// create request/response
   125  	var response json.RawMessage
   126  	var err error
   127  	req := client.DefaultClient.NewRequest(service, endpoint, request, client.WithContentType("application/json"))
   128  
   129  	// create context
   130  	ctx := helper.RequestToContext(r)
   131  
   132  	var opts []client.CallOption
   133  
   134  	timeout, _ := strconv.Atoi(r.Header.Get("Timeout"))
   135  	// set timeout
   136  	if timeout > 0 {
   137  		opts = append(opts, client.WithRequestTimeout(time.Duration(timeout)*time.Second))
   138  	}
   139  
   140  	// remote call
   141  	if len(address) > 0 {
   142  		opts = append(opts, client.WithAddress(address))
   143  	}
   144  
   145  	// since services can be running in many domains, we'll use the resolver to determine the domain
   146  	// which should be used on the call
   147  	// if resolver, ok := h.resolver.(*subdomain.Resolver); ok {
   148  	// 	if dom := resolver.Domain(r); len(dom) > 0 {
   149  	// 		opts = append(opts, client.WithNetwork(dom))
   150  	// 	}
   151  	// }
   152  
   153  	// remote call
   154  	err = client.DefaultClient.Call(ctx, req, &response, opts...)
   155  	if err != nil {
   156  		ce := errors.Parse(err.Error())
   157  		switch ce.Code {
   158  		case 0:
   159  			// assuming it's totally screwed
   160  			ce.Code = 500
   161  			ce.Id = "micro.rpc"
   162  			ce.Status = http.StatusText(500)
   163  			ce.Detail = "error during request: " + ce.Detail
   164  			w.WriteHeader(500)
   165  		default:
   166  			w.WriteHeader(int(ce.Code))
   167  		}
   168  		w.Write([]byte(ce.Error()))
   169  		return
   170  	}
   171  
   172  	b, _ := response.MarshalJSON()
   173  	w.Header().Set("Content-Length", strconv.Itoa(len(b)))
   174  	w.Write(b)
   175  }
   176  
   177  // NewRPCHandler returns an initialized RPC handler
   178  func NewRPCHandler(r resolver.Resolver) handler.Handler {
   179  	return &rpcHandler{r}
   180  }