github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/websocket_test.go (about) 1 package larking 2 3 import ( 4 "context" 5 "net" 6 "net/http" 7 "testing" 8 "time" 9 10 "github.com/emcfarlane/larking/testpb" 11 "github.com/google/go-cmp/cmp" 12 "golang.org/x/sync/errgroup" 13 "google.golang.org/grpc" 14 "google.golang.org/protobuf/encoding/protojson" 15 "google.golang.org/protobuf/proto" 16 "google.golang.org/protobuf/testing/protocmp" 17 "nhooyr.io/websocket" 18 ) 19 20 func TestWebsocket(t *testing.T) { 21 // Create test server. 22 fs := &testpb.UnimplementedChatRoomServer{} 23 o := &overrides{} 24 25 var g errgroup.Group 26 defer func() { 27 if err := g.Wait(); err != nil { 28 t.Fatal(err) 29 } 30 }() 31 mux, err := NewMux( 32 UnaryServerInterceptorOption(o.unary()), 33 StreamServerInterceptorOption(o.stream()), 34 ) 35 if err != nil { 36 t.Fatal(err) 37 } 38 mux.RegisterService(&testpb.ChatRoom_ServiceDesc, fs) 39 40 s, err := NewServer(mux, InsecureServerOption()) 41 if err != nil { 42 t.Fatal(err) 43 } 44 45 lis, err := net.Listen("tcp", "localhost:0") 46 if err != nil { 47 t.Fatalf("failed to listen: %v", err) 48 } 49 defer lis.Close() 50 51 g.Go(func() (err error) { 52 if err := s.Serve(lis); err != nil && err != http.ErrServerClosed { 53 return err 54 } 55 return nil 56 }) 57 defer func() { 58 if err := s.Shutdown(context.Background()); err != nil { 59 t.Fatal(err) 60 } 61 }() 62 63 cmpOpts := cmp.Options{protocmp.Transform()} 64 var unaryStreamDesc = &grpc.StreamDesc{ 65 ClientStreams: false, 66 ServerStreams: false, 67 } 68 69 tests := []struct { 70 name string 71 desc *grpc.StreamDesc 72 path string 73 method string 74 client []interface{} 75 server []interface{} 76 }{{ 77 name: "unary", 78 desc: unaryStreamDesc, 79 path: "/v1/rooms/chat", 80 method: "/larking.testpb.ChatRoom/Chat", 81 client: []interface{}{ 82 in{ 83 msg: &testpb.ChatMessage{ 84 Text: "hello", 85 }, 86 }, 87 out{ 88 msg: &testpb.ChatMessage{ 89 Text: "world", 90 }, 91 }, 92 }, 93 server: []interface{}{ 94 in{ 95 msg: &testpb.ChatMessage{ 96 Name: "rooms/chat", // name added from URL path 97 Text: "hello", 98 }, 99 }, 100 out{ 101 msg: &testpb.ChatMessage{ 102 Text: "world", 103 }, 104 }, 105 }, 106 }} 107 108 for _, tt := range tests { 109 t.Run(tt.name, func(t *testing.T) { 110 o.reset(t, "http-test", tt.server) 111 112 ctx, cancel := context.WithTimeout(testContext(t), time.Minute) 113 defer cancel() 114 115 c, _, err := websocket.Dial(ctx, "ws://"+lis.Addr().String()+tt.path, &websocket.DialOptions{ 116 HTTPHeader: map[string][]string{ 117 "test": {tt.method}, 118 }, 119 }) 120 if err != nil { 121 t.Fatal(err) 122 } 123 defer c.Close(websocket.StatusNormalClosure, "the sky is falling") 124 125 for i := 0; i < len(tt.client); i++ { 126 switch typ := tt.client[i].(type) { 127 case in: 128 b, err := protojson.Marshal(typ.msg) 129 if err != nil { 130 t.Fatal(err) 131 } 132 if err := c.Write(ctx, websocket.MessageText, b); err != nil { 133 t.Fatal(err) 134 } 135 136 case out: 137 mt, b, err := c.Read(ctx) 138 if err != nil { 139 t.Fatal(mt, err) 140 } 141 t.Log("b", string(b)) 142 143 out := proto.Clone(typ.msg) 144 if err := protojson.Unmarshal(b, out); err != nil { 145 t.Fatal(err) 146 } 147 diff := cmp.Diff(out, typ.msg, cmpOpts...) 148 if diff != "" { 149 t.Fatal(diff) 150 } 151 } 152 } 153 c.Close(websocket.StatusNormalClosure, "normal") 154 }) 155 } 156 }