github.com/annwntech/go-micro/v2@v2.9.5/handler/rpc.go (about) 1 package handler 2 3 import ( 4 "encoding/json" 5 "net/http" 6 "strconv" 7 "strings" 8 "time" 9 10 "github.com/annwntech/go-micro/v2/api/handler" 11 "github.com/annwntech/go-micro/v2/api/resolver" 12 "github.com/annwntech/go-micro/v2/api/server/cors" 13 "github.com/annwntech/go-micro/v2/client" 14 "github.com/annwntech/go-micro/v2/errors" 15 "github.com/annwntech/go-micro/v2/helper" 16 ) 17 18 type rpcRequest struct { 19 Service string 20 Endpoint string 21 Method string 22 Address string 23 Request interface{} 24 } 25 26 type rpcHandler struct { 27 resolver resolver.Resolver 28 } 29 30 func (h *rpcHandler) String() string { 31 return "internal/rpc" 32 } 33 34 // ServeHTTP passes on a JSON or form encoded RPC request to a service. 35 func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 36 if r.Method == "OPTIONS" { 37 cors.SetHeaders(w, r) 38 return 39 } 40 41 if r.Method != "POST" { 42 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 43 return 44 } 45 defer r.Body.Close() 46 47 badRequest := func(description string) { 48 e := errors.BadRequest("micro.rpc", description) 49 w.WriteHeader(400) 50 w.Write([]byte(e.Error())) 51 } 52 53 var service, endpoint, address string 54 var request interface{} 55 56 // response content type 57 w.Header().Set("Content-Type", "application/json") 58 59 ct := r.Header.Get("Content-Type") 60 61 // Strip charset from Content-Type (like `application/json; charset=UTF-8`) 62 if idx := strings.IndexRune(ct, ';'); idx >= 0 { 63 ct = ct[:idx] 64 } 65 66 switch ct { 67 case "application/json": 68 var rpcReq rpcRequest 69 70 d := json.NewDecoder(r.Body) 71 d.UseNumber() 72 73 if err := d.Decode(&rpcReq); err != nil { 74 badRequest(err.Error()) 75 return 76 } 77 78 service = rpcReq.Service 79 endpoint = rpcReq.Endpoint 80 address = rpcReq.Address 81 request = rpcReq.Request 82 if len(endpoint) == 0 { 83 endpoint = rpcReq.Method 84 } 85 86 // JSON as string 87 if req, ok := rpcReq.Request.(string); ok { 88 d := json.NewDecoder(strings.NewReader(req)) 89 d.UseNumber() 90 91 if err := d.Decode(&request); err != nil { 92 badRequest("error decoding request string: " + err.Error()) 93 return 94 } 95 } 96 default: 97 r.ParseForm() 98 service = r.Form.Get("service") 99 endpoint = r.Form.Get("endpoint") 100 address = r.Form.Get("address") 101 if len(endpoint) == 0 { 102 endpoint = r.Form.Get("method") 103 } 104 105 d := json.NewDecoder(strings.NewReader(r.Form.Get("request"))) 106 d.UseNumber() 107 108 if err := d.Decode(&request); err != nil { 109 badRequest("error decoding request string: " + err.Error()) 110 return 111 } 112 } 113 114 if len(service) == 0 { 115 badRequest("invalid service") 116 return 117 } 118 119 if len(endpoint) == 0 { 120 badRequest("invalid endpoint") 121 return 122 } 123 124 // create request/response 125 var response json.RawMessage 126 var err error 127 req := client.DefaultClient.NewRequest(service, endpoint, request, client.WithContentType("application/json")) 128 129 // create context 130 ctx := helper.RequestToContext(r) 131 132 var opts []client.CallOption 133 134 timeout, _ := strconv.Atoi(r.Header.Get("Timeout")) 135 // set timeout 136 if timeout > 0 { 137 opts = append(opts, client.WithRequestTimeout(time.Duration(timeout)*time.Second)) 138 } 139 140 // remote call 141 if len(address) > 0 { 142 opts = append(opts, client.WithAddress(address)) 143 } 144 145 // since services can be running in many domains, we'll use the resolver to determine the domain 146 // which should be used on the call 147 // if resolver, ok := h.resolver.(*subdomain.Resolver); ok { 148 // if dom := resolver.Domain(r); len(dom) > 0 { 149 // opts = append(opts, client.WithNetwork(dom)) 150 // } 151 // } 152 153 // remote call 154 err = client.DefaultClient.Call(ctx, req, &response, opts...) 155 if err != nil { 156 ce := errors.Parse(err.Error()) 157 switch ce.Code { 158 case 0: 159 // assuming it's totally screwed 160 ce.Code = 500 161 ce.Id = "micro.rpc" 162 ce.Status = http.StatusText(500) 163 ce.Detail = "error during request: " + ce.Detail 164 w.WriteHeader(500) 165 default: 166 w.WriteHeader(int(ce.Code)) 167 } 168 w.Write([]byte(ce.Error())) 169 return 170 } 171 172 b, _ := response.MarshalJSON() 173 w.Header().Set("Content-Length", strconv.Itoa(len(b))) 174 w.Write(b) 175 } 176 177 // NewRPCHandler returns an initialized RPC handler 178 func NewRPCHandler(r resolver.Resolver) handler.Handler { 179 return &rpcHandler{r} 180 }