github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/graphql/context_test.go (about)

     1  package graphql
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  	"github.com/vektah/gqlparser/ast"
    11  )
    12  
    13  func TestRequestContext_GetErrors(t *testing.T) {
    14  	c := &RequestContext{
    15  		ErrorPresenter: DefaultErrorPresenter,
    16  	}
    17  
    18  	ctx := context.Background()
    19  
    20  	root := &ResolverContext{
    21  		Field: CollectedField{
    22  			Field: &ast.Field{
    23  				Alias: "foo",
    24  			},
    25  		},
    26  	}
    27  	ctx = WithResolverContext(ctx, root)
    28  	c.Error(ctx, errors.New("foo1"))
    29  	c.Error(ctx, errors.New("foo2"))
    30  
    31  	index := 1
    32  	child := &ResolverContext{
    33  		Parent: root,
    34  		Index:  &index,
    35  	}
    36  	ctx = WithResolverContext(ctx, child)
    37  	c.Error(ctx, errors.New("bar"))
    38  
    39  	specs := []struct {
    40  		Name     string
    41  		RCtx     *ResolverContext
    42  		Messages []string
    43  	}{
    44  		{
    45  			Name:     "with root ResolverContext",
    46  			RCtx:     root,
    47  			Messages: []string{"foo1", "foo2"},
    48  		},
    49  		{
    50  			Name:     "with child ResolverContext",
    51  			RCtx:     child,
    52  			Messages: []string{"bar"},
    53  		},
    54  	}
    55  
    56  	for _, spec := range specs {
    57  		t.Run(spec.Name, func(t *testing.T) {
    58  			errList := c.GetErrors(spec.RCtx)
    59  			if assert.Equal(t, len(spec.Messages), len(errList)) {
    60  				for idx, err := range errList {
    61  					assert.Equal(t, spec.Messages[idx], err.Message)
    62  				}
    63  			}
    64  		})
    65  	}
    66  }
    67  
    68  func TestGetRequestContext(t *testing.T) {
    69  	require.Nil(t, GetRequestContext(context.Background()))
    70  
    71  	rc := &RequestContext{}
    72  	require.Equal(t, rc, GetRequestContext(WithRequestContext(context.Background(), rc)))
    73  }
    74  
    75  func TestGetResolverContext(t *testing.T) {
    76  	require.Nil(t, GetResolverContext(context.Background()))
    77  
    78  	rc := &ResolverContext{}
    79  	require.Equal(t, rc, GetResolverContext(WithResolverContext(context.Background(), rc)))
    80  }
    81  
    82  func testContext(sel ast.SelectionSet) context.Context {
    83  
    84  	ctx := context.Background()
    85  
    86  	rqCtx := &RequestContext{}
    87  	ctx = WithRequestContext(ctx, rqCtx)
    88  
    89  	root := &ResolverContext{
    90  		Field: CollectedField{
    91  			Selections: sel,
    92  		},
    93  	}
    94  	ctx = WithResolverContext(ctx, root)
    95  
    96  	return ctx
    97  }
    98  
    99  func TestCollectAllFields(t *testing.T) {
   100  	t.Run("collect fields", func(t *testing.T) {
   101  		ctx := testContext(ast.SelectionSet{
   102  			&ast.Field{
   103  				Name: "field",
   104  			},
   105  		})
   106  		s := CollectAllFields(ctx)
   107  		require.Equal(t, []string{"field"}, s)
   108  	})
   109  
   110  	t.Run("unique field names", func(t *testing.T) {
   111  		ctx := testContext(ast.SelectionSet{
   112  			&ast.Field{
   113  				Name: "field",
   114  			},
   115  			&ast.Field{
   116  				Name:  "field",
   117  				Alias: "field alias",
   118  			},
   119  		})
   120  		s := CollectAllFields(ctx)
   121  		require.Equal(t, []string{"field"}, s)
   122  	})
   123  
   124  	t.Run("collect fragments", func(t *testing.T) {
   125  		ctx := testContext(ast.SelectionSet{
   126  			&ast.Field{
   127  				Name: "fieldA",
   128  			},
   129  			&ast.InlineFragment{
   130  				TypeCondition: "ExampleTypeA",
   131  				SelectionSet: ast.SelectionSet{
   132  					&ast.Field{
   133  						Name: "fieldA",
   134  					},
   135  				},
   136  			},
   137  			&ast.InlineFragment{
   138  				TypeCondition: "ExampleTypeB",
   139  				SelectionSet: ast.SelectionSet{
   140  					&ast.Field{
   141  						Name: "fieldB",
   142  					},
   143  				},
   144  			},
   145  		})
   146  		s := CollectAllFields(ctx)
   147  		require.Equal(t, []string{"fieldA", "fieldB"}, s)
   148  	})
   149  }