github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/service/api/handler/rpc/stream.go (about)

     1  // Copyright 2020 Asim Aslam
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // Original source: github.com/micro/go-micro/v3/api/handler/rpc/stream.go
    16  
    17  package rpc
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/gorilla/websocket"
    31  	pbapi "github.com/tickoalcantara12/micro/v3/proto/api"
    32  	"github.com/tickoalcantara12/micro/v3/service/api"
    33  	"github.com/tickoalcantara12/micro/v3/service/client"
    34  	"github.com/tickoalcantara12/micro/v3/service/errors"
    35  	"github.com/tickoalcantara12/micro/v3/service/logger"
    36  	raw "github.com/tickoalcantara12/micro/v3/util/codec/bytes"
    37  	"github.com/tickoalcantara12/micro/v3/util/router"
    38  )
    39  
    40  const (
    41  	// Time allowed to write a message to the client.
    42  	writeWait = 10 * time.Second
    43  
    44  	// Time allowed to read the next pong message from the client.
    45  	pongWait = 60 * time.Second
    46  
    47  	// Send pings to client with this period. Must be less than pongWait.
    48  	pingPeriod = 15 * time.Second
    49  
    50  	// Maximum message size allowed from client.
    51  	maxMessageSize = 512
    52  )
    53  
    54  var upgrader = websocket.Upgrader{
    55  	ReadBufferSize:  1024,
    56  	WriteBufferSize: 1024,
    57  	CheckOrigin: func(r *http.Request) bool {
    58  		return true
    59  	},
    60  }
    61  
    62  func serveStream(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) {
    63  	// serve as websocket if thats the case
    64  	if isWebSocket(r) {
    65  		serveWebsocket(ctx, w, r, service, c)
    66  		return
    67  	}
    68  
    69  	ct := r.Header.Get("Content-Type")
    70  	// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
    71  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
    72  		ct = ct[:idx]
    73  	}
    74  
    75  	payload, err := api.RequestPayload(r)
    76  	if err != nil {
    77  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
    78  			logger.Error(err)
    79  		}
    80  		return
    81  	}
    82  	if len(payload) == 0 {
    83  		// make it valid json
    84  		payload = []byte("{}")
    85  	}
    86  
    87  	var request interface{}
    88  	if !bytes.Equal(payload, []byte(`{}`)) {
    89  		switch ct {
    90  		case "application/json", "":
    91  			m := json.RawMessage(payload)
    92  			request = &m
    93  		default:
    94  			request = &raw.Frame{Data: payload}
    95  		}
    96  	}
    97  
    98  	// we always need to set content type for message
    99  	if ct == "" {
   100  		ct = "application/json"
   101  	}
   102  	req := c.NewRequest(
   103  		service.Name,
   104  		service.Endpoint.Name,
   105  		request,
   106  		client.WithContentType(ct),
   107  		client.StreamingRequest(),
   108  	)
   109  
   110  	w.Header().Set("Content-Type", ct)
   111  
   112  	// create custom router
   113  	var nodes []string
   114  	for _, service := range service.Services {
   115  		for _, node := range service.Nodes {
   116  			nodes = append(nodes, node.Address)
   117  		}
   118  	}
   119  
   120  	callOpt := client.WithAddress(nodes...)
   121  
   122  	// create a new stream
   123  	stream, err := c.Stream(ctx, req, callOpt)
   124  	if err != nil {
   125  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   126  			logger.Error(err)
   127  		}
   128  		merr, ok := err.(*errors.Error)
   129  		if ok {
   130  			w.WriteHeader(int(merr.Code))
   131  			w.Write([]byte(merr.Error()))
   132  		}
   133  		return
   134  	}
   135  	defer stream.Close()
   136  
   137  	// send request even if nil because it triggers the call in case server expects no input
   138  	// without this, we establish a connection but don't kick off the stream of communication
   139  	if err = stream.Send(request); err != nil {
   140  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   141  			logger.Error(err)
   142  		}
   143  		merr, ok := err.(*errors.Error)
   144  		if ok {
   145  			w.WriteHeader(int(merr.Code))
   146  			w.Write([]byte(merr.Error()))
   147  		} else {
   148  			w.WriteHeader(500)
   149  			w.Write([]byte(err.Error()))
   150  		}
   151  		return
   152  	}
   153  
   154  	rsp := stream.Response()
   155  
   156  	// receive from stream and send to client
   157  	for {
   158  		select {
   159  		case <-ctx.Done():
   160  			return
   161  		case <-stream.Context().Done():
   162  			return
   163  		default:
   164  			// read backend response body
   165  			buf, err := rsp.Read()
   166  			if err != nil {
   167  				// clean exit
   168  				if err == io.EOF {
   169  					return
   170  				}
   171  				// wants to avoid import  grpc/status.Status
   172  				if strings.Contains(err.Error(), "context canceled") {
   173  					return
   174  				}
   175  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   176  					logger.Error(err)
   177  				}
   178  				merr, ok := err.(*errors.Error)
   179  				if ok {
   180  					w.WriteHeader(int(merr.Code))
   181  					w.Write([]byte(merr.Error()))
   182  				}
   183  				return
   184  			}
   185  			var bufOut string
   186  			var apiRsp pbapi.Response
   187  			if err := json.Unmarshal(buf, &apiRsp); err == nil && apiRsp.StatusCode > 0 {
   188  				// bit of a hack. If the response is actually an api response we want to set the headers and status code
   189  				for _, v := range apiRsp.Header {
   190  					for _, s := range v.Values {
   191  						w.Header().Add(v.Key, s)
   192  					}
   193  				}
   194  				w.WriteHeader(int(apiRsp.StatusCode))
   195  				bufOut = apiRsp.Body
   196  			} else {
   197  				bufOut = string(buf)
   198  			}
   199  
   200  			// send the buffer
   201  			_, err = fmt.Fprint(w, bufOut)
   202  			if err != nil {
   203  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   204  					logger.Error(err)
   205  				}
   206  			}
   207  
   208  			// flush it
   209  			flusher, ok := w.(http.Flusher)
   210  			if ok {
   211  				flusher.Flush()
   212  			}
   213  		}
   214  	}
   215  }
   216  
   217  type stream struct {
   218  	// message type requested (binary or text)
   219  	messageType int
   220  	// request context
   221  	ctx context.Context
   222  	// the websocket connection.
   223  	conn *websocket.Conn
   224  	// the downstream connection.
   225  	stream client.Stream
   226  }
   227  
   228  func (s *stream) processWSReadsAndWrites() {
   229  	defer func() {
   230  		s.conn.Close()
   231  	}()
   232  
   233  	msgs := make(chan []byte)
   234  
   235  	stopCtx, cancel := context.WithCancel(context.Background())
   236  	wg := sync.WaitGroup{}
   237  	wg.Add(3)
   238  	go s.rspToBufLoop(cancel, &wg, stopCtx, msgs)
   239  	go s.bufToClientLoop(cancel, &wg, stopCtx, msgs)
   240  	go s.clientToServerLoop(cancel, &wg, stopCtx)
   241  	wg.Wait()
   242  }
   243  
   244  func (s *stream) clientToServerLoop(cancel context.CancelFunc, wg *sync.WaitGroup, stopCtx context.Context) {
   245  	defer func() {
   246  		s.stream.Close()
   247  		cancel()
   248  		wg.Done()
   249  	}()
   250  	s.conn.SetReadLimit(maxMessageSize)
   251  	s.conn.SetReadDeadline(time.Now().Add(pongWait))
   252  	s.conn.SetPongHandler(func(string) error { s.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
   253  
   254  	for {
   255  		select {
   256  		case <-stopCtx.Done():
   257  			return
   258  		default:
   259  		}
   260  
   261  		_, msg, err := s.conn.ReadMessage()
   262  		if err != nil {
   263  			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
   264  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   265  					logger.Error(err)
   266  				}
   267  			}
   268  			return
   269  		}
   270  
   271  		var request interface{}
   272  		switch s.messageType {
   273  		case websocket.TextMessage:
   274  			m := json.RawMessage(msg)
   275  			request = &m
   276  		default:
   277  			request = &raw.Frame{Data: msg}
   278  		}
   279  
   280  		if err := s.stream.Send(request); err != nil {
   281  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   282  				logger.Error(err)
   283  			}
   284  			return
   285  		}
   286  	}
   287  
   288  }
   289  
   290  func (s *stream) rspToBufLoop(cancel context.CancelFunc, wg *sync.WaitGroup, stopCtx context.Context, msgs chan []byte) {
   291  	defer func() {
   292  		cancel()
   293  		wg.Done()
   294  	}()
   295  	rsp := s.stream.Response()
   296  	for {
   297  		select {
   298  		case <-stopCtx.Done():
   299  			return
   300  		default:
   301  		}
   302  		bytes, err := rsp.Read()
   303  		if err != nil {
   304  			if err == io.EOF {
   305  				// clean exit
   306  				return
   307  			}
   308  			// write error then close the connection
   309  			b, _ := json.Marshal(err)
   310  			s.conn.WriteMessage(s.messageType, b)
   311  			s.conn.WriteMessage(websocket.CloseAbnormalClosure, []byte{})
   312  			return
   313  		}
   314  		select {
   315  		case <-stopCtx.Done():
   316  			return
   317  		case msgs <- bytes:
   318  		}
   319  
   320  	}
   321  
   322  }
   323  
   324  func (s *stream) bufToClientLoop(cancel context.CancelFunc, wg *sync.WaitGroup, stopCtx context.Context, msgs chan []byte) {
   325  	defer func() {
   326  		s.conn.Close()
   327  		cancel()
   328  		wg.Done()
   329  
   330  	}()
   331  	ticker := time.NewTicker(pingPeriod)
   332  	defer ticker.Stop()
   333  	for {
   334  		select {
   335  		case <-stopCtx.Done():
   336  			return
   337  		case <-s.ctx.Done():
   338  			return
   339  		case <-s.stream.Context().Done():
   340  			s.conn.WriteMessage(websocket.CloseMessage, []byte{})
   341  			return
   342  		case <-ticker.C:
   343  			s.conn.SetWriteDeadline(time.Now().Add(writeWait))
   344  			if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
   345  				return
   346  			}
   347  		case msg := <-msgs:
   348  			// read response body
   349  			s.conn.SetWriteDeadline(time.Now().Add(writeWait))
   350  			w, err := s.conn.NextWriter(s.messageType)
   351  			if err != nil {
   352  				return
   353  			}
   354  			if _, err := w.Write(msg); err != nil {
   355  				return
   356  			}
   357  			if err := w.Close(); err != nil {
   358  				return
   359  			}
   360  		}
   361  	}
   362  
   363  }
   364  
   365  // serveWebsocket will stream rpc back over websockets assuming json
   366  func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) {
   367  	var rspHdr http.Header
   368  	// we use Sec-Websocket-Protocol to pass auth headers so just accept anything here
   369  	if prots := r.Header.Values("Sec-WebSocket-Protocol"); len(prots) > 0 {
   370  		rspHdr = http.Header{}
   371  		for _, p := range prots {
   372  			rspHdr.Add("Sec-WebSocket-Protocol", p)
   373  		}
   374  	}
   375  
   376  	conn, err := upgrader.Upgrade(w, r, rspHdr)
   377  	if err != nil {
   378  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   379  			logger.Error(err)
   380  		}
   381  		return
   382  	}
   383  
   384  	// determine the content type
   385  	ct := r.Header.Get("Content-Type")
   386  	// strip charset from Content-Type (like `application/json; charset=UTF-8`)
   387  	if idx := strings.IndexRune(ct, ';'); idx >= 0 {
   388  		ct = ct[:idx]
   389  	}
   390  	if len(ct) == 0 {
   391  		ct = "application/json"
   392  	}
   393  
   394  	// create stream
   395  	req := c.NewRequest(service.Name, service.Endpoint.Name, nil, client.WithContentType(ct), client.StreamingRequest())
   396  	str, err := c.Stream(ctx, req, client.WithRouter(router.New(service.Services)))
   397  	if err != nil {
   398  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   399  			logger.Error(err)
   400  		}
   401  		return
   402  	}
   403  
   404  	// determine the message type
   405  	msgType := websocket.BinaryMessage
   406  	if ct == "application/json" {
   407  		msgType = websocket.TextMessage
   408  	}
   409  
   410  	s := stream{ctx: ctx, conn: conn, stream: str, messageType: msgType}
   411  	s.processWSReadsAndWrites()
   412  }
   413  
   414  func isStream(r *http.Request, srv *api.Service) bool {
   415  	// check if the endpoint supports streaming
   416  	for _, service := range srv.Services {
   417  		for _, ep := range service.Endpoints {
   418  			// skip if it doesn't match the name
   419  			if ep.Name != srv.Endpoint.Name {
   420  				continue
   421  			}
   422  			// matched if the name
   423  			if v := ep.Metadata["stream"]; v == "true" {
   424  				return true
   425  			}
   426  		}
   427  	}
   428  
   429  	return false
   430  }
   431  
   432  func isWebSocket(r *http.Request) bool {
   433  	contains := func(key, val string) bool {
   434  		vv := strings.Split(r.Header.Get(key), ",")
   435  		for _, v := range vv {
   436  			if val == strings.ToLower(strings.TrimSpace(v)) {
   437  				return true
   438  			}
   439  		}
   440  		return false
   441  	}
   442  
   443  	if contains("Connection", "upgrade") && contains("Upgrade", "websocket") {
   444  		return true
   445  	}
   446  
   447  	return false
   448  }