github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/protoc-gen-openapiv2/internal/genopenapi/generator.go (about) 1 package genopenapi 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "path/filepath" 9 "reflect" 10 "sort" 11 "strings" 12 13 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor" 14 gen "github.com/grpc-ecosystem/grpc-gateway/v2/internal/generator" 15 openapioptions "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2/options" 16 statuspb "google.golang.org/genproto/googleapis/rpc/status" 17 "google.golang.org/grpc/grpclog" 18 "google.golang.org/protobuf/proto" 19 "google.golang.org/protobuf/reflect/protodesc" 20 "google.golang.org/protobuf/types/descriptorpb" 21 "google.golang.org/protobuf/types/known/anypb" 22 "google.golang.org/protobuf/types/pluginpb" 23 "gopkg.in/yaml.v3" 24 ) 25 26 var errNoTargetService = errors.New("no target service defined in the file") 27 28 type generator struct { 29 reg *descriptor.Registry 30 format Format 31 } 32 33 type wrapper struct { 34 fileName string 35 swagger *openapiSwaggerObject 36 } 37 38 type GeneratorOptions struct { 39 Registry *descriptor.Registry 40 RecursiveDepth int 41 } 42 43 // New returns a new generator which generates grpc gateway files. 44 func New(reg *descriptor.Registry, format Format) gen.Generator { 45 return &generator{ 46 reg: reg, 47 format: format, 48 } 49 } 50 51 // Merge a lot of OpenAPI file (wrapper) to single one OpenAPI file 52 func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper { 53 var mergedTarget *wrapper 54 for _, f := range targets { 55 if mergedTarget == nil { 56 mergedTarget = &wrapper{ 57 fileName: mergeFileName, 58 swagger: f.swagger, 59 } 60 } else { 61 for k, v := range f.swagger.Definitions { 62 mergedTarget.swagger.Definitions[k] = v 63 } 64 for k, v := range f.swagger.SecurityDefinitions { 65 mergedTarget.swagger.SecurityDefinitions[k] = v 66 } 67 copy(mergedTarget.swagger.Paths, f.swagger.Paths) 68 mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...) 69 } 70 } 71 return mergedTarget 72 } 73 74 // Q: What's up with the alias types here? 75 // A: We don't want to completely override how these structs are marshaled into 76 // JSON, we only want to add fields (see below, extensionMarshalJSON). 77 // An infinite recursion would happen if we'd call json.Marshal on the struct 78 // that has swaggerObject as an embedded field. To avoid that, we'll create 79 // type aliases, and those don't have the custom MarshalJSON methods defined 80 // on them. See http://choly.ca/post/go-json-marshalling/ (or, if it ever 81 // goes away, use 82 // https://web.archive.org/web/20190806073003/http://choly.ca/post/go-json-marshalling/. 83 func (so openapiSwaggerObject) MarshalJSON() ([]byte, error) { 84 type alias openapiSwaggerObject 85 return extensionMarshalJSON(alias(so), so.extensions) 86 } 87 88 // MarshalYAML implements yaml.Marshaler interface. 89 // 90 // It is required in order to pass extensions inline. 91 // 92 // Example: 93 // 94 // extensions: {x-key: x-value} 95 // type: string 96 // 97 // It will be rendered as: 98 // 99 // x-key: x-value 100 // type: string 101 // 102 // Use generics when the project will be upgraded to go 1.18+. 103 func (so openapiSwaggerObject) MarshalYAML() (interface{}, error) { 104 type Alias openapiSwaggerObject 105 106 return struct { 107 Extension map[string]interface{} `yaml:",inline"` 108 Alias `yaml:",inline"` 109 }{ 110 Extension: extensionsToMap(so.extensions), 111 Alias: Alias(so), 112 }, nil 113 } 114 115 // Custom json marshaller for openapiPathsObject. Ensures 116 // openapiPathsObject is marshalled into expected format in generated 117 // swagger.json. 118 func (po openapiPathsObject) MarshalJSON() ([]byte, error) { 119 var buf bytes.Buffer 120 121 buf.WriteString("{") 122 for i, pd := range po { 123 if i != 0 { 124 buf.WriteString(",") 125 } 126 // marshal key 127 key, err := json.Marshal(pd.Path) 128 if err != nil { 129 return nil, err 130 } 131 buf.Write(key) 132 buf.WriteString(":") 133 // marshal value 134 val, err := json.Marshal(pd.PathItemObject) 135 if err != nil { 136 return nil, err 137 } 138 buf.Write(val) 139 } 140 141 buf.WriteString("}") 142 return buf.Bytes(), nil 143 } 144 145 // Custom yaml marshaller for openapiPathsObject. Ensures 146 // openapiPathsObject is marshalled into expected format in generated 147 // swagger.yaml. 148 func (po openapiPathsObject) MarshalYAML() (interface{}, error) { 149 var pathObjectNode yaml.Node 150 pathObjectNode.Kind = yaml.MappingNode 151 152 for _, pathData := range po { 153 var pathNode yaml.Node 154 155 pathNode.SetString(pathData.Path) 156 pathItemObjectNode, err := pathData.PathItemObject.toYAMLNode() 157 if err != nil { 158 return nil, err 159 } 160 pathObjectNode.Content = append(pathObjectNode.Content, &pathNode, pathItemObjectNode) 161 } 162 163 return pathObjectNode, nil 164 } 165 166 // We can simplify this implementation once the go-yaml bug is resolved. See: https://github.com/go-yaml/yaml/issues/643. 167 // 168 // func (pio *openapiPathItemObject) toYAMLNode() (*yaml.Node, error) { 169 // var node yaml.Node 170 // if err := node.Encode(pio); err != nil { 171 // return nil, err 172 // } 173 // return &node, nil 174 // } 175 func (pio *openapiPathItemObject) toYAMLNode() (*yaml.Node, error) { 176 var doc yaml.Node 177 var buf bytes.Buffer 178 ec := yaml.NewEncoder(&buf) 179 ec.SetIndent(2) 180 if err := ec.Encode(pio); err != nil { 181 return nil, err 182 } 183 if err := yaml.Unmarshal(buf.Bytes(), &doc); err != nil { 184 return nil, err 185 } 186 if len(doc.Content) == 0 { 187 return nil, errors.New("unexpected number of yaml nodes") 188 } 189 return doc.Content[0], nil 190 } 191 192 func (so openapiInfoObject) MarshalJSON() ([]byte, error) { 193 type alias openapiInfoObject 194 return extensionMarshalJSON(alias(so), so.extensions) 195 } 196 197 func (so openapiInfoObject) MarshalYAML() (interface{}, error) { 198 type Alias openapiInfoObject 199 200 return struct { 201 Extension map[string]interface{} `yaml:",inline"` 202 Alias `yaml:",inline"` 203 }{ 204 Extension: extensionsToMap(so.extensions), 205 Alias: Alias(so), 206 }, nil 207 } 208 209 func (so openapiSecuritySchemeObject) MarshalJSON() ([]byte, error) { 210 type alias openapiSecuritySchemeObject 211 return extensionMarshalJSON(alias(so), so.extensions) 212 } 213 214 func (so openapiSecuritySchemeObject) MarshalYAML() (interface{}, error) { 215 type Alias openapiSecuritySchemeObject 216 217 return struct { 218 Extension map[string]interface{} `yaml:",inline"` 219 Alias `yaml:",inline"` 220 }{ 221 Extension: extensionsToMap(so.extensions), 222 Alias: Alias(so), 223 }, nil 224 } 225 226 func (so openapiOperationObject) MarshalJSON() ([]byte, error) { 227 type alias openapiOperationObject 228 return extensionMarshalJSON(alias(so), so.extensions) 229 } 230 231 func (so openapiOperationObject) MarshalYAML() (interface{}, error) { 232 type Alias openapiOperationObject 233 234 return struct { 235 Extension map[string]interface{} `yaml:",inline"` 236 Alias `yaml:",inline"` 237 }{ 238 Extension: extensionsToMap(so.extensions), 239 Alias: Alias(so), 240 }, nil 241 } 242 243 func (so openapiResponseObject) MarshalJSON() ([]byte, error) { 244 type alias openapiResponseObject 245 return extensionMarshalJSON(alias(so), so.extensions) 246 } 247 248 func (so openapiResponseObject) MarshalYAML() (interface{}, error) { 249 type Alias openapiResponseObject 250 251 return struct { 252 Extension map[string]interface{} `yaml:",inline"` 253 Alias `yaml:",inline"` 254 }{ 255 Extension: extensionsToMap(so.extensions), 256 Alias: Alias(so), 257 }, nil 258 } 259 260 func (so openapiSchemaObject) MarshalJSON() ([]byte, error) { 261 type alias openapiSchemaObject 262 return extensionMarshalJSON(alias(so), so.extensions) 263 } 264 265 func (so openapiSchemaObject) MarshalYAML() (interface{}, error) { 266 type Alias openapiSchemaObject 267 268 return struct { 269 Extension map[string]interface{} `yaml:",inline"` 270 Alias `yaml:",inline"` 271 }{ 272 Extension: extensionsToMap(so.extensions), 273 Alias: Alias(so), 274 }, nil 275 } 276 277 func (so openapiParameterObject) MarshalJSON() ([]byte, error) { 278 type alias openapiParameterObject 279 return extensionMarshalJSON(alias(so), so.extensions) 280 } 281 282 func (so openapiParameterObject) MarshalYAML() (interface{}, error) { 283 type Alias openapiParameterObject 284 285 return struct { 286 Extension map[string]interface{} `yaml:",inline"` 287 Alias `yaml:",inline"` 288 }{ 289 Extension: extensionsToMap(so.extensions), 290 Alias: Alias(so), 291 }, nil 292 } 293 294 func (so openapiTagObject) MarshalJSON() ([]byte, error) { 295 type alias openapiTagObject 296 return extensionMarshalJSON(alias(so), so.extensions) 297 } 298 299 func (so openapiTagObject) MarshalYAML() (interface{}, error) { 300 type Alias openapiTagObject 301 302 return struct { 303 Extension map[string]interface{} `yaml:",inline"` 304 Alias `yaml:",inline"` 305 }{ 306 Extension: extensionsToMap(so.extensions), 307 Alias: Alias(so), 308 }, nil 309 } 310 311 func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) { 312 // To append arbitrary keys to the struct we'll render into json, 313 // we're creating another struct that embeds the original one, and 314 // its extra fields: 315 // 316 // The struct will look like 317 // struct { 318 // *openapiCore 319 // XGrpcGatewayFoo json.RawMessage `json:"x-grpc-gateway-foo"` 320 // XGrpcGatewayBar json.RawMessage `json:"x-grpc-gateway-bar"` 321 // } 322 // and thus render into what we want -- the JSON of openapiCore with the 323 // extensions appended. 324 fields := []reflect.StructField{ 325 { // embedded 326 Name: "Embedded", 327 Type: reflect.TypeOf(so), 328 Anonymous: true, 329 }, 330 } 331 for _, ext := range extensions { 332 fields = append(fields, reflect.StructField{ 333 Name: fieldName(ext.key), 334 Type: reflect.TypeOf(ext.value), 335 Tag: reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)), 336 }) 337 } 338 339 t := reflect.StructOf(fields) 340 s := reflect.New(t).Elem() 341 s.Field(0).Set(reflect.ValueOf(so)) 342 for _, ext := range extensions { 343 s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value)) 344 } 345 return json.Marshal(s.Interface()) 346 } 347 348 // encodeOpenAPI converts OpenAPI file obj to pluginpb.CodeGeneratorResponse_File 349 func encodeOpenAPI(file *wrapper, format Format) (*descriptor.ResponseFile, error) { 350 var contentBuf bytes.Buffer 351 enc, err := format.NewEncoder(&contentBuf) 352 if err != nil { 353 return nil, err 354 } 355 356 if err := enc.Encode(*file.swagger); err != nil { 357 return nil, err 358 } 359 360 name := file.fileName 361 ext := filepath.Ext(name) 362 base := strings.TrimSuffix(name, ext) 363 output := fmt.Sprintf("%s.swagger."+string(format), base) 364 return &descriptor.ResponseFile{ 365 CodeGeneratorResponse_File: &pluginpb.CodeGeneratorResponse_File{ 366 Name: proto.String(output), 367 Content: proto.String(contentBuf.String()), 368 }, 369 }, nil 370 } 371 372 func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.ResponseFile, error) { 373 var files []*descriptor.ResponseFile 374 if g.reg.IsAllowMerge() { 375 var mergedTarget *descriptor.File 376 // try to find proto leader 377 for _, f := range targets { 378 if proto.HasExtension(f.Options, openapioptions.E_Openapiv2Swagger) { 379 mergedTarget = f 380 break 381 } 382 } 383 // merge protos to leader 384 for _, f := range targets { 385 if mergedTarget == nil { 386 mergedTarget = f 387 } else if mergedTarget != f { 388 mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...) 389 mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...) 390 mergedTarget.Services = append(mergedTarget.Services, f.Services...) 391 } 392 } 393 394 targets = nil 395 targets = append(targets, mergedTarget) 396 } 397 398 var openapis []*wrapper 399 for _, file := range targets { 400 if grpclog.V(1) { 401 grpclog.Infof("Processing %s", file.GetName()) 402 } 403 swagger, err := applyTemplate(param{File: file, reg: g.reg}) 404 if errors.Is(err, errNoTargetService) { 405 if grpclog.V(1) { 406 grpclog.Infof("%s: %v", file.GetName(), err) 407 } 408 continue 409 } 410 if err != nil { 411 return nil, err 412 } 413 openapis = append(openapis, &wrapper{ 414 fileName: file.GetName(), 415 swagger: swagger, 416 }) 417 } 418 419 if g.reg.IsAllowMerge() { 420 targetOpenAPI := mergeTargetFile(openapis, g.reg.GetMergeFileName()) 421 if !g.reg.IsPreserveRPCOrder() { 422 targetOpenAPI.swagger.sortPathsAlphabetically() 423 } 424 f, err := encodeOpenAPI(targetOpenAPI, g.format) 425 if err != nil { 426 return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", g.reg.GetMergeFileName(), err) 427 } 428 files = append(files, f) 429 if grpclog.V(1) { 430 grpclog.Infof("New OpenAPI file will emit") 431 } 432 } else { 433 for _, file := range openapis { 434 if !g.reg.IsPreserveRPCOrder() { 435 file.swagger.sortPathsAlphabetically() 436 } 437 f, err := encodeOpenAPI(file, g.format) 438 if err != nil { 439 return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", file.fileName, err) 440 } 441 files = append(files, f) 442 if grpclog.V(1) { 443 grpclog.Infof("New OpenAPI file will emit") 444 } 445 } 446 } 447 return files, nil 448 } 449 450 func (so openapiSwaggerObject) sortPathsAlphabetically() { 451 sort.Slice(so.Paths, func(i, j int) bool { 452 return so.Paths[i].Path < so.Paths[j].Path 453 }) 454 } 455 456 // AddErrorDefs Adds google.rpc.Status and google.protobuf.Any 457 // to registry (used for error-related API responses) 458 func AddErrorDefs(reg *descriptor.Registry) error { 459 // load internal protos 460 any := protodesc.ToFileDescriptorProto((&anypb.Any{}).ProtoReflect().Descriptor().ParentFile()) 461 any.SourceCodeInfo = new(descriptorpb.SourceCodeInfo) 462 status := protodesc.ToFileDescriptorProto((&statuspb.Status{}).ProtoReflect().Descriptor().ParentFile()) 463 status.SourceCodeInfo = new(descriptorpb.SourceCodeInfo) 464 return reg.Load(&pluginpb.CodeGeneratorRequest{ 465 ProtoFile: []*descriptorpb.FileDescriptorProto{ 466 any, 467 status, 468 }, 469 }) 470 } 471 472 func extensionsToMap(extensions []extension) map[string]interface{} { 473 m := make(map[string]interface{}, len(extensions)) 474 475 for _, v := range extensions { 476 m[v.key] = RawExample(v.value) 477 } 478 479 return m 480 }