github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/websocket.go (about)

     1  package larking
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"google.golang.org/grpc"
     8  	"google.golang.org/grpc/metadata"
     9  	"google.golang.org/protobuf/encoding/protojson"
    10  	"google.golang.org/protobuf/proto"
    11  	"nhooyr.io/websocket"
    12  )
    13  
    14  const kindWebsocket = "WEBSOCKET"
    15  
    16  type streamWS struct {
    17  	ctx  context.Context
    18  	conn *websocket.Conn
    19  
    20  	method *method
    21  	params params
    22  	recvN  int
    23  	sendN  int
    24  
    25  	sentHeader bool
    26  	header     metadata.MD
    27  	trailer    metadata.MD
    28  }
    29  
    30  func (s *streamWS) SetHeader(md metadata.MD) error {
    31  	if !s.sentHeader {
    32  		s.header = metadata.Join(s.header, md)
    33  	}
    34  	return nil
    35  
    36  }
    37  func (s *streamWS) SendHeader(md metadata.MD) error {
    38  	if s.sentHeader {
    39  		return nil // already sent?
    40  	}
    41  	// TODO: headers?
    42  	s.sentHeader = true
    43  	return nil
    44  }
    45  
    46  func (s *streamWS) SetTrailer(md metadata.MD) {
    47  	s.sentHeader = true
    48  	s.trailer = metadata.Join(s.trailer, md)
    49  }
    50  
    51  func (s *streamWS) Context() context.Context {
    52  	sts := &serverTransportStream{s, s.method.name}
    53  	return grpc.NewContextWithServerTransportStream(s.ctx, sts)
    54  }
    55  
    56  func (s *streamWS) SendMsg(v interface{}) error {
    57  	s.sendN += 1
    58  	reply := v.(proto.Message)
    59  	ctx := s.ctx
    60  
    61  	cur := reply.ProtoReflect()
    62  	for _, fd := range s.method.resp {
    63  		cur = cur.Mutable(fd).Message()
    64  	}
    65  	msg := cur.Interface()
    66  
    67  	// TODO: contentType check?
    68  	b, err := protojson.Marshal(msg)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	return s.conn.Write(ctx, websocket.MessageText, b)
    74  }
    75  
    76  func (s *streamWS) RecvMsg(m interface{}) error {
    77  	s.recvN += 1
    78  	args := m.(proto.Message)
    79  
    80  	if s.method.hasBody {
    81  		cur := args.ProtoReflect()
    82  		for _, fd := range s.method.body {
    83  			cur = cur.Mutable(fd).Message()
    84  		}
    85  
    86  		msg := cur.Interface()
    87  
    88  		mt, b, err := s.conn.Read(s.ctx)
    89  		if err != nil {
    90  			return err
    91  		}
    92  		if mt != websocket.MessageText {
    93  			return fmt.Errorf("invalid message type: %v", mt)
    94  		}
    95  
    96  		// TODO: contentType check?
    97  		// What marshalling options should we support?
    98  		if err := protojson.Unmarshal(b, msg); err != nil {
    99  			return err
   100  		}
   101  	}
   102  
   103  	if s.recvN == 1 {
   104  		if err := s.params.set(args); err != nil {
   105  			return err
   106  		}
   107  	}
   108  	return nil
   109  }