github.com/btccom/go-micro/v2@v2.9.3/api/handler/rpc/stream.go (about) 1 package rpc 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "io" 8 "net/http" 9 "strings" 10 "time" 11 12 "github.com/gobwas/httphead" 13 "github.com/gobwas/ws" 14 "github.com/gobwas/ws/wsutil" 15 "github.com/btccom/go-micro/v2/api" 16 "github.com/btccom/go-micro/v2/client" 17 "github.com/btccom/go-micro/v2/client/selector" 18 raw "github.com/btccom/go-micro/v2/codec/bytes" 19 "github.com/btccom/go-micro/v2/logger" 20 ) 21 22 // serveWebsocket will stream rpc back over websockets assuming json 23 func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { 24 var op ws.OpCode 25 26 ct := r.Header.Get("Content-Type") 27 // Strip charset from Content-Type (like `application/json; charset=UTF-8`) 28 if idx := strings.IndexRune(ct, ';'); idx >= 0 { 29 ct = ct[:idx] 30 } 31 32 // check proto from request 33 switch ct { 34 case "application/json": 35 op = ws.OpText 36 default: 37 op = ws.OpBinary 38 } 39 40 hdr := make(http.Header) 41 if proto, ok := r.Header["Sec-WebSocket-Protocol"]; ok { 42 for _, p := range proto { 43 switch p { 44 case "binary": 45 hdr["Sec-WebSocket-Protocol"] = []string{"binary"} 46 op = ws.OpBinary 47 } 48 } 49 } 50 payload, err := requestPayload(r) 51 if err != nil { 52 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 53 logger.Error(err) 54 } 55 return 56 } 57 58 upgrader := ws.HTTPUpgrader{Timeout: 5 * time.Second, 59 Protocol: func(proto string) bool { 60 if strings.Contains(proto, "binary") { 61 return true 62 } 63 // fallback to support all protocols now 64 return true 65 }, 66 Extension: func(httphead.Option) bool { 67 // disable extensions for compatibility 68 return false 69 }, 70 Header: hdr, 71 } 72 73 conn, rw, _, err := upgrader.Upgrade(r, w) 74 if err != nil { 75 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 76 logger.Error(err) 77 } 78 return 79 } 80 81 defer func() { 82 if err := conn.Close(); err != nil { 83 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 84 logger.Error(err) 85 } 86 return 87 } 88 }() 89 90 var request interface{} 91 if !bytes.Equal(payload, []byte(`{}`)) { 92 switch ct { 93 case "application/json", "": 94 m := json.RawMessage(payload) 95 request = &m 96 default: 97 request = &raw.Frame{Data: payload} 98 } 99 } 100 101 // we always need to set content type for message 102 if ct == "" { 103 ct = "application/json" 104 } 105 req := c.NewRequest( 106 service.Name, 107 service.Endpoint.Name, 108 request, 109 client.WithContentType(ct), 110 client.StreamingRequest(), 111 ) 112 113 so := selector.WithStrategy(strategy(service.Services)) 114 // create a new stream 115 stream, err := c.Stream(ctx, req, client.WithSelectOption(so)) 116 if err != nil { 117 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 118 logger.Error(err) 119 } 120 return 121 } 122 123 if request != nil { 124 if err = stream.Send(request); err != nil { 125 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 126 logger.Error(err) 127 } 128 return 129 } 130 } 131 132 go writeLoop(rw, stream) 133 134 rsp := stream.Response() 135 136 // receive from stream and send to client 137 for { 138 select { 139 case <-ctx.Done(): 140 return 141 case <-stream.Context().Done(): 142 return 143 default: 144 // read backend response body 145 buf, err := rsp.Read() 146 if err != nil { 147 // wants to avoid import grpc/status.Status 148 if strings.Contains(err.Error(), "context canceled") { 149 return 150 } 151 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 152 logger.Error(err) 153 } 154 return 155 } 156 157 // write the response 158 if err := wsutil.WriteServerMessage(rw, op, buf); err != nil { 159 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 160 logger.Error(err) 161 } 162 return 163 } 164 if err = rw.Flush(); err != nil { 165 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 166 logger.Error(err) 167 } 168 return 169 } 170 } 171 } 172 } 173 174 // writeLoop 175 func writeLoop(rw io.ReadWriter, stream client.Stream) { 176 // close stream when done 177 defer stream.Close() 178 179 for { 180 select { 181 case <-stream.Context().Done(): 182 return 183 default: 184 buf, op, err := wsutil.ReadClientData(rw) 185 if err != nil { 186 if wserr, ok := err.(wsutil.ClosedError); ok { 187 switch wserr.Code { 188 case ws.StatusGoingAway: 189 // this happens when user leave the page 190 return 191 case ws.StatusNormalClosure, ws.StatusNoStatusRcvd: 192 // this happens when user close ws connection, or we don't get any status 193 return 194 } 195 } 196 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 197 logger.Error(err) 198 } 199 return 200 } 201 switch op { 202 default: 203 // not relevant 204 continue 205 case ws.OpText, ws.OpBinary: 206 break 207 } 208 // send to backend 209 // default to trying json 210 // if the extracted payload isn't empty lets use it 211 request := &raw.Frame{Data: buf} 212 if err := stream.Send(request); err != nil { 213 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 214 logger.Error(err) 215 } 216 return 217 } 218 } 219 } 220 } 221 222 func isStream(r *http.Request, srv *api.Service) bool { 223 // check if it's a web socket 224 if !isWebSocket(r) { 225 return false 226 } 227 // check if the endpoint supports streaming 228 for _, service := range srv.Services { 229 for _, ep := range service.Endpoints { 230 // skip if it doesn't match the name 231 if ep.Name != srv.Endpoint.Name { 232 continue 233 } 234 // matched if the name 235 if v := ep.Metadata["stream"]; v == "true" { 236 return true 237 } 238 } 239 } 240 return false 241 } 242 243 func isWebSocket(r *http.Request) bool { 244 contains := func(key, val string) bool { 245 vv := strings.Split(r.Header.Get(key), ",") 246 for _, v := range vv { 247 if val == strings.ToLower(strings.TrimSpace(v)) { 248 return true 249 } 250 } 251 return false 252 } 253 254 if contains("Connection", "upgrade") && contains("Upgrade", "websocket") { 255 return true 256 } 257 258 return false 259 }