github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/graphql/context_test.go (about) 1 package graphql 2 3 import ( 4 "context" 5 "errors" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 "github.com/vektah/gqlparser/ast" 11 ) 12 13 func TestRequestContext_GetErrors(t *testing.T) { 14 c := &RequestContext{ 15 ErrorPresenter: DefaultErrorPresenter, 16 } 17 18 ctx := context.Background() 19 20 root := &ResolverContext{ 21 Field: CollectedField{ 22 Field: &ast.Field{ 23 Alias: "foo", 24 }, 25 }, 26 } 27 ctx = WithResolverContext(ctx, root) 28 c.Error(ctx, errors.New("foo1")) 29 c.Error(ctx, errors.New("foo2")) 30 31 index := 1 32 child := &ResolverContext{ 33 Parent: root, 34 Index: &index, 35 } 36 ctx = WithResolverContext(ctx, child) 37 c.Error(ctx, errors.New("bar")) 38 39 specs := []struct { 40 Name string 41 RCtx *ResolverContext 42 Messages []string 43 }{ 44 { 45 Name: "with root ResolverContext", 46 RCtx: root, 47 Messages: []string{"foo1", "foo2"}, 48 }, 49 { 50 Name: "with child ResolverContext", 51 RCtx: child, 52 Messages: []string{"bar"}, 53 }, 54 } 55 56 for _, spec := range specs { 57 t.Run(spec.Name, func(t *testing.T) { 58 errList := c.GetErrors(spec.RCtx) 59 if assert.Equal(t, len(spec.Messages), len(errList)) { 60 for idx, err := range errList { 61 assert.Equal(t, spec.Messages[idx], err.Message) 62 } 63 } 64 }) 65 } 66 } 67 68 func TestGetRequestContext(t *testing.T) { 69 require.Nil(t, GetRequestContext(context.Background())) 70 71 rc := &RequestContext{} 72 require.Equal(t, rc, GetRequestContext(WithRequestContext(context.Background(), rc))) 73 } 74 75 func TestGetResolverContext(t *testing.T) { 76 require.Nil(t, GetResolverContext(context.Background())) 77 78 rc := &ResolverContext{} 79 require.Equal(t, rc, GetResolverContext(WithResolverContext(context.Background(), rc))) 80 } 81 82 func testContext(sel ast.SelectionSet) context.Context { 83 84 ctx := context.Background() 85 86 rqCtx := &RequestContext{} 87 ctx = WithRequestContext(ctx, rqCtx) 88 89 root := &ResolverContext{ 90 Field: CollectedField{ 91 Selections: sel, 92 }, 93 } 94 ctx = WithResolverContext(ctx, root) 95 96 return ctx 97 } 98 99 func TestCollectAllFields(t *testing.T) { 100 t.Run("collect fields", func(t *testing.T) { 101 ctx := testContext(ast.SelectionSet{ 102 &ast.Field{ 103 Name: "field", 104 }, 105 }) 106 s := CollectAllFields(ctx) 107 require.Equal(t, []string{"field"}, s) 108 }) 109 110 t.Run("unique field names", func(t *testing.T) { 111 ctx := testContext(ast.SelectionSet{ 112 &ast.Field{ 113 Name: "field", 114 }, 115 &ast.Field{ 116 Name: "field", 117 Alias: "field alias", 118 }, 119 }) 120 s := CollectAllFields(ctx) 121 require.Equal(t, []string{"field"}, s) 122 }) 123 124 t.Run("collect fragments", func(t *testing.T) { 125 ctx := testContext(ast.SelectionSet{ 126 &ast.Field{ 127 Name: "fieldA", 128 }, 129 &ast.InlineFragment{ 130 TypeCondition: "ExampleTypeA", 131 SelectionSet: ast.SelectionSet{ 132 &ast.Field{ 133 Name: "fieldA", 134 }, 135 }, 136 }, 137 &ast.InlineFragment{ 138 TypeCondition: "ExampleTypeB", 139 SelectionSet: ast.SelectionSet{ 140 &ast.Field{ 141 Name: "fieldB", 142 }, 143 }, 144 }, 145 }) 146 s := CollectAllFields(ctx) 147 require.Equal(t, []string{"fieldA", "fieldB"}, s) 148 }) 149 }