github.com/luciferinlove/gqlgen@v0.17.16-bzc.1/codegen/testserver/followschema/subscription_test.go (about) 1 package followschema 2 3 import ( 4 "context" 5 "fmt" 6 "runtime" 7 "sort" 8 "testing" 9 "time" 10 11 "github.com/luciferinlove/gqlgen/graphql/handler/transport" 12 13 "github.com/luciferinlove/gqlgen/client" 14 "github.com/luciferinlove/gqlgen/graphql" 15 "github.com/luciferinlove/gqlgen/graphql/handler" 16 "github.com/stretchr/testify/require" 17 ) 18 19 func TestSubscriptions(t *testing.T) { 20 tick := make(chan string, 1) 21 22 resolvers := &Stub{} 23 24 resolvers.SubscriptionResolver.InitPayload = func(ctx context.Context) (strings <-chan string, e error) { 25 payload := transport.GetInitPayload(ctx) 26 channel := make(chan string, len(payload)+1) 27 28 go func() { 29 <-ctx.Done() 30 close(channel) 31 }() 32 33 // Test the helper function separately 34 auth := payload.Authorization() 35 if auth != "" { 36 channel <- "AUTH:" + auth 37 } else { 38 channel <- "AUTH:NONE" 39 } 40 41 // Send them over the channel in alphabetic order 42 keys := make([]string, 0, len(payload)) 43 for key := range payload { 44 keys = append(keys, key) 45 } 46 sort.Strings(keys) 47 for _, key := range keys { 48 channel <- fmt.Sprintf("%s = %#+v", key, payload[key]) 49 } 50 51 return channel, nil 52 } 53 54 errorTick := make(chan *Error, 1) 55 resolvers.SubscriptionResolver.ErrorRequired = func(ctx context.Context) (<-chan *Error, error) { 56 res := make(chan *Error, 1) 57 58 go func() { 59 for { 60 select { 61 case t := <-errorTick: 62 res <- t 63 case <-ctx.Done(): 64 close(res) 65 return 66 } 67 } 68 }() 69 return res, nil 70 } 71 72 resolvers.SubscriptionResolver.Updated = func(ctx context.Context) (<-chan string, error) { 73 res := make(chan string, 1) 74 75 go func() { 76 for { 77 select { 78 case t := <-tick: 79 res <- t 80 case <-ctx.Done(): 81 close(res) 82 return 83 } 84 } 85 }() 86 return res, nil 87 } 88 89 srv := handler.NewDefaultServer( 90 NewExecutableSchema(Config{Resolvers: resolvers}), 91 ) 92 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 93 path, _ := ctx.Value(ckey("path")).([]int) 94 return next(context.WithValue(ctx, ckey("path"), append(path, 1))) 95 }) 96 97 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { 98 path, _ := ctx.Value(ckey("path")).([]int) 99 return next(context.WithValue(ctx, ckey("path"), append(path, 2))) 100 }) 101 102 c := client.New(srv) 103 104 t.Run("wont leak goroutines", func(t *testing.T) { 105 runtime.GC() // ensure no go-routines left from preceding tests 106 initialGoroutineCount := runtime.NumGoroutine() 107 108 sub := c.Websocket(`subscription { updated }`) 109 110 tick <- "message" 111 112 var msg struct { 113 resp struct { 114 Updated string 115 } 116 } 117 118 err := sub.Next(&msg.resp) 119 require.NoError(t, err) 120 require.Equal(t, "message", msg.resp.Updated) 121 sub.Close() 122 123 // need a little bit of time for goroutines to settle 124 start := time.Now() 125 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { 126 time.Sleep(5 * time.Millisecond) 127 } 128 129 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) 130 }) 131 132 t.Run("will parse init payload", func(t *testing.T) { 133 sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{ 134 "Authorization": "Bearer of the curse", 135 "number": 32, 136 "strings": []string{"hello", "world"}, 137 }) 138 139 var msg struct { 140 resp struct { 141 InitPayload string 142 } 143 } 144 145 err := sub.Next(&msg.resp) 146 require.NoError(t, err) 147 require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload) 148 err = sub.Next(&msg.resp) 149 require.NoError(t, err) 150 require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload) 151 err = sub.Next(&msg.resp) 152 require.NoError(t, err) 153 require.Equal(t, "number = 32", msg.resp.InitPayload) 154 err = sub.Next(&msg.resp) 155 require.NoError(t, err) 156 require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload) 157 sub.Close() 158 }) 159 160 t.Run("websocket gets errors", func(t *testing.T) { 161 runtime.GC() // ensure no go-routines left from preceding tests 162 initialGoroutineCount := runtime.NumGoroutine() 163 164 sub := c.Websocket(`subscription { errorRequired { id } }`) 165 166 errorTick <- &Error{ID: "ID1234"} 167 168 var msg struct { 169 resp struct { 170 ErrorRequired *struct { 171 Id string 172 } 173 } 174 } 175 176 err := sub.Next(&msg.resp) 177 require.NoError(t, err) 178 require.Equal(t, "ID1234", msg.resp.ErrorRequired.Id) 179 180 errorTick <- nil 181 err = sub.Next(&msg.resp) 182 require.Error(t, err) 183 184 sub.Close() 185 186 // need a little bit of time for goroutines to settle 187 start := time.Now() 188 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() { 189 time.Sleep(5 * time.Millisecond) 190 } 191 192 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) 193 }) 194 }