github.com/btccom/go-micro/v2@v2.9.3/api/handler/rpc/stream.go (about)

     1  package rpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/gobwas/httphead"
    13  	"github.com/gobwas/ws"
    14  	"github.com/gobwas/ws/wsutil"
    15  	"github.com/btccom/go-micro/v2/api"
    16  	"github.com/btccom/go-micro/v2/client"
    17  	"github.com/btccom/go-micro/v2/client/selector"
    18  	raw "github.com/btccom/go-micro/v2/codec/bytes"
    19  	"github.com/btccom/go-micro/v2/logger"
    20  )
    21  
    22  // serveWebsocket will stream rpc back over websockets assuming json
    23  func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) {
    24  	var op ws.OpCode
    25  
    26  	ct := r.Header.Get("Content-Type")
    27  	// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
    28  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
    29  		ct = ct[:idx]
    30  	}
    31  
    32  	// check proto from request
    33  	switch ct {
    34  	case "application/json":
    35  		op = ws.OpText
    36  	default:
    37  		op = ws.OpBinary
    38  	}
    39  
    40  	hdr := make(http.Header)
    41  	if proto, ok := r.Header["Sec-WebSocket-Protocol"]; ok {
    42  		for _, p := range proto {
    43  			switch p {
    44  			case "binary":
    45  				hdr["Sec-WebSocket-Protocol"] = []string{"binary"}
    46  				op = ws.OpBinary
    47  			}
    48  		}
    49  	}
    50  	payload, err := requestPayload(r)
    51  	if err != nil {
    52  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
    53  			logger.Error(err)
    54  		}
    55  		return
    56  	}
    57  
    58  	upgrader := ws.HTTPUpgrader{Timeout: 5 * time.Second,
    59  		Protocol: func(proto string) bool {
    60  			if strings.Contains(proto, "binary") {
    61  				return true
    62  			}
    63  			// fallback to support all protocols now
    64  			return true
    65  		},
    66  		Extension: func(httphead.Option) bool {
    67  			// disable extensions for compatibility
    68  			return false
    69  		},
    70  		Header: hdr,
    71  	}
    72  
    73  	conn, rw, _, err := upgrader.Upgrade(r, w)
    74  	if err != nil {
    75  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
    76  			logger.Error(err)
    77  		}
    78  		return
    79  	}
    80  
    81  	defer func() {
    82  		if err := conn.Close(); err != nil {
    83  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
    84  				logger.Error(err)
    85  			}
    86  			return
    87  		}
    88  	}()
    89  
    90  	var request interface{}
    91  	if !bytes.Equal(payload, []byte(`{}`)) {
    92  		switch ct {
    93  		case "application/json", "":
    94  			m := json.RawMessage(payload)
    95  			request = &m
    96  		default:
    97  			request = &raw.Frame{Data: payload}
    98  		}
    99  	}
   100  
   101  	// we always need to set content type for message
   102  	if ct == "" {
   103  		ct = "application/json"
   104  	}
   105  	req := c.NewRequest(
   106  		service.Name,
   107  		service.Endpoint.Name,
   108  		request,
   109  		client.WithContentType(ct),
   110  		client.StreamingRequest(),
   111  	)
   112  
   113  	so := selector.WithStrategy(strategy(service.Services))
   114  	// create a new stream
   115  	stream, err := c.Stream(ctx, req, client.WithSelectOption(so))
   116  	if err != nil {
   117  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   118  			logger.Error(err)
   119  		}
   120  		return
   121  	}
   122  
   123  	if request != nil {
   124  		if err = stream.Send(request); err != nil {
   125  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   126  				logger.Error(err)
   127  			}
   128  			return
   129  		}
   130  	}
   131  
   132  	go writeLoop(rw, stream)
   133  
   134  	rsp := stream.Response()
   135  
   136  	// receive from stream and send to client
   137  	for {
   138  		select {
   139  		case <-ctx.Done():
   140  			return
   141  		case <-stream.Context().Done():
   142  			return
   143  		default:
   144  			// read backend response body
   145  			buf, err := rsp.Read()
   146  			if err != nil {
   147  				// wants to avoid import  grpc/status.Status
   148  				if strings.Contains(err.Error(), "context canceled") {
   149  					return
   150  				}
   151  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   152  					logger.Error(err)
   153  				}
   154  				return
   155  			}
   156  
   157  			// write the response
   158  			if err := wsutil.WriteServerMessage(rw, op, buf); err != nil {
   159  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   160  					logger.Error(err)
   161  				}
   162  				return
   163  			}
   164  			if err = rw.Flush(); err != nil {
   165  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   166  					logger.Error(err)
   167  				}
   168  				return
   169  			}
   170  		}
   171  	}
   172  }
   173  
   174  // writeLoop
   175  func writeLoop(rw io.ReadWriter, stream client.Stream) {
   176  	// close stream when done
   177  	defer stream.Close()
   178  
   179  	for {
   180  		select {
   181  		case <-stream.Context().Done():
   182  			return
   183  		default:
   184  			buf, op, err := wsutil.ReadClientData(rw)
   185  			if err != nil {
   186  				if wserr, ok := err.(wsutil.ClosedError); ok {
   187  					switch wserr.Code {
   188  					case ws.StatusGoingAway:
   189  						// this happens when user leave the page
   190  						return
   191  					case ws.StatusNormalClosure, ws.StatusNoStatusRcvd:
   192  						// this happens when user close ws connection, or we don't get any status
   193  						return
   194  					}
   195  				}
   196  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   197  					logger.Error(err)
   198  				}
   199  				return
   200  			}
   201  			switch op {
   202  			default:
   203  				// not relevant
   204  				continue
   205  			case ws.OpText, ws.OpBinary:
   206  				break
   207  			}
   208  			// send to backend
   209  			// default to trying json
   210  			// if the extracted payload isn't empty lets use it
   211  			request := &raw.Frame{Data: buf}
   212  			if err := stream.Send(request); err != nil {
   213  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   214  					logger.Error(err)
   215  				}
   216  				return
   217  			}
   218  		}
   219  	}
   220  }
   221  
   222  func isStream(r *http.Request, srv *api.Service) bool {
   223  	// check if it's a web socket
   224  	if !isWebSocket(r) {
   225  		return false
   226  	}
   227  	// check if the endpoint supports streaming
   228  	for _, service := range srv.Services {
   229  		for _, ep := range service.Endpoints {
   230  			// skip if it doesn't match the name
   231  			if ep.Name != srv.Endpoint.Name {
   232  				continue
   233  			}
   234  			// matched if the name
   235  			if v := ep.Metadata["stream"]; v == "true" {
   236  				return true
   237  			}
   238  		}
   239  	}
   240  	return false
   241  }
   242  
   243  func isWebSocket(r *http.Request) bool {
   244  	contains := func(key, val string) bool {
   245  		vv := strings.Split(r.Header.Get(key), ",")
   246  		for _, v := range vv {
   247  			if val == strings.ToLower(strings.TrimSpace(v)) {
   248  				return true
   249  			}
   250  		}
   251  		return false
   252  	}
   253  
   254  	if contains("Connection", "upgrade") && contains("Upgrade", "websocket") {
   255  		return true
   256  	}
   257  
   258  	return false
   259  }