github.com/operandinc/gqlgen@v0.16.1/plugin/federation/fieldset/fieldset.go (about) 1 package fieldset 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/operandinc/gqlgen/codegen" 8 "github.com/operandinc/gqlgen/codegen/templates" 9 "github.com/vektah/gqlparser/v2/ast" 10 ) 11 12 // Set represents a FieldSet that is used in federation directives @key and @requires. 13 // Would be happier to reuse FieldSet parsing from gqlparser, but this suits for now. 14 // 15 type Set []Field 16 17 // Field represents a single field in a FieldSet 18 // 19 type Field []string 20 21 // New parses a FieldSet string into a TinyFieldSet. 22 // 23 func New(raw string, prefix []string) Set { 24 if !strings.Contains(raw, "{") { 25 return parseUnnestedKeyFieldSet(raw, prefix) 26 } 27 28 var ( 29 ret = Set{} 30 subPrefix = prefix 31 ) 32 before, during, after := extractSubs(raw) 33 34 if before != "" { 35 befores := New(before, prefix) 36 if len(befores) > 0 { 37 subPrefix = befores[len(befores)-1] 38 ret = append(ret, befores[:len(befores)-1]...) 39 } 40 } 41 if during != "" { 42 ret = append(ret, New(during, subPrefix)...) 43 } 44 if after != "" { 45 ret = append(ret, New(after, prefix)...) 46 } 47 return ret 48 } 49 50 // FieldDefinition looks up a field in the type. 51 // 52 func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition { 53 objType := schemaType 54 def := objType.Fields.ForName(f[0]) 55 56 for _, part := range f[1:] { 57 if objType.Kind != ast.Object { 58 panic(fmt.Sprintf(`invalid sub-field reference "%s" in %v: `, objType.Name, f)) 59 } 60 x := def.Type.Name() 61 objType = schema.Types[x] 62 if objType == nil { 63 panic("invalid schema type: " + x) 64 } 65 def = objType.Fields.ForName(part) 66 } 67 if def == nil { 68 return nil 69 } 70 ret := *def // shallow copy 71 ret.Name = f.ToGoPrivate() 72 73 return &ret 74 } 75 76 // TypeReference looks up the type of a field. 77 // 78 func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field { 79 var def *codegen.Field 80 81 for _, part := range f { 82 def = fieldByName(obj, part) 83 if def == nil { 84 panic("unable to find field " + f[0]) 85 } 86 obj = objects.ByName(def.TypeReference.Definition.Name) 87 } 88 return def 89 } 90 91 // ToGo converts a (possibly nested) field into a proper public Go name. 92 // 93 func (f Field) ToGo() string { 94 var ret string 95 96 for _, field := range f { 97 ret += templates.ToGo(field) 98 } 99 return ret 100 } 101 102 // ToGoPrivate converts a (possibly nested) field into a proper private Go name. 103 // 104 func (f Field) ToGoPrivate() string { 105 var ret string 106 107 for i, field := range f { 108 if i == 0 { 109 ret += templates.ToGoPrivate(field) 110 continue 111 } 112 ret += templates.ToGo(field) 113 } 114 return ret 115 } 116 117 // Join concatenates the field parts with a string separator between. Useful in templates. 118 // 119 func (f Field) Join(str string) string { 120 return strings.Join(f, str) 121 } 122 123 // JoinGo concatenates the Go name of field parts with a string separator between. Useful in templates. 124 // 125 func (f Field) JoinGo(str string) string { 126 strs := []string{} 127 128 for _, s := range f { 129 strs = append(strs, templates.ToGo(s)) 130 } 131 return strings.Join(strs, str) 132 } 133 134 func (f Field) LastIndex() int { 135 return len(f) - 1 136 } 137 138 // local functions 139 140 // parseUnnestedKeyFieldSet // handles simple case where none of the fields are nested. 141 // 142 func parseUnnestedKeyFieldSet(raw string, prefix []string) Set { 143 ret := Set{} 144 145 for _, s := range strings.Fields(raw) { 146 next := append(prefix[:], s) //nolint:gocritic // slicing out on purpose 147 ret = append(ret, next) 148 } 149 return ret 150 } 151 152 // extractSubs splits out and trims sub-expressions from before, inside, and after "{}". 153 // 154 func extractSubs(str string) (string, string, string) { 155 start := strings.Index(str, "{") 156 end := matchingBracketIndex(str, start) 157 158 if start < 0 || end < 0 { 159 panic("invalid key fieldSet: " + str) 160 } 161 return strings.TrimSpace(str[:start]), strings.TrimSpace(str[start+1 : end]), strings.TrimSpace(str[end+1:]) 162 } 163 164 // matchingBracketIndex returns the index of the closing bracket, assuming an open bracket at start. 165 // 166 func matchingBracketIndex(str string, start int) int { 167 if start < 0 || len(str) <= start+1 { 168 return -1 169 } 170 var depth int 171 172 for i, c := range str[start+1:] { 173 switch c { 174 case '{': 175 depth++ 176 case '}': 177 if depth == 0 { 178 return start + 1 + i 179 } 180 depth-- 181 } 182 } 183 return -1 184 } 185 186 func fieldByName(obj *codegen.Object, name string) *codegen.Field { 187 for _, field := range obj.Fields { 188 if field.Name == name { 189 return field 190 } 191 } 192 return nil 193 }