github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/websocket_test.go (about)

     1  package larking
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/emcfarlane/larking/testpb"
    11  	"github.com/google/go-cmp/cmp"
    12  	"golang.org/x/sync/errgroup"
    13  	"google.golang.org/grpc"
    14  	"google.golang.org/protobuf/encoding/protojson"
    15  	"google.golang.org/protobuf/proto"
    16  	"google.golang.org/protobuf/testing/protocmp"
    17  	"nhooyr.io/websocket"
    18  )
    19  
    20  func TestWebsocket(t *testing.T) {
    21  	// Create test server.
    22  	fs := &testpb.UnimplementedChatRoomServer{}
    23  	o := &overrides{}
    24  
    25  	var g errgroup.Group
    26  	defer func() {
    27  		if err := g.Wait(); err != nil {
    28  			t.Fatal(err)
    29  		}
    30  	}()
    31  	mux, err := NewMux(
    32  		UnaryServerInterceptorOption(o.unary()),
    33  		StreamServerInterceptorOption(o.stream()),
    34  	)
    35  	if err != nil {
    36  		t.Fatal(err)
    37  	}
    38  	mux.RegisterService(&testpb.ChatRoom_ServiceDesc, fs)
    39  
    40  	s, err := NewServer(mux, InsecureServerOption())
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  
    45  	lis, err := net.Listen("tcp", "localhost:0")
    46  	if err != nil {
    47  		t.Fatalf("failed to listen: %v", err)
    48  	}
    49  	defer lis.Close()
    50  
    51  	g.Go(func() (err error) {
    52  		if err := s.Serve(lis); err != nil && err != http.ErrServerClosed {
    53  			return err
    54  		}
    55  		return nil
    56  	})
    57  	defer func() {
    58  		if err := s.Shutdown(context.Background()); err != nil {
    59  			t.Fatal(err)
    60  		}
    61  	}()
    62  
    63  	cmpOpts := cmp.Options{protocmp.Transform()}
    64  	var unaryStreamDesc = &grpc.StreamDesc{
    65  		ClientStreams: false,
    66  		ServerStreams: false,
    67  	}
    68  
    69  	tests := []struct {
    70  		name   string
    71  		desc   *grpc.StreamDesc
    72  		path   string
    73  		method string
    74  		client []interface{}
    75  		server []interface{}
    76  	}{{
    77  		name:   "unary",
    78  		desc:   unaryStreamDesc,
    79  		path:   "/v1/rooms/chat",
    80  		method: "/larking.testpb.ChatRoom/Chat",
    81  		client: []interface{}{
    82  			in{
    83  				msg: &testpb.ChatMessage{
    84  					Text: "hello",
    85  				},
    86  			},
    87  			out{
    88  				msg: &testpb.ChatMessage{
    89  					Text: "world",
    90  				},
    91  			},
    92  		},
    93  		server: []interface{}{
    94  			in{
    95  				msg: &testpb.ChatMessage{
    96  					Name: "rooms/chat", // name added from URL path
    97  					Text: "hello",
    98  				},
    99  			},
   100  			out{
   101  				msg: &testpb.ChatMessage{
   102  					Text: "world",
   103  				},
   104  			},
   105  		},
   106  	}}
   107  
   108  	for _, tt := range tests {
   109  		t.Run(tt.name, func(t *testing.T) {
   110  			o.reset(t, "http-test", tt.server)
   111  
   112  			ctx, cancel := context.WithTimeout(testContext(t), time.Minute)
   113  			defer cancel()
   114  
   115  			c, _, err := websocket.Dial(ctx, "ws://"+lis.Addr().String()+tt.path, &websocket.DialOptions{
   116  				HTTPHeader: map[string][]string{
   117  					"test": {tt.method},
   118  				},
   119  			})
   120  			if err != nil {
   121  				t.Fatal(err)
   122  			}
   123  			defer c.Close(websocket.StatusNormalClosure, "the sky is falling")
   124  
   125  			for i := 0; i < len(tt.client); i++ {
   126  				switch typ := tt.client[i].(type) {
   127  				case in:
   128  					b, err := protojson.Marshal(typ.msg)
   129  					if err != nil {
   130  						t.Fatal(err)
   131  					}
   132  					if err := c.Write(ctx, websocket.MessageText, b); err != nil {
   133  						t.Fatal(err)
   134  					}
   135  
   136  				case out:
   137  					mt, b, err := c.Read(ctx)
   138  					if err != nil {
   139  						t.Fatal(mt, err)
   140  					}
   141  					t.Log("b", string(b))
   142  
   143  					out := proto.Clone(typ.msg)
   144  					if err := protojson.Unmarshal(b, out); err != nil {
   145  						t.Fatal(err)
   146  					}
   147  					diff := cmp.Diff(out, typ.msg, cmpOpts...)
   148  					if diff != "" {
   149  						t.Fatal(diff)
   150  					}
   151  				}
   152  			}
   153  			c.Close(websocket.StatusNormalClosure, "normal")
   154  		})
   155  	}
   156  }