github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/plugin/federation/federation.go (about)

     1  package federation
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  
    11  	"github.com/geneva/gqlgen/codegen"
    12  	"github.com/geneva/gqlgen/codegen/config"
    13  	"github.com/geneva/gqlgen/codegen/templates"
    14  	"github.com/geneva/gqlgen/plugin"
    15  	"github.com/geneva/gqlgen/plugin/federation/fieldset"
    16  )
    17  
    18  //go:embed federation.gotpl
    19  var federationTemplate string
    20  
    21  type federation struct {
    22  	Entities []*Entity
    23  	Version  int
    24  }
    25  
    26  // New returns a federation plugin that injects
    27  // federated directives and types into the schema
    28  func New(version int) plugin.Plugin {
    29  	if version == 0 {
    30  		version = 1
    31  	}
    32  
    33  	return &federation{Version: version}
    34  }
    35  
    36  // Name returns the plugin name
    37  func (f *federation) Name() string {
    38  	return "federation"
    39  }
    40  
    41  // MutateConfig mutates the configuration
    42  func (f *federation) MutateConfig(cfg *config.Config) error {
    43  	builtins := config.TypeMap{
    44  		"_Service": {
    45  			Model: config.StringList{
    46  				"github.com/geneva/gqlgen/plugin/federation/fedruntime.Service",
    47  			},
    48  		},
    49  		"_Entity": {
    50  			Model: config.StringList{
    51  				"github.com/geneva/gqlgen/plugin/federation/fedruntime.Entity",
    52  			},
    53  		},
    54  		"Entity": {
    55  			Model: config.StringList{
    56  				"github.com/geneva/gqlgen/plugin/federation/fedruntime.Entity",
    57  			},
    58  		},
    59  		"_Any": {
    60  			Model: config.StringList{"github.com/geneva/gqlgen/graphql.Map"},
    61  		},
    62  	}
    63  
    64  	for typeName, entry := range builtins {
    65  		if cfg.Models.Exists(typeName) {
    66  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    67  		}
    68  		cfg.Models[typeName] = entry
    69  	}
    70  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    71  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    72  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    73  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    74  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    75  
    76  	// Federation 2 specific directives
    77  	if f.Version == 2 {
    78  		cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
    79  		cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
    80  		cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
    81  		cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
    82  		cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
    83  		cfg.Directives["entityResolver"] = config.DirectiveConfig{SkipRuntime: true}
    84  		cfg.Directives["interfaceObject"] = config.DirectiveConfig{SkipRuntime: true}
    85  		cfg.Directives["composeDirective"] = config.DirectiveConfig{SkipRuntime: true}
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func (f *federation) InjectSourceEarly() *ast.Source {
    92  	input := ``
    93  
    94  	// add version-specific changes on key directive, as well as adding the new directives for federation 2
    95  	if f.Version == 1 {
    96  		input += `
    97  	directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
    98  	directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
    99  	directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
   100  	directive @extends on OBJECT | INTERFACE
   101  	directive @external on FIELD_DEFINITION
   102  	scalar _Any
   103  	scalar _FieldSet
   104  `
   105  	} else if f.Version == 2 {
   106  		input += `
   107  	directive @composeDirective(name: String!) repeatable on SCHEMA
   108  	directive @extends on OBJECT | INTERFACE
   109  	directive @external on OBJECT | FIELD_DEFINITION
   110  	directive @key(fields: FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
   111  	directive @inaccessible on
   112  	  | ARGUMENT_DEFINITION
   113  	  | ENUM
   114  	  | ENUM_VALUE
   115  	  | FIELD_DEFINITION
   116  	  | INPUT_FIELD_DEFINITION
   117  	  | INPUT_OBJECT
   118  	  | INTERFACE
   119  	  | OBJECT
   120  	  | SCALAR
   121  	  | UNION
   122  	directive @interfaceObject on OBJECT
   123  	directive @link(import: [String!], url: String!) repeatable on SCHEMA
   124  	directive @override(from: String!) on FIELD_DEFINITION
   125  	directive @provides(fields: FieldSet!) on FIELD_DEFINITION
   126  	directive @requires(fields: FieldSet!) on FIELD_DEFINITION
   127  	directive @shareable repeatable on FIELD_DEFINITION | OBJECT
   128  	directive @tag(name: String!) repeatable on
   129  	  | ARGUMENT_DEFINITION
   130  	  | ENUM
   131  	  | ENUM_VALUE
   132  	  | FIELD_DEFINITION
   133  	  | INPUT_FIELD_DEFINITION
   134  	  | INPUT_OBJECT
   135  	  | INTERFACE
   136  	  | OBJECT
   137  	  | SCALAR
   138  	  | UNION
   139  	scalar _Any
   140  	scalar FieldSet
   141  `
   142  	}
   143  	return &ast.Source{
   144  		Name:    "federation/directives.graphql",
   145  		Input:   input,
   146  		BuiltIn: true,
   147  	}
   148  }
   149  
   150  // InjectSourceLate creates a GraphQL Entity type with all
   151  // the fields that had the @key directive
   152  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
   153  	f.setEntities(schema)
   154  
   155  	var entities, resolvers, entityResolverInputDefinitions string
   156  	for _, e := range f.Entities {
   157  
   158  		if e.Def.Kind != ast.Interface {
   159  			if entities != "" {
   160  				entities += " | "
   161  			}
   162  			entities += e.Name
   163  		} else if len(schema.GetPossibleTypes(e.Def)) == 0 {
   164  			fmt.Println(
   165  				"skipping @key field on interface " + e.Def.Name + " as no types implement it",
   166  			)
   167  		}
   168  
   169  		for _, r := range e.Resolvers {
   170  			if e.Multi {
   171  				if entityResolverInputDefinitions != "" {
   172  					entityResolverInputDefinitions += "\n\n"
   173  				}
   174  				entityResolverInputDefinitions += "input " + r.InputTypeName + " {\n"
   175  				for _, keyField := range r.KeyFields {
   176  					entityResolverInputDefinitions += fmt.Sprintf("\t%s: %s\n", keyField.Field.ToGo(), keyField.Definition.Type.String())
   177  				}
   178  				entityResolverInputDefinitions += "}"
   179  				resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputTypeName, e.Name)
   180  			} else {
   181  				resolverArgs := ""
   182  				for _, keyField := range r.KeyFields {
   183  					resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
   184  				}
   185  				resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
   186  			}
   187  		}
   188  	}
   189  
   190  	var blocks []string
   191  	if entities != "" {
   192  		entities = `# a union of all types that use the @key directive
   193  union _Entity = ` + entities
   194  		blocks = append(blocks, entities)
   195  	}
   196  
   197  	// resolvers can be empty if a service defines only "empty
   198  	// extend" types.  This should be rare.
   199  	if resolvers != "" {
   200  		if entityResolverInputDefinitions != "" {
   201  			blocks = append(blocks, entityResolverInputDefinitions)
   202  		}
   203  		resolvers = `# fake type to build resolver interfaces for users to implement
   204  type Entity {
   205  	` + resolvers + `
   206  }`
   207  		blocks = append(blocks, resolvers)
   208  	}
   209  
   210  	_serviceTypeDef := `type _Service {
   211    sdl: String
   212  }`
   213  	blocks = append(blocks, _serviceTypeDef)
   214  
   215  	var additionalQueryFields string
   216  	// Quote from the Apollo Federation subgraph specification:
   217  	// If no types are annotated with the key directive, then the
   218  	// _Entity union and _entities field should be removed from the schema
   219  	if len(f.Entities) > 0 {
   220  		additionalQueryFields += `  _entities(representations: [_Any!]!): [_Entity]!
   221  `
   222  	}
   223  	// _service field is required in any case
   224  	additionalQueryFields += `  _service: _Service!`
   225  
   226  	extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
   227  ` + additionalQueryFields + `
   228  }`
   229  	blocks = append(blocks, extendTypeQueryDef)
   230  
   231  	return &ast.Source{
   232  		Name:    "federation/entity.graphql",
   233  		BuiltIn: true,
   234  		Input:   "\n" + strings.Join(blocks, "\n\n") + "\n",
   235  	}
   236  }
   237  
   238  func (f *federation) GenerateCode(data *codegen.Data) error {
   239  	if len(f.Entities) > 0 {
   240  		if data.Objects.ByName("Entity") != nil {
   241  			data.Objects.ByName("Entity").Root = true
   242  		}
   243  		for _, e := range f.Entities {
   244  			obj := data.Objects.ByName(e.Def.Name)
   245  
   246  			if e.Def.Kind == ast.Interface {
   247  				if len(data.Interfaces[e.Def.Name].Implementors) == 0 {
   248  					fmt.Println(
   249  						"skipping @key field on interface " + e.Def.Name + " as no types implement it",
   250  					)
   251  					continue
   252  				}
   253  				obj = data.Objects.ByName(data.Interfaces[e.Def.Name].Implementors[0].Name)
   254  			}
   255  
   256  			for _, r := range e.Resolvers {
   257  				// fill in types for key fields
   258  				//
   259  				for _, keyField := range r.KeyFields {
   260  					if len(keyField.Field) == 0 {
   261  						fmt.Println(
   262  							"skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
   263  						)
   264  						continue
   265  					}
   266  					cgField := keyField.Field.TypeReference(obj, data.Objects)
   267  					keyField.Type = cgField.TypeReference
   268  				}
   269  			}
   270  
   271  			// fill in types for requires fields
   272  			//
   273  			for _, reqField := range e.Requires {
   274  				if len(reqField.Field) == 0 {
   275  					fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
   276  					continue
   277  				}
   278  				cgField := reqField.Field.TypeReference(obj, data.Objects)
   279  				reqField.Type = cgField.TypeReference
   280  			}
   281  		}
   282  	}
   283  
   284  	// fill in types for resolver inputs
   285  	//
   286  	for _, entity := range f.Entities {
   287  		if !entity.Multi {
   288  			continue
   289  		}
   290  
   291  		for _, resolver := range entity.Resolvers {
   292  			obj := data.Inputs.ByName(resolver.InputTypeName)
   293  			if obj == nil {
   294  				return fmt.Errorf("input object %s not found", resolver.InputTypeName)
   295  			}
   296  
   297  			resolver.InputType = obj.Type
   298  		}
   299  	}
   300  
   301  	return templates.Render(templates.Options{
   302  		PackageName:     data.Config.Federation.Package,
   303  		Filename:        data.Config.Federation.Filename,
   304  		Data:            f,
   305  		GeneratedHeader: true,
   306  		Packages:        data.Config.Packages,
   307  		Template:        federationTemplate,
   308  	})
   309  }
   310  
   311  func (f *federation) setEntities(schema *ast.Schema) {
   312  	for _, schemaType := range schema.Types {
   313  		keys, ok := isFederatedEntity(schemaType)
   314  		if !ok {
   315  			continue
   316  		}
   317  
   318  		if (schemaType.Kind == ast.Interface) && (len(schema.GetPossibleTypes(schemaType)) == 0) {
   319  			fmt.Printf("@key directive found on unused \"interface %s\". Will be ignored.\n", schemaType.Name)
   320  			continue
   321  		}
   322  
   323  		e := &Entity{
   324  			Name:      schemaType.Name,
   325  			Def:       schemaType,
   326  			Resolvers: nil,
   327  			Requires:  nil,
   328  		}
   329  
   330  		// Let's process custom entity resolver settings.
   331  		dir := schemaType.Directives.ForName("entityResolver")
   332  		if dir != nil {
   333  			if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
   334  				if dirVal, err := dirArg.Value.Value(nil); err == nil {
   335  					e.Multi = dirVal.(bool)
   336  				}
   337  			}
   338  		}
   339  
   340  		// If our schema has a field with a type defined in
   341  		// another service, then we need to define an "empty
   342  		// extend" of that type in this service, so this service
   343  		// knows what the type is like.  But the graphql-server
   344  		// will never ask us to actually resolve this "empty
   345  		// extend", so we don't require a resolver function for
   346  		// it.  (Well, it will never ask in practice; it's
   347  		// unclear whether the spec guarantees this.  See
   348  		// https://github.com/apollographql/apollo-server/issues/3852
   349  		// ).  Example:
   350  		//    type MyType {
   351  		//       myvar: TypeDefinedInOtherService
   352  		//    }
   353  		//    // Federation needs this type, but
   354  		//    // it doesn't need a resolver for it!
   355  		//    extend TypeDefinedInOtherService @key(fields: "id") {
   356  		//       id: ID @external
   357  		//    }
   358  		if !e.allFieldsAreExternal(f.Version) {
   359  			for _, dir := range keys {
   360  				if len(dir.Arguments) > 2 {
   361  					panic("More than two arguments provided for @key declaration.")
   362  				}
   363  				var arg *ast.Argument
   364  
   365  				// since keys are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
   366  				for _, a := range dir.Arguments {
   367  					if a.Name == "fields" {
   368  						if arg != nil {
   369  							panic("More than one `fields` provided for @key declaration.")
   370  						}
   371  						arg = a
   372  					}
   373  				}
   374  
   375  				keyFieldSet := fieldset.New(arg.Value.Raw, nil)
   376  
   377  				keyFields := make([]*KeyField, len(keyFieldSet))
   378  				resolverFields := []string{}
   379  				for i, field := range keyFieldSet {
   380  					def := field.FieldDefinition(schemaType, schema)
   381  
   382  					if def == nil {
   383  						panic(fmt.Sprintf("no field for %v", field))
   384  					}
   385  
   386  					keyFields[i] = &KeyField{Definition: def, Field: field}
   387  					resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
   388  				}
   389  
   390  				resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
   391  				var resolverName string
   392  				if e.Multi {
   393  					resolverFieldsToGo += "s" // Pluralize for better API readability
   394  					resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
   395  				} else {
   396  					resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
   397  				}
   398  
   399  				e.Resolvers = append(e.Resolvers, &EntityResolver{
   400  					ResolverName:  resolverName,
   401  					KeyFields:     keyFields,
   402  					InputTypeName: resolverFieldsToGo + "Input",
   403  				})
   404  			}
   405  
   406  			e.Requires = []*Requires{}
   407  			for _, f := range schemaType.Fields {
   408  				dir := f.Directives.ForName("requires")
   409  				if dir == nil {
   410  					continue
   411  				}
   412  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   413  					panic("Exactly one `fields` argument needed for @requires declaration.")
   414  				}
   415  				requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
   416  				for _, field := range requiresFieldSet {
   417  					e.Requires = append(e.Requires, &Requires{
   418  						Name:  field.ToGoPrivate(),
   419  						Field: field,
   420  					})
   421  				}
   422  			}
   423  		}
   424  		f.Entities = append(f.Entities, e)
   425  	}
   426  
   427  	// make sure order remains stable across multiple builds
   428  	sort.Slice(f.Entities, func(i, j int) bool {
   429  		return f.Entities[i].Name < f.Entities[j].Name
   430  	})
   431  }
   432  
   433  func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
   434  	switch schemaType.Kind {
   435  	case ast.Object:
   436  		keys := schemaType.Directives.ForNames("key")
   437  		if len(keys) > 0 {
   438  			return keys, true
   439  		}
   440  	case ast.Interface:
   441  		keys := schemaType.Directives.ForNames("key")
   442  		if len(keys) > 0 {
   443  			return keys, true
   444  		}
   445  
   446  		// TODO: support @extends for interfaces
   447  		if dir := schemaType.Directives.ForName("extends"); dir != nil {
   448  			panic(
   449  				fmt.Sprintf(
   450  					"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
   451  					schemaType.Name,
   452  				))
   453  		}
   454  	default:
   455  		// ignore
   456  	}
   457  	return nil, false
   458  }