github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/websocket.go (about) 1 package larking 2 3 import ( 4 "context" 5 "fmt" 6 7 "google.golang.org/grpc" 8 "google.golang.org/grpc/metadata" 9 "google.golang.org/protobuf/encoding/protojson" 10 "google.golang.org/protobuf/proto" 11 "nhooyr.io/websocket" 12 ) 13 14 const kindWebsocket = "WEBSOCKET" 15 16 type streamWS struct { 17 ctx context.Context 18 conn *websocket.Conn 19 20 method *method 21 params params 22 recvN int 23 sendN int 24 25 sentHeader bool 26 header metadata.MD 27 trailer metadata.MD 28 } 29 30 func (s *streamWS) SetHeader(md metadata.MD) error { 31 if !s.sentHeader { 32 s.header = metadata.Join(s.header, md) 33 } 34 return nil 35 36 } 37 func (s *streamWS) SendHeader(md metadata.MD) error { 38 if s.sentHeader { 39 return nil // already sent? 40 } 41 // TODO: headers? 42 s.sentHeader = true 43 return nil 44 } 45 46 func (s *streamWS) SetTrailer(md metadata.MD) { 47 s.sentHeader = true 48 s.trailer = metadata.Join(s.trailer, md) 49 } 50 51 func (s *streamWS) Context() context.Context { 52 sts := &serverTransportStream{s, s.method.name} 53 return grpc.NewContextWithServerTransportStream(s.ctx, sts) 54 } 55 56 func (s *streamWS) SendMsg(v interface{}) error { 57 s.sendN += 1 58 reply := v.(proto.Message) 59 ctx := s.ctx 60 61 cur := reply.ProtoReflect() 62 for _, fd := range s.method.resp { 63 cur = cur.Mutable(fd).Message() 64 } 65 msg := cur.Interface() 66 67 // TODO: contentType check? 68 b, err := protojson.Marshal(msg) 69 if err != nil { 70 return err 71 } 72 73 return s.conn.Write(ctx, websocket.MessageText, b) 74 } 75 76 func (s *streamWS) RecvMsg(m interface{}) error { 77 s.recvN += 1 78 args := m.(proto.Message) 79 80 if s.method.hasBody { 81 cur := args.ProtoReflect() 82 for _, fd := range s.method.body { 83 cur = cur.Mutable(fd).Message() 84 } 85 86 msg := cur.Interface() 87 88 mt, b, err := s.conn.Read(s.ctx) 89 if err != nil { 90 return err 91 } 92 if mt != websocket.MessageText { 93 return fmt.Errorf("invalid message type: %v", mt) 94 } 95 96 // TODO: contentType check? 97 // What marshalling options should we support? 98 if err := protojson.Unmarshal(b, msg); err != nil { 99 return err 100 } 101 } 102 103 if s.recvN == 1 { 104 if err := s.params.set(args); err != nil { 105 return err 106 } 107 } 108 return nil 109 }