github.com/operandinc/gqlgen@v0.16.1/plugin/federation/fieldset/fieldset.go (about)

     1  package fieldset
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/operandinc/gqlgen/codegen"
     8  	"github.com/operandinc/gqlgen/codegen/templates"
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  )
    11  
    12  // Set represents a FieldSet that is used in federation directives @key and @requires.
    13  // Would be happier to reuse FieldSet parsing from gqlparser, but this suits for now.
    14  //
    15  type Set []Field
    16  
    17  // Field represents a single field in a FieldSet
    18  //
    19  type Field []string
    20  
    21  // New parses a FieldSet string into a TinyFieldSet.
    22  //
    23  func New(raw string, prefix []string) Set {
    24  	if !strings.Contains(raw, "{") {
    25  		return parseUnnestedKeyFieldSet(raw, prefix)
    26  	}
    27  
    28  	var (
    29  		ret       = Set{}
    30  		subPrefix = prefix
    31  	)
    32  	before, during, after := extractSubs(raw)
    33  
    34  	if before != "" {
    35  		befores := New(before, prefix)
    36  		if len(befores) > 0 {
    37  			subPrefix = befores[len(befores)-1]
    38  			ret = append(ret, befores[:len(befores)-1]...)
    39  		}
    40  	}
    41  	if during != "" {
    42  		ret = append(ret, New(during, subPrefix)...)
    43  	}
    44  	if after != "" {
    45  		ret = append(ret, New(after, prefix)...)
    46  	}
    47  	return ret
    48  }
    49  
    50  // FieldDefinition looks up a field in the type.
    51  //
    52  func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition {
    53  	objType := schemaType
    54  	def := objType.Fields.ForName(f[0])
    55  
    56  	for _, part := range f[1:] {
    57  		if objType.Kind != ast.Object {
    58  			panic(fmt.Sprintf(`invalid sub-field reference "%s" in %v: `, objType.Name, f))
    59  		}
    60  		x := def.Type.Name()
    61  		objType = schema.Types[x]
    62  		if objType == nil {
    63  			panic("invalid schema type: " + x)
    64  		}
    65  		def = objType.Fields.ForName(part)
    66  	}
    67  	if def == nil {
    68  		return nil
    69  	}
    70  	ret := *def // shallow copy
    71  	ret.Name = f.ToGoPrivate()
    72  
    73  	return &ret
    74  }
    75  
    76  // TypeReference looks up the type of a field.
    77  //
    78  func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field {
    79  	var def *codegen.Field
    80  
    81  	for _, part := range f {
    82  		def = fieldByName(obj, part)
    83  		if def == nil {
    84  			panic("unable to find field " + f[0])
    85  		}
    86  		obj = objects.ByName(def.TypeReference.Definition.Name)
    87  	}
    88  	return def
    89  }
    90  
    91  // ToGo converts a (possibly nested) field into a proper public Go name.
    92  //
    93  func (f Field) ToGo() string {
    94  	var ret string
    95  
    96  	for _, field := range f {
    97  		ret += templates.ToGo(field)
    98  	}
    99  	return ret
   100  }
   101  
   102  // ToGoPrivate converts a (possibly nested) field into a proper private Go name.
   103  //
   104  func (f Field) ToGoPrivate() string {
   105  	var ret string
   106  
   107  	for i, field := range f {
   108  		if i == 0 {
   109  			ret += templates.ToGoPrivate(field)
   110  			continue
   111  		}
   112  		ret += templates.ToGo(field)
   113  	}
   114  	return ret
   115  }
   116  
   117  // Join concatenates the field parts with a string separator between. Useful in templates.
   118  //
   119  func (f Field) Join(str string) string {
   120  	return strings.Join(f, str)
   121  }
   122  
   123  // JoinGo concatenates the Go name of field parts with a string separator between. Useful in templates.
   124  //
   125  func (f Field) JoinGo(str string) string {
   126  	strs := []string{}
   127  
   128  	for _, s := range f {
   129  		strs = append(strs, templates.ToGo(s))
   130  	}
   131  	return strings.Join(strs, str)
   132  }
   133  
   134  func (f Field) LastIndex() int {
   135  	return len(f) - 1
   136  }
   137  
   138  // local functions
   139  
   140  // parseUnnestedKeyFieldSet // handles simple case where none of the fields are nested.
   141  //
   142  func parseUnnestedKeyFieldSet(raw string, prefix []string) Set {
   143  	ret := Set{}
   144  
   145  	for _, s := range strings.Fields(raw) {
   146  		next := append(prefix[:], s) //nolint:gocritic // slicing out on purpose
   147  		ret = append(ret, next)
   148  	}
   149  	return ret
   150  }
   151  
   152  // extractSubs splits out and trims sub-expressions from before, inside, and after "{}".
   153  //
   154  func extractSubs(str string) (string, string, string) {
   155  	start := strings.Index(str, "{")
   156  	end := matchingBracketIndex(str, start)
   157  
   158  	if start < 0 || end < 0 {
   159  		panic("invalid key fieldSet: " + str)
   160  	}
   161  	return strings.TrimSpace(str[:start]), strings.TrimSpace(str[start+1 : end]), strings.TrimSpace(str[end+1:])
   162  }
   163  
   164  // matchingBracketIndex returns the index of the closing bracket, assuming an open bracket at start.
   165  //
   166  func matchingBracketIndex(str string, start int) int {
   167  	if start < 0 || len(str) <= start+1 {
   168  		return -1
   169  	}
   170  	var depth int
   171  
   172  	for i, c := range str[start+1:] {
   173  		switch c {
   174  		case '{':
   175  			depth++
   176  		case '}':
   177  			if depth == 0 {
   178  				return start + 1 + i
   179  			}
   180  			depth--
   181  		}
   182  	}
   183  	return -1
   184  }
   185  
   186  func fieldByName(obj *codegen.Object, name string) *codegen.Field {
   187  	for _, field := range obj.Fields {
   188  		if field.Name == name {
   189  			return field
   190  		}
   191  	}
   192  	return nil
   193  }