github.com/99designs/gqlgen@v0.17.45/codegen/testserver/singlefile/middleware_test.go (about)

     1  package singlefile
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/99designs/gqlgen/client"
    12  	"github.com/99designs/gqlgen/graphql"
    13  	"github.com/99designs/gqlgen/graphql/handler"
    14  )
    15  
    16  func TestMiddleware(t *testing.T) {
    17  	resolvers := &Stub{}
    18  	resolvers.QueryResolver.ErrorBubble = func(ctx context.Context) (i *Error, e error) {
    19  		return &Error{ID: "E1234"}, nil
    20  	}
    21  
    22  	resolvers.QueryResolver.User = func(ctx context.Context, id int) (user *User, e error) {
    23  		return &User{ID: 1}, nil
    24  	}
    25  
    26  	resolvers.UserResolver.Friends = func(ctx context.Context, obj *User) (users []*User, e error) {
    27  		return []*User{{ID: 1}}, nil
    28  	}
    29  
    30  	resolvers.UserResolver.Pets = func(ctx context.Context, obj *User, limit *int) ([]*Pet, error) {
    31  		return []*Pet{{ID: 10}}, nil
    32  	}
    33  
    34  	resolvers.PetResolver.Friends = func(ctx context.Context, obj *Pet, limit *int) ([]*Pet, error) {
    35  		return []*Pet{}, nil
    36  	}
    37  
    38  	resolvers.QueryResolver.ModelMethods = func(ctx context.Context) (methods *ModelMethods, e error) {
    39  		return &ModelMethods{}, nil
    40  	}
    41  
    42  	var mu sync.Mutex
    43  	areMethods := map[string]bool{}
    44  	areResolvers := map[string]bool{}
    45  	srv := handler.NewDefaultServer(
    46  		NewExecutableSchema(Config{Resolvers: resolvers}),
    47  	)
    48  	srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    49  		path, _ := ctx.Value(ckey("path")).([]int)
    50  		return next(context.WithValue(ctx, ckey("path"), append(path, 1)))
    51  	})
    52  
    53  	srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    54  		path, _ := ctx.Value(ckey("path")).([]int)
    55  		return next(context.WithValue(ctx, ckey("path"), append(path, 2)))
    56  	})
    57  
    58  	srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    59  		fc := graphql.GetFieldContext(ctx)
    60  		mu.Lock()
    61  		areMethods[fc.Field.Name] = fc.IsMethod
    62  		areResolvers[fc.Field.Name] = fc.IsResolver
    63  		mu.Unlock()
    64  		return next(ctx)
    65  	})
    66  
    67  	srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    68  		fc := graphql.GetFieldContext(ctx)
    69  		if fc.Field.Name != "user" {
    70  			return next(ctx)
    71  		}
    72  		opCtx := graphql.GetOperationContext(ctx)
    73  		collected := graphql.CollectFields(opCtx, fc.Field.Selections, []string{"User"})
    74  		require.Len(t, collected, 3)
    75  		require.Equal(t, "pets", collected[2].Name)
    76  		child, err := fc.Child(ctx, collected[2])
    77  		require.NoError(t, err)
    78  		require.Equal(t, 2, *child.Args["limit"].(*int))
    79  		collected = graphql.CollectFields(opCtx, child.Field.Selections, []string{"Pet"})
    80  		require.Len(t, collected, 2)
    81  		require.Equal(t, "friends", collected[1].Name)
    82  		child, err = child.Child(ctx, collected[1])
    83  		require.NoError(t, err)
    84  		require.Equal(t, 10, *child.Args["limit"].(*int))
    85  		return next(ctx)
    86  	})
    87  
    88  	c := client.New(srv)
    89  
    90  	var resp struct {
    91  		User struct {
    92  			ID      int
    93  			Friends []struct {
    94  				ID int
    95  			}
    96  			Pets []struct {
    97  				ID      int
    98  				Friends []struct {
    99  					ID int
   100  				}
   101  			}
   102  		}
   103  		ModelMethods struct {
   104  			NoContext bool
   105  		}
   106  	}
   107  
   108  	called := false
   109  	resolvers.UserResolver.Friends = func(ctx context.Context, obj *User) ([]*User, error) {
   110  		assert.Equal(t, []int{1, 2, 1, 2}, ctx.Value(ckey("path")))
   111  		called = true
   112  		return []*User{}, nil
   113  	}
   114  
   115  	err := c.Post(`query {
   116  		user(id: 1) {
   117  			id,
   118  			friends {
   119  				id
   120  			}
   121  			pets (limit: 2) {
   122  				id
   123  				friends(limit: 10) {
   124  					id
   125  				}
   126  			}
   127  		}
   128  		modelMethods {
   129  			noContext
   130  		}
   131  	}`, &resp)
   132  
   133  	assert.Equal(t, map[string]bool{
   134  		"user":         true,
   135  		"pets":         true,
   136  		"id":           false,
   137  		"friends":      true,
   138  		"modelMethods": true,
   139  		"noContext":    true,
   140  	}, areMethods)
   141  	assert.Equal(t, map[string]bool{
   142  		"user":         true,
   143  		"pets":         true,
   144  		"id":           false,
   145  		"friends":      true,
   146  		"modelMethods": true,
   147  		"noContext":    false,
   148  	}, areResolvers)
   149  
   150  	require.NoError(t, err)
   151  	require.True(t, called)
   152  }