github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/plugin/federation/federation.go (about)

     1  package federation
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  
    11  	"github.com/niko0xdev/gqlgen/codegen"
    12  	"github.com/niko0xdev/gqlgen/codegen/config"
    13  	"github.com/niko0xdev/gqlgen/codegen/templates"
    14  	"github.com/niko0xdev/gqlgen/plugin"
    15  	"github.com/niko0xdev/gqlgen/plugin/federation/fieldset"
    16  )
    17  
    18  //go:embed federation.gotpl
    19  var federationTemplate string
    20  
    21  type federation struct {
    22  	Entities []*Entity
    23  	Version  int
    24  }
    25  
    26  // New returns a federation plugin that injects
    27  // federated directives and types into the schema
    28  func New(version int) plugin.Plugin {
    29  	if version == 0 {
    30  		version = 1
    31  	}
    32  
    33  	return &federation{Version: version}
    34  }
    35  
    36  // Name returns the plugin name
    37  func (f *federation) Name() string {
    38  	return "federation"
    39  }
    40  
    41  // MutateConfig mutates the configuration
    42  func (f *federation) MutateConfig(cfg *config.Config) error {
    43  	builtins := config.TypeMap{
    44  		"_Service": {
    45  			Model: config.StringList{
    46  				"github.com/niko0xdev/gqlgen/plugin/federation/fedruntime.Service",
    47  			},
    48  		},
    49  		"_Entity": {
    50  			Model: config.StringList{
    51  				"github.com/niko0xdev/gqlgen/plugin/federation/fedruntime.Entity",
    52  			},
    53  		},
    54  		"Entity": {
    55  			Model: config.StringList{
    56  				"github.com/niko0xdev/gqlgen/plugin/federation/fedruntime.Entity",
    57  			},
    58  		},
    59  		"_Any": {
    60  			Model: config.StringList{"github.com/niko0xdev/gqlgen/graphql.Map"},
    61  		},
    62  		"federation__Scope": {
    63  			Model: config.StringList{"github.com/niko0xdev/gqlgen/graphql.String"},
    64  		},
    65  	}
    66  
    67  	for typeName, entry := range builtins {
    68  		if cfg.Models.Exists(typeName) {
    69  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    70  		}
    71  		cfg.Models[typeName] = entry
    72  	}
    73  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    74  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    75  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    76  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    77  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    78  
    79  	// Federation 2 specific directives
    80  	if f.Version == 2 {
    81  		cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
    82  		cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
    83  		cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
    84  		cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
    85  		cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
    86  		cfg.Directives["authenticated"] = config.DirectiveConfig{SkipRuntime: true}
    87  		cfg.Directives["requiresScopes"] = config.DirectiveConfig{SkipRuntime: true}
    88  		cfg.Directives["interfaceObject"] = config.DirectiveConfig{SkipRuntime: true}
    89  		cfg.Directives["composeDirective"] = config.DirectiveConfig{SkipRuntime: true}
    90  	}
    91  
    92  	return nil
    93  }
    94  
    95  func (f *federation) InjectSourceEarly() *ast.Source {
    96  	input := ``
    97  
    98  	// add version-specific changes on key directive, as well as adding the new directives for federation 2
    99  	if f.Version == 1 {
   100  		input += `
   101  	directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
   102  	directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
   103  	directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
   104  	directive @extends on OBJECT | INTERFACE
   105  	directive @external on FIELD_DEFINITION
   106  	scalar _Any
   107  	scalar _FieldSet
   108  `
   109  	} else if f.Version == 2 {
   110  		input += `
   111  	directive @authenticated on FIELD_DEFINITION | OBJECT | INTERFACE | SCALAR | ENUM
   112  	directive @composeDirective(name: String!) repeatable on SCHEMA
   113  	directive @extends on OBJECT | INTERFACE
   114  	directive @external on OBJECT | FIELD_DEFINITION
   115  	directive @key(fields: FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
   116  	directive @inaccessible on
   117  	  | ARGUMENT_DEFINITION
   118  	  | ENUM
   119  	  | ENUM_VALUE
   120  	  | FIELD_DEFINITION
   121  	  | INPUT_FIELD_DEFINITION
   122  	  | INPUT_OBJECT
   123  	  | INTERFACE
   124  	  | OBJECT
   125  	  | SCALAR
   126  	  | UNION
   127  	directive @interfaceObject on OBJECT
   128  	directive @link(import: [String!], url: String!) repeatable on SCHEMA
   129  	directive @override(from: String!) on FIELD_DEFINITION
   130  	directive @provides(fields: FieldSet!) on FIELD_DEFINITION
   131  	directive @requires(fields: FieldSet!) on FIELD_DEFINITION
   132  	directive @requiresScopes(scopes: [[federation__Scope!]!]!) on 
   133  	  | FIELD_DEFINITION
   134  	  | OBJECT
   135  	  | INTERFACE
   136  	  | SCALAR
   137  	  | ENUM
   138  	directive @shareable repeatable on FIELD_DEFINITION | OBJECT
   139  	directive @tag(name: String!) repeatable on
   140  	  | ARGUMENT_DEFINITION
   141  	  | ENUM
   142  	  | ENUM_VALUE
   143  	  | FIELD_DEFINITION
   144  	  | INPUT_FIELD_DEFINITION
   145  	  | INPUT_OBJECT
   146  	  | INTERFACE
   147  	  | OBJECT
   148  	  | SCALAR
   149  	  | UNION
   150  	scalar _Any
   151  	scalar FieldSet
   152  	scalar federation__Scope
   153  `
   154  	}
   155  	return &ast.Source{
   156  		Name:    "federation/directives.graphql",
   157  		Input:   input,
   158  		BuiltIn: true,
   159  	}
   160  }
   161  
   162  // InjectSourceLate creates a GraphQL Entity type with all
   163  // the fields that had the @key directive
   164  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
   165  	f.setEntities(schema)
   166  
   167  	var entities, resolvers, entityResolverInputDefinitions string
   168  	for _, e := range f.Entities {
   169  
   170  		if e.Def.Kind != ast.Interface {
   171  			if entities != "" {
   172  				entities += " | "
   173  			}
   174  			entities += e.Name
   175  		} else if len(schema.GetPossibleTypes(e.Def)) == 0 {
   176  			fmt.Println(
   177  				"skipping @key field on interface " + e.Def.Name + " as no types implement it",
   178  			)
   179  		}
   180  
   181  		for _, r := range e.Resolvers {
   182  			if e.Multi {
   183  				if entityResolverInputDefinitions != "" {
   184  					entityResolverInputDefinitions += "\n\n"
   185  				}
   186  				entityResolverInputDefinitions += "input " + r.InputTypeName + " {\n"
   187  				for _, keyField := range r.KeyFields {
   188  					entityResolverInputDefinitions += fmt.Sprintf(
   189  						"\t%s: %s\n",
   190  						keyField.Field.ToGo(),
   191  						keyField.Definition.Type.String(),
   192  					)
   193  				}
   194  				entityResolverInputDefinitions += "}"
   195  				resolvers += fmt.Sprintf("\t%s(reps: [%s]!): [%s]\n", r.ResolverName, r.InputTypeName, e.Name)
   196  			} else {
   197  				resolverArgs := ""
   198  				for _, keyField := range r.KeyFields {
   199  					resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
   200  				}
   201  				resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
   202  			}
   203  		}
   204  	}
   205  
   206  	var blocks []string
   207  	if entities != "" {
   208  		entities = `# a union of all types that use the @key directive
   209  union _Entity = ` + entities
   210  		blocks = append(blocks, entities)
   211  	}
   212  
   213  	// resolvers can be empty if a service defines only "empty
   214  	// extend" types.  This should be rare.
   215  	if resolvers != "" {
   216  		if entityResolverInputDefinitions != "" {
   217  			blocks = append(blocks, entityResolverInputDefinitions)
   218  		}
   219  		resolvers = `# fake type to build resolver interfaces for users to implement
   220  type Entity {
   221  	` + resolvers + `
   222  }`
   223  		blocks = append(blocks, resolvers)
   224  	}
   225  
   226  	_serviceTypeDef := `type _Service {
   227    sdl: String
   228  }`
   229  	blocks = append(blocks, _serviceTypeDef)
   230  
   231  	var additionalQueryFields string
   232  	// Quote from the Apollo Federation subgraph specification:
   233  	// If no types are annotated with the key directive, then the
   234  	// _Entity union and _entities field should be removed from the schema
   235  	if len(f.Entities) > 0 {
   236  		additionalQueryFields += `  _entities(representations: [_Any!]!): [_Entity]!
   237  `
   238  	}
   239  	// _service field is required in any case
   240  	additionalQueryFields += `  _service: _Service!`
   241  
   242  	extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
   243  ` + additionalQueryFields + `
   244  }`
   245  	blocks = append(blocks, extendTypeQueryDef)
   246  
   247  	return &ast.Source{
   248  		Name:    "federation/entity.graphql",
   249  		BuiltIn: true,
   250  		Input:   "\n" + strings.Join(blocks, "\n\n") + "\n",
   251  	}
   252  }
   253  
   254  func (f *federation) GenerateCode(data *codegen.Data) error {
   255  	if len(f.Entities) > 0 {
   256  		if data.Objects.ByName("Entity") != nil {
   257  			data.Objects.ByName("Entity").Root = true
   258  		}
   259  		for _, e := range f.Entities {
   260  			obj := data.Objects.ByName(e.Def.Name)
   261  
   262  			if e.Def.Kind == ast.Interface {
   263  				if len(data.Interfaces[e.Def.Name].Implementors) == 0 {
   264  					fmt.Println(
   265  						"skipping @key field on interface " + e.Def.Name + " as no types implement it",
   266  					)
   267  					continue
   268  				}
   269  				obj = data.Objects.ByName(data.Interfaces[e.Def.Name].Implementors[0].Name)
   270  			}
   271  
   272  			for _, r := range e.Resolvers {
   273  				// fill in types for key fields
   274  				//
   275  				for _, keyField := range r.KeyFields {
   276  					if len(keyField.Field) == 0 {
   277  						fmt.Println(
   278  							"skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
   279  						)
   280  						continue
   281  					}
   282  					cgField := keyField.Field.TypeReference(obj, data.Objects)
   283  					keyField.Type = cgField.TypeReference
   284  				}
   285  			}
   286  
   287  			// fill in types for requires fields
   288  			//
   289  			for _, reqField := range e.Requires {
   290  				if len(reqField.Field) == 0 {
   291  					fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
   292  					continue
   293  				}
   294  				cgField := reqField.Field.TypeReference(obj, data.Objects)
   295  				reqField.Type = cgField.TypeReference
   296  			}
   297  		}
   298  	}
   299  
   300  	// fill in types for resolver inputs
   301  	//
   302  	for _, entity := range f.Entities {
   303  		if !entity.Multi {
   304  			continue
   305  		}
   306  
   307  		for _, resolver := range entity.Resolvers {
   308  			obj := data.Inputs.ByName(resolver.InputTypeName)
   309  			if obj == nil {
   310  				return fmt.Errorf("input object %s not found", resolver.InputTypeName)
   311  			}
   312  
   313  			resolver.InputType = obj.Type
   314  		}
   315  	}
   316  
   317  	return templates.Render(templates.Options{
   318  		PackageName:     data.Config.Federation.Package,
   319  		Filename:        data.Config.Federation.Filename,
   320  		Data:            f,
   321  		GeneratedHeader: true,
   322  		Packages:        data.Config.Packages,
   323  		Template:        federationTemplate,
   324  	})
   325  }
   326  
   327  func (f *federation) setEntities(schema *ast.Schema) {
   328  	for _, schemaType := range schema.Types {
   329  		keys, ok := isFederatedEntity(schemaType)
   330  		if !ok {
   331  			continue
   332  		}
   333  
   334  		if (schemaType.Kind == ast.Interface) && (len(schema.GetPossibleTypes(schemaType)) == 0) {
   335  			fmt.Printf("@key directive found on unused \"interface %s\". Will be ignored.\n", schemaType.Name)
   336  			continue
   337  		}
   338  
   339  		e := &Entity{
   340  			Name:      schemaType.Name,
   341  			Def:       schemaType,
   342  			Resolvers: nil,
   343  			Requires:  nil,
   344  		}
   345  
   346  		// Let's process custom entity resolver settings.
   347  		dir := schemaType.Directives.ForName("entityResolver")
   348  		if dir != nil {
   349  			if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
   350  				if dirVal, err := dirArg.Value.Value(nil); err == nil {
   351  					e.Multi = dirVal.(bool)
   352  				}
   353  			}
   354  		}
   355  
   356  		// If our schema has a field with a type defined in
   357  		// another service, then we need to define an "empty
   358  		// extend" of that type in this service, so this service
   359  		// knows what the type is like.  But the graphql-server
   360  		// will never ask us to actually resolve this "empty
   361  		// extend", so we don't require a resolver function for
   362  		// it.  (Well, it will never ask in practice; it's
   363  		// unclear whether the spec guarantees this.  See
   364  		// https://github.com/apollographql/apollo-server/issues/3852
   365  		// ).  Example:
   366  		//    type MyType {
   367  		//       myvar: TypeDefinedInOtherService
   368  		//    }
   369  		//    // Federation needs this type, but
   370  		//    // it doesn't need a resolver for it!
   371  		//    extend TypeDefinedInOtherService @key(fields: "id") {
   372  		//       id: ID @external
   373  		//    }
   374  		if !e.allFieldsAreExternal(f.Version) {
   375  			for _, dir := range keys {
   376  				if len(dir.Arguments) > 2 {
   377  					panic("More than two arguments provided for @key declaration.")
   378  				}
   379  				var arg *ast.Argument
   380  
   381  				// since keys are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
   382  				for _, a := range dir.Arguments {
   383  					if a.Name == "fields" {
   384  						if arg != nil {
   385  							panic("More than one `fields` provided for @key declaration.")
   386  						}
   387  						arg = a
   388  					}
   389  				}
   390  
   391  				keyFieldSet := fieldset.New(arg.Value.Raw, nil)
   392  
   393  				keyFields := make([]*KeyField, len(keyFieldSet))
   394  				resolverFields := []string{}
   395  				for i, field := range keyFieldSet {
   396  					def := field.FieldDefinition(schemaType, schema)
   397  
   398  					if def == nil {
   399  						panic(fmt.Sprintf("no field for %v", field))
   400  					}
   401  
   402  					keyFields[i] = &KeyField{Definition: def, Field: field}
   403  					resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
   404  				}
   405  
   406  				resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
   407  				var resolverName string
   408  				if e.Multi {
   409  					resolverFieldsToGo += "s" // Pluralize for better API readability
   410  					resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
   411  				} else {
   412  					resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
   413  				}
   414  
   415  				e.Resolvers = append(e.Resolvers, &EntityResolver{
   416  					ResolverName:  resolverName,
   417  					KeyFields:     keyFields,
   418  					InputTypeName: resolverFieldsToGo + "Input",
   419  				})
   420  			}
   421  
   422  			e.Requires = []*Requires{}
   423  			for _, f := range schemaType.Fields {
   424  				dir := f.Directives.ForName("requires")
   425  				if dir == nil {
   426  					continue
   427  				}
   428  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   429  					panic("Exactly one `fields` argument needed for @requires declaration.")
   430  				}
   431  				requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
   432  				for _, field := range requiresFieldSet {
   433  					e.Requires = append(e.Requires, &Requires{
   434  						Name:  field.ToGoPrivate(),
   435  						Field: field,
   436  					})
   437  				}
   438  			}
   439  		}
   440  		f.Entities = append(f.Entities, e)
   441  	}
   442  
   443  	// make sure order remains stable across multiple builds
   444  	sort.Slice(f.Entities, func(i, j int) bool {
   445  		return f.Entities[i].Name < f.Entities[j].Name
   446  	})
   447  }
   448  
   449  func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
   450  	switch schemaType.Kind {
   451  	case ast.Object:
   452  		keys := schemaType.Directives.ForNames("key")
   453  		if len(keys) > 0 {
   454  			return keys, true
   455  		}
   456  	case ast.Interface:
   457  		keys := schemaType.Directives.ForNames("key")
   458  		if len(keys) > 0 {
   459  			return keys, true
   460  		}
   461  
   462  		// TODO: support @extends for interfaces
   463  		if dir := schemaType.Directives.ForName("extends"); dir != nil {
   464  			panic(
   465  				fmt.Sprintf(
   466  					"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
   467  					schemaType.Name,
   468  				))
   469  		}
   470  	default:
   471  		// ignore
   472  	}
   473  	return nil, false
   474  }