git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/plugin/federation/federation.go (about)

     1  package federation
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"strings"
     7  
     8  	"github.com/vektah/gqlparser/v2/ast"
     9  
    10  	"git.sr.ht/~sircmpwn/gqlgen/codegen"
    11  	"git.sr.ht/~sircmpwn/gqlgen/codegen/config"
    12  	"git.sr.ht/~sircmpwn/gqlgen/codegen/templates"
    13  	"git.sr.ht/~sircmpwn/gqlgen/plugin"
    14  )
    15  
    16  type federation struct {
    17  	Entities []*Entity
    18  }
    19  
    20  // New returns a federation plugin that injects
    21  // federated directives and types into the schema
    22  func New() plugin.Plugin {
    23  	return &federation{}
    24  }
    25  
    26  // Name returns the plugin name
    27  func (f *federation) Name() string {
    28  	return "federation"
    29  }
    30  
    31  // MutateConfig mutates the configuration
    32  func (f *federation) MutateConfig(cfg *config.Config) error {
    33  	builtins := config.TypeMap{
    34  		"_Service": {
    35  			Model: config.StringList{
    36  				"git.sr.ht/~sircmpwn/gqlgen/plugin/federation/fedruntime.Service",
    37  			},
    38  		},
    39  		"_Entity": {
    40  			Model: config.StringList{
    41  				"git.sr.ht/~sircmpwn/gqlgen/plugin/federation/fedruntime.Entity",
    42  			},
    43  		},
    44  		"Entity": {
    45  			Model: config.StringList{
    46  				"git.sr.ht/~sircmpwn/gqlgen/plugin/federation/fedruntime.Entity",
    47  			},
    48  		},
    49  		"_Any": {
    50  			Model: config.StringList{"git.sr.ht/~sircmpwn/gqlgen/graphql.Map"},
    51  		},
    52  	}
    53  	for typeName, entry := range builtins {
    54  		if cfg.Models.Exists(typeName) {
    55  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    56  		}
    57  		cfg.Models[typeName] = entry
    58  	}
    59  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    60  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    61  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    62  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    63  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    64  
    65  	return nil
    66  }
    67  
    68  func (f *federation) InjectSourceEarly() *ast.Source {
    69  	return &ast.Source{
    70  		Name: "federation/directives.graphql",
    71  		Input: `
    72  scalar _Any
    73  scalar _FieldSet
    74  
    75  directive @external on FIELD_DEFINITION
    76  directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
    77  directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
    78  directive @key(fields: _FieldSet!) on OBJECT | INTERFACE
    79  directive @extends on OBJECT
    80  `,
    81  		BuiltIn: true,
    82  	}
    83  }
    84  
    85  // InjectSources creates a GraphQL Entity type with all
    86  // the fields that had the @key directive
    87  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
    88  	f.setEntities(schema)
    89  
    90  	entities := ""
    91  	resolvers := ""
    92  	for i, e := range f.Entities {
    93  		if i != 0 {
    94  			entities += " | "
    95  		}
    96  		entities += e.Name
    97  
    98  		resolverArgs := ""
    99  		for _, field := range e.KeyFields {
   100  			resolverArgs += fmt.Sprintf("%s: %s,", field.Field.Name, field.Field.Type.String())
   101  		}
   102  		resolvers += fmt.Sprintf("\t%s(%s): %s!\n", e.ResolverName, resolverArgs, e.Def.Name)
   103  
   104  	}
   105  
   106  	if len(f.Entities) == 0 {
   107  		// It's unusual for a service not to have any entities, but
   108  		// possible if it only exports top-level queries and mutations.
   109  		return nil
   110  	}
   111  
   112  	return &ast.Source{
   113  		Name:    "federation/entity.graphql",
   114  		BuiltIn: true,
   115  		Input: `
   116  # a union of all types that use the @key directive
   117  union _Entity = ` + entities + `
   118  
   119  # fake type to build resolver interfaces for users to implement
   120  type Entity {
   121  	` + resolvers + `
   122  }
   123  
   124  type _Service {
   125    sdl: String
   126  }
   127  
   128  extend type Query {
   129    _entities(representations: [_Any!]!): [_Entity]!
   130    _service: _Service!
   131  }
   132  `,
   133  	}
   134  }
   135  
   136  // Entity represents a federated type
   137  // that was declared in the GQL schema.
   138  type Entity struct {
   139  	Name         string      // The same name as the type declaration
   140  	KeyFields    []*KeyField // The fields declared in @key.
   141  	ResolverName string      // The resolver name, such as FindUserByID
   142  	Def          *ast.Definition
   143  	Requires     []*Requires
   144  }
   145  
   146  type KeyField struct {
   147  	Field         *ast.FieldDefinition
   148  	TypeReference *config.TypeReference // The Go representation of that field type
   149  }
   150  
   151  // Requires represents an @requires clause
   152  type Requires struct {
   153  	Name   string          // the name of the field
   154  	Fields []*RequireField // the name of the sibling fields
   155  }
   156  
   157  // RequireField is similar to an entity but it is a field not
   158  // an object
   159  type RequireField struct {
   160  	Name          string                // The same name as the type declaration
   161  	NameGo        string                // The Go struct field name
   162  	TypeReference *config.TypeReference // The Go representation of that field type
   163  }
   164  
   165  func (f *federation) GenerateCode(data *codegen.Data) error {
   166  	if len(f.Entities) > 0 {
   167  		data.Objects.ByName("Entity").Root = true
   168  		for _, e := range f.Entities {
   169  			obj := data.Objects.ByName(e.Def.Name)
   170  			for _, field := range obj.Fields {
   171  				// Storing key fields in a slice rather than a map
   172  				// to preserve insertion order at the tradeoff of higher
   173  				// lookup complexity.
   174  				keyField := f.getKeyField(e.KeyFields, field.Name)
   175  				if keyField != nil {
   176  					keyField.TypeReference = field.TypeReference
   177  				}
   178  				for _, r := range e.Requires {
   179  					for _, rf := range r.Fields {
   180  						if rf.Name == field.Name {
   181  							rf.TypeReference = field.TypeReference
   182  							rf.NameGo = field.GoFieldName
   183  						}
   184  					}
   185  				}
   186  			}
   187  		}
   188  	}
   189  
   190  	return templates.Render(templates.Options{
   191  		PackageName:     data.Config.Federation.Package,
   192  		Filename:        data.Config.Federation.Filename,
   193  		Data:            f,
   194  		GeneratedHeader: true,
   195  		Packages:        data.Config.Packages,
   196  	})
   197  }
   198  
   199  func (f *federation) getKeyField(keyFields []*KeyField, fieldName string) *KeyField {
   200  	for _, field := range keyFields {
   201  		if field.Field.Name == fieldName {
   202  			return field
   203  		}
   204  	}
   205  	return nil
   206  }
   207  
   208  func (f *federation) setEntities(schema *ast.Schema) {
   209  	for _, schemaType := range schema.Types {
   210  		if schemaType.Kind == ast.Object {
   211  			dir := schemaType.Directives.ForName("key") // TODO: interfaces
   212  			if dir != nil {
   213  				if len(dir.Arguments) > 1 {
   214  					panic("Multiple arguments are not currently supported in @key declaration.")
   215  				}
   216  				fieldName := dir.Arguments[0].Value.Raw // TODO: multiple arguments
   217  				if strings.Contains(fieldName, "{") {
   218  					panic("Nested fields are not currently supported in @key declaration.")
   219  				}
   220  
   221  				requires := []*Requires{}
   222  				for _, f := range schemaType.Fields {
   223  					dir := f.Directives.ForName("requires")
   224  					if dir == nil {
   225  						continue
   226  					}
   227  					fields := strings.Split(dir.Arguments[0].Value.Raw, " ")
   228  					requireFields := []*RequireField{}
   229  					for _, f := range fields {
   230  						requireFields = append(requireFields, &RequireField{
   231  							Name: f,
   232  						})
   233  					}
   234  					requires = append(requires, &Requires{
   235  						Name:   f.Name,
   236  						Fields: requireFields,
   237  					})
   238  				}
   239  
   240  				fieldNames := strings.Split(fieldName, " ")
   241  				keyFields := make([]*KeyField, len(fieldNames))
   242  				resolverName := fmt.Sprintf("find%sBy", schemaType.Name)
   243  				for i, f := range fieldNames {
   244  					field := schemaType.Fields.ForName(f)
   245  
   246  					keyFields[i] = &KeyField{Field: field}
   247  					if i > 0 {
   248  						resolverName += "And"
   249  					}
   250  					resolverName += templates.ToGo(f)
   251  
   252  				}
   253  
   254  				f.Entities = append(f.Entities, &Entity{
   255  					Name:         schemaType.Name,
   256  					KeyFields:    keyFields,
   257  					Def:          schemaType,
   258  					ResolverName: resolverName,
   259  					Requires:     requires,
   260  				})
   261  			}
   262  		}
   263  	}
   264  
   265  	// make sure order remains stable across multiple builds
   266  	sort.Slice(f.Entities, func(i, j int) bool {
   267  		return f.Entities[i].Name < f.Entities[j].Name
   268  	})
   269  }