github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/service/api/handler/rpc/stream.go (about) 1 // Copyright 2020 Asim Aslam 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Original source: github.com/micro/go-micro/v3/api/handler/rpc/stream.go 16 17 package rpc 18 19 import ( 20 "bytes" 21 "context" 22 "encoding/json" 23 "fmt" 24 "io" 25 "net/http" 26 "strings" 27 "sync" 28 "time" 29 30 "github.com/gorilla/websocket" 31 pbapi "github.com/tickoalcantara12/micro/v3/proto/api" 32 "github.com/tickoalcantara12/micro/v3/service/api" 33 "github.com/tickoalcantara12/micro/v3/service/client" 34 "github.com/tickoalcantara12/micro/v3/service/errors" 35 "github.com/tickoalcantara12/micro/v3/service/logger" 36 raw "github.com/tickoalcantara12/micro/v3/util/codec/bytes" 37 "github.com/tickoalcantara12/micro/v3/util/router" 38 ) 39 40 const ( 41 // Time allowed to write a message to the client. 42 writeWait = 10 * time.Second 43 44 // Time allowed to read the next pong message from the client. 45 pongWait = 60 * time.Second 46 47 // Send pings to client with this period. Must be less than pongWait. 48 pingPeriod = 15 * time.Second 49 50 // Maximum message size allowed from client. 51 maxMessageSize = 512 52 ) 53 54 var upgrader = websocket.Upgrader{ 55 ReadBufferSize: 1024, 56 WriteBufferSize: 1024, 57 CheckOrigin: func(r *http.Request) bool { 58 return true 59 }, 60 } 61 62 func serveStream(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { 63 // serve as websocket if thats the case 64 if isWebSocket(r) { 65 serveWebsocket(ctx, w, r, service, c) 66 return 67 } 68 69 ct := r.Header.Get("Content-Type") 70 // Strip charset from Content-Type (like `application/json; charset=UTF-8`) 71 if idx := strings.IndexRune(ct, ';'); idx >= 0 { 72 ct = ct[:idx] 73 } 74 75 payload, err := api.RequestPayload(r) 76 if err != nil { 77 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 78 logger.Error(err) 79 } 80 return 81 } 82 if len(payload) == 0 { 83 // make it valid json 84 payload = []byte("{}") 85 } 86 87 var request interface{} 88 if !bytes.Equal(payload, []byte(`{}`)) { 89 switch ct { 90 case "application/json", "": 91 m := json.RawMessage(payload) 92 request = &m 93 default: 94 request = &raw.Frame{Data: payload} 95 } 96 } 97 98 // we always need to set content type for message 99 if ct == "" { 100 ct = "application/json" 101 } 102 req := c.NewRequest( 103 service.Name, 104 service.Endpoint.Name, 105 request, 106 client.WithContentType(ct), 107 client.StreamingRequest(), 108 ) 109 110 w.Header().Set("Content-Type", ct) 111 112 // create custom router 113 var nodes []string 114 for _, service := range service.Services { 115 for _, node := range service.Nodes { 116 nodes = append(nodes, node.Address) 117 } 118 } 119 120 callOpt := client.WithAddress(nodes...) 121 122 // create a new stream 123 stream, err := c.Stream(ctx, req, callOpt) 124 if err != nil { 125 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 126 logger.Error(err) 127 } 128 merr, ok := err.(*errors.Error) 129 if ok { 130 w.WriteHeader(int(merr.Code)) 131 w.Write([]byte(merr.Error())) 132 } 133 return 134 } 135 defer stream.Close() 136 137 // send request even if nil because it triggers the call in case server expects no input 138 // without this, we establish a connection but don't kick off the stream of communication 139 if err = stream.Send(request); err != nil { 140 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 141 logger.Error(err) 142 } 143 merr, ok := err.(*errors.Error) 144 if ok { 145 w.WriteHeader(int(merr.Code)) 146 w.Write([]byte(merr.Error())) 147 } else { 148 w.WriteHeader(500) 149 w.Write([]byte(err.Error())) 150 } 151 return 152 } 153 154 rsp := stream.Response() 155 156 // receive from stream and send to client 157 for { 158 select { 159 case <-ctx.Done(): 160 return 161 case <-stream.Context().Done(): 162 return 163 default: 164 // read backend response body 165 buf, err := rsp.Read() 166 if err != nil { 167 // clean exit 168 if err == io.EOF { 169 return 170 } 171 // wants to avoid import grpc/status.Status 172 if strings.Contains(err.Error(), "context canceled") { 173 return 174 } 175 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 176 logger.Error(err) 177 } 178 merr, ok := err.(*errors.Error) 179 if ok { 180 w.WriteHeader(int(merr.Code)) 181 w.Write([]byte(merr.Error())) 182 } 183 return 184 } 185 var bufOut string 186 var apiRsp pbapi.Response 187 if err := json.Unmarshal(buf, &apiRsp); err == nil && apiRsp.StatusCode > 0 { 188 // bit of a hack. If the response is actually an api response we want to set the headers and status code 189 for _, v := range apiRsp.Header { 190 for _, s := range v.Values { 191 w.Header().Add(v.Key, s) 192 } 193 } 194 w.WriteHeader(int(apiRsp.StatusCode)) 195 bufOut = apiRsp.Body 196 } else { 197 bufOut = string(buf) 198 } 199 200 // send the buffer 201 _, err = fmt.Fprint(w, bufOut) 202 if err != nil { 203 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 204 logger.Error(err) 205 } 206 } 207 208 // flush it 209 flusher, ok := w.(http.Flusher) 210 if ok { 211 flusher.Flush() 212 } 213 } 214 } 215 } 216 217 type stream struct { 218 // message type requested (binary or text) 219 messageType int 220 // request context 221 ctx context.Context 222 // the websocket connection. 223 conn *websocket.Conn 224 // the downstream connection. 225 stream client.Stream 226 } 227 228 func (s *stream) processWSReadsAndWrites() { 229 defer func() { 230 s.conn.Close() 231 }() 232 233 msgs := make(chan []byte) 234 235 stopCtx, cancel := context.WithCancel(context.Background()) 236 wg := sync.WaitGroup{} 237 wg.Add(3) 238 go s.rspToBufLoop(cancel, &wg, stopCtx, msgs) 239 go s.bufToClientLoop(cancel, &wg, stopCtx, msgs) 240 go s.clientToServerLoop(cancel, &wg, stopCtx) 241 wg.Wait() 242 } 243 244 func (s *stream) clientToServerLoop(cancel context.CancelFunc, wg *sync.WaitGroup, stopCtx context.Context) { 245 defer func() { 246 s.stream.Close() 247 cancel() 248 wg.Done() 249 }() 250 s.conn.SetReadLimit(maxMessageSize) 251 s.conn.SetReadDeadline(time.Now().Add(pongWait)) 252 s.conn.SetPongHandler(func(string) error { s.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 253 254 for { 255 select { 256 case <-stopCtx.Done(): 257 return 258 default: 259 } 260 261 _, msg, err := s.conn.ReadMessage() 262 if err != nil { 263 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 264 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 265 logger.Error(err) 266 } 267 } 268 return 269 } 270 271 var request interface{} 272 switch s.messageType { 273 case websocket.TextMessage: 274 m := json.RawMessage(msg) 275 request = &m 276 default: 277 request = &raw.Frame{Data: msg} 278 } 279 280 if err := s.stream.Send(request); err != nil { 281 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 282 logger.Error(err) 283 } 284 return 285 } 286 } 287 288 } 289 290 func (s *stream) rspToBufLoop(cancel context.CancelFunc, wg *sync.WaitGroup, stopCtx context.Context, msgs chan []byte) { 291 defer func() { 292 cancel() 293 wg.Done() 294 }() 295 rsp := s.stream.Response() 296 for { 297 select { 298 case <-stopCtx.Done(): 299 return 300 default: 301 } 302 bytes, err := rsp.Read() 303 if err != nil { 304 if err == io.EOF { 305 // clean exit 306 return 307 } 308 // write error then close the connection 309 b, _ := json.Marshal(err) 310 s.conn.WriteMessage(s.messageType, b) 311 s.conn.WriteMessage(websocket.CloseAbnormalClosure, []byte{}) 312 return 313 } 314 select { 315 case <-stopCtx.Done(): 316 return 317 case msgs <- bytes: 318 } 319 320 } 321 322 } 323 324 func (s *stream) bufToClientLoop(cancel context.CancelFunc, wg *sync.WaitGroup, stopCtx context.Context, msgs chan []byte) { 325 defer func() { 326 s.conn.Close() 327 cancel() 328 wg.Done() 329 330 }() 331 ticker := time.NewTicker(pingPeriod) 332 defer ticker.Stop() 333 for { 334 select { 335 case <-stopCtx.Done(): 336 return 337 case <-s.ctx.Done(): 338 return 339 case <-s.stream.Context().Done(): 340 s.conn.WriteMessage(websocket.CloseMessage, []byte{}) 341 return 342 case <-ticker.C: 343 s.conn.SetWriteDeadline(time.Now().Add(writeWait)) 344 if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 345 return 346 } 347 case msg := <-msgs: 348 // read response body 349 s.conn.SetWriteDeadline(time.Now().Add(writeWait)) 350 w, err := s.conn.NextWriter(s.messageType) 351 if err != nil { 352 return 353 } 354 if _, err := w.Write(msg); err != nil { 355 return 356 } 357 if err := w.Close(); err != nil { 358 return 359 } 360 } 361 } 362 363 } 364 365 // serveWebsocket will stream rpc back over websockets assuming json 366 func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { 367 var rspHdr http.Header 368 // we use Sec-Websocket-Protocol to pass auth headers so just accept anything here 369 if prots := r.Header.Values("Sec-WebSocket-Protocol"); len(prots) > 0 { 370 rspHdr = http.Header{} 371 for _, p := range prots { 372 rspHdr.Add("Sec-WebSocket-Protocol", p) 373 } 374 } 375 376 conn, err := upgrader.Upgrade(w, r, rspHdr) 377 if err != nil { 378 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 379 logger.Error(err) 380 } 381 return 382 } 383 384 // determine the content type 385 ct := r.Header.Get("Content-Type") 386 // strip charset from Content-Type (like `application/json; charset=UTF-8`) 387 if idx := strings.IndexRune(ct, ';'); idx >= 0 { 388 ct = ct[:idx] 389 } 390 if len(ct) == 0 { 391 ct = "application/json" 392 } 393 394 // create stream 395 req := c.NewRequest(service.Name, service.Endpoint.Name, nil, client.WithContentType(ct), client.StreamingRequest()) 396 str, err := c.Stream(ctx, req, client.WithRouter(router.New(service.Services))) 397 if err != nil { 398 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 399 logger.Error(err) 400 } 401 return 402 } 403 404 // determine the message type 405 msgType := websocket.BinaryMessage 406 if ct == "application/json" { 407 msgType = websocket.TextMessage 408 } 409 410 s := stream{ctx: ctx, conn: conn, stream: str, messageType: msgType} 411 s.processWSReadsAndWrites() 412 } 413 414 func isStream(r *http.Request, srv *api.Service) bool { 415 // check if the endpoint supports streaming 416 for _, service := range srv.Services { 417 for _, ep := range service.Endpoints { 418 // skip if it doesn't match the name 419 if ep.Name != srv.Endpoint.Name { 420 continue 421 } 422 // matched if the name 423 if v := ep.Metadata["stream"]; v == "true" { 424 return true 425 } 426 } 427 } 428 429 return false 430 } 431 432 func isWebSocket(r *http.Request) bool { 433 contains := func(key, val string) bool { 434 vv := strings.Split(r.Header.Get(key), ",") 435 for _, v := range vv { 436 if val == strings.ToLower(strings.TrimSpace(v)) { 437 return true 438 } 439 } 440 return false 441 } 442 443 if contains("Connection", "upgrade") && contains("Upgrade", "websocket") { 444 return true 445 } 446 447 return false 448 }