github.com/99designs/gqlgen@v0.17.45/codegen/testserver/singlefile/subscription_test.go (about) 1 package singlefile 2 3 import ( 4 "context" 5 "fmt" 6 "runtime" 7 "sort" 8 "testing" 9 "time" 10 11 "github.com/stretchr/testify/require" 12 13 "github.com/99designs/gqlgen/client" 14 "github.com/99designs/gqlgen/graphql" 15 "github.com/99designs/gqlgen/graphql/handler" 16 "github.com/99designs/gqlgen/graphql/handler/transport" 17 ) 18 19 func TestSubscriptions(t *testing.T) { 20 tick := make(chan string, 1) 21 22 resolvers := &Stub{} 23 24 resolvers.SubscriptionResolver.InitPayload = func(ctx context.Context) (strings <-chan string, e error) { 25 payload := transport.GetInitPayload(ctx) 26 channel := make(chan string, len(payload)+1) 27 28 go func() { 29 <-ctx.Done() 30 close(channel) 31 }() 32 33 // Test the helper function separately 34 auth := payload.Authorization() 35 if auth != "" { 36 channel <- "AUTH:" + auth 37 } else { 38 channel <- "AUTH:NONE" 39 } 40 41 // Send them over the channel in alphabetic order 42 keys := make([]string, 0, len(payload)) 43 for key := range payload { 44 keys = append(keys, key) 45 } 46 sort.Strings(keys) 47 for _, key := range keys { 48 channel <- fmt.Sprintf("%s = %#+v", key, payload[key]) 49 } 50 51 return channel, nil 52 } 53 54 errorTick := make(chan *Error, 1) 55 resolvers.SubscriptionResolver.ErrorRequired = func(ctx context.Context) (<-chan *Error, error) { 56 res := make(chan *Error, 1) 57 58 go func() { 59 for { 60 select { 61 case t := <-errorTick: 62 res <- t 63 case <-ctx.Done(): 64 close(res) 65 return 66 } 67 } 68 }() 69 return res, nil 70 } 71 72 resolvers.SubscriptionResolver.Updated = func(ctx context.Context) (<-chan string, error) { 73 res := make(chan string, 1) 74 75 go func() { 76 for { 77 select { 78 case t := <-tick: 79 res <- t 80 case <-ctx.Done(): 81 close(res) 82 return 83 } 84 } 85 }() 86 return res, nil 87 } 88 89 srv := handler.NewDefaultServer( 90 NewExecutableSchema(Config{Resolvers: resolvers}), 91 ) 92 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 93 path, _ := ctx.Value(ckey("path")).([]int) 94 return next(context.WithValue(ctx, ckey("path"), append(path, 1))) 95 }) 96 97 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 98 path, _ := ctx.Value(ckey("path")).([]int) 99 return next(context.WithValue(ctx, ckey("path"), append(path, 2))) 100 }) 101 102 c := client.New(srv) 103 104 t.Run("wont leak goroutines", func(t *testing.T) { 105 runtime.GC() // ensure no go-routines left from preceding tests 106 initialGoroutineCount := runtime.NumGoroutine() 107 108 sub := c.Websocket(`subscription { updated }`) 109 110 tick <- "message" 111 112 var msg struct { 113 resp struct { 114 Updated string 115 } 116 } 117 118 err := sub.Next(&msg.resp) 119 require.NoError(t, err) 120 require.Equal(t, "message", msg.resp.Updated) 121 sub.Close() 122 123 // need a little bit of time for goroutines to settle 124 start := time.Now() 125 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { 126 time.Sleep(5 * time.Millisecond) 127 } 128 129 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) 130 }) 131 132 t.Run("will parse init payload", func(t *testing.T) { 133 runtime.GC() // ensure no go-routines left from preceding tests 134 initialGoroutineCount := runtime.NumGoroutine() 135 136 sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{ 137 "Authorization": "Bearer of the curse", 138 "number": 32, 139 "strings": []string{"hello", "world"}, 140 }) 141 142 var msg struct { 143 resp struct { 144 InitPayload string 145 } 146 } 147 148 err := sub.Next(&msg.resp) 149 require.NoError(t, err) 150 require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload) 151 err = sub.Next(&msg.resp) 152 require.NoError(t, err) 153 require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload) 154 err = sub.Next(&msg.resp) 155 require.NoError(t, err) 156 require.Equal(t, "number = 32", msg.resp.InitPayload) 157 err = sub.Next(&msg.resp) 158 require.NoError(t, err) 159 require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload) 160 sub.Close() 161 162 // need a little bit of time for goroutines to settle 163 start := time.Now() 164 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { 165 time.Sleep(5 * time.Millisecond) 166 } 167 168 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) 169 }) 170 171 t.Run("websocket gets errors", func(t *testing.T) { 172 runtime.GC() // ensure no go-routines left from preceding tests 173 initialGoroutineCount := runtime.NumGoroutine() 174 175 sub := c.Websocket(`subscription { errorRequired { id } }`) 176 177 errorTick <- &Error{ID: "ID1234"} 178 179 var msg struct { 180 resp struct { 181 ErrorRequired *struct { 182 Id string 183 } 184 } 185 } 186 187 err := sub.Next(&msg.resp) 188 require.NoError(t, err) 189 require.Equal(t, "ID1234", msg.resp.ErrorRequired.Id) 190 191 errorTick <- nil 192 err = sub.Next(&msg.resp) 193 require.Error(t, err) 194 195 sub.Close() 196 197 // need a little bit of time for goroutines to settle 198 start := time.Now() 199 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { 200 time.Sleep(5 * time.Millisecond) 201 } 202 203 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) 204 }) 205 }