github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/protoc-gen-openapiv2/internal/genopenapi/generator.go (about)

     1  package genopenapi
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"path/filepath"
     9  	"reflect"
    10  	"sort"
    11  	"strings"
    12  
    13  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
    14  	gen "github.com/grpc-ecosystem/grpc-gateway/v2/internal/generator"
    15  	openapioptions "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2/options"
    16  	statuspb "google.golang.org/genproto/googleapis/rpc/status"
    17  	"google.golang.org/grpc/grpclog"
    18  	"google.golang.org/protobuf/proto"
    19  	"google.golang.org/protobuf/reflect/protodesc"
    20  	"google.golang.org/protobuf/types/descriptorpb"
    21  	"google.golang.org/protobuf/types/known/anypb"
    22  	"google.golang.org/protobuf/types/pluginpb"
    23  	"gopkg.in/yaml.v3"
    24  )
    25  
    26  var errNoTargetService = errors.New("no target service defined in the file")
    27  
    28  type generator struct {
    29  	reg    *descriptor.Registry
    30  	format Format
    31  }
    32  
    33  type wrapper struct {
    34  	fileName string
    35  	swagger  *openapiSwaggerObject
    36  }
    37  
    38  type GeneratorOptions struct {
    39  	Registry       *descriptor.Registry
    40  	RecursiveDepth int
    41  }
    42  
    43  // New returns a new generator which generates grpc gateway files.
    44  func New(reg *descriptor.Registry, format Format) gen.Generator {
    45  	return &generator{
    46  		reg:    reg,
    47  		format: format,
    48  	}
    49  }
    50  
    51  // Merge a lot of OpenAPI file (wrapper) to single one OpenAPI file
    52  func mergeTargetFile(targets []*wrapper, mergeFileName string) *wrapper {
    53  	var mergedTarget *wrapper
    54  	for _, f := range targets {
    55  		if mergedTarget == nil {
    56  			mergedTarget = &wrapper{
    57  				fileName: mergeFileName,
    58  				swagger:  f.swagger,
    59  			}
    60  		} else {
    61  			for k, v := range f.swagger.Definitions {
    62  				mergedTarget.swagger.Definitions[k] = v
    63  			}
    64  			for k, v := range f.swagger.SecurityDefinitions {
    65  				mergedTarget.swagger.SecurityDefinitions[k] = v
    66  			}
    67  			copy(mergedTarget.swagger.Paths, f.swagger.Paths)
    68  			mergedTarget.swagger.Security = append(mergedTarget.swagger.Security, f.swagger.Security...)
    69  		}
    70  	}
    71  	return mergedTarget
    72  }
    73  
    74  // Q: What's up with the alias types here?
    75  // A: We don't want to completely override how these structs are marshaled into
    76  // JSON, we only want to add fields (see below, extensionMarshalJSON).
    77  // An infinite recursion would happen if we'd call json.Marshal on the struct
    78  // that has swaggerObject as an embedded field. To avoid that, we'll create
    79  // type aliases, and those don't have the custom MarshalJSON methods defined
    80  // on them. See http://choly.ca/post/go-json-marshalling/ (or, if it ever
    81  // goes away, use
    82  // https://web.archive.org/web/20190806073003/http://choly.ca/post/go-json-marshalling/.
    83  func (so openapiSwaggerObject) MarshalJSON() ([]byte, error) {
    84  	type alias openapiSwaggerObject
    85  	return extensionMarshalJSON(alias(so), so.extensions)
    86  }
    87  
    88  // MarshalYAML implements yaml.Marshaler interface.
    89  //
    90  // It is required in order to pass extensions inline.
    91  //
    92  // Example:
    93  //
    94  //	extensions: {x-key: x-value}
    95  //	type: string
    96  //
    97  // It will be rendered as:
    98  //
    99  //	x-key: x-value
   100  //	type: string
   101  //
   102  // Use generics when the project will be upgraded to go 1.18+.
   103  func (so openapiSwaggerObject) MarshalYAML() (interface{}, error) {
   104  	type Alias openapiSwaggerObject
   105  
   106  	return struct {
   107  		Extension map[string]interface{} `yaml:",inline"`
   108  		Alias     `yaml:",inline"`
   109  	}{
   110  		Extension: extensionsToMap(so.extensions),
   111  		Alias:     Alias(so),
   112  	}, nil
   113  }
   114  
   115  // Custom json marshaller for openapiPathsObject. Ensures
   116  // openapiPathsObject is marshalled into expected format in generated
   117  // swagger.json.
   118  func (po openapiPathsObject) MarshalJSON() ([]byte, error) {
   119  	var buf bytes.Buffer
   120  
   121  	buf.WriteString("{")
   122  	for i, pd := range po {
   123  		if i != 0 {
   124  			buf.WriteString(",")
   125  		}
   126  		// marshal key
   127  		key, err := json.Marshal(pd.Path)
   128  		if err != nil {
   129  			return nil, err
   130  		}
   131  		buf.Write(key)
   132  		buf.WriteString(":")
   133  		// marshal value
   134  		val, err := json.Marshal(pd.PathItemObject)
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  		buf.Write(val)
   139  	}
   140  
   141  	buf.WriteString("}")
   142  	return buf.Bytes(), nil
   143  }
   144  
   145  // Custom yaml marshaller for openapiPathsObject. Ensures
   146  // openapiPathsObject is marshalled into expected format in generated
   147  // swagger.yaml.
   148  func (po openapiPathsObject) MarshalYAML() (interface{}, error) {
   149  	var pathObjectNode yaml.Node
   150  	pathObjectNode.Kind = yaml.MappingNode
   151  
   152  	for _, pathData := range po {
   153  		var pathNode yaml.Node
   154  
   155  		pathNode.SetString(pathData.Path)
   156  		pathItemObjectNode, err := pathData.PathItemObject.toYAMLNode()
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  		pathObjectNode.Content = append(pathObjectNode.Content, &pathNode, pathItemObjectNode)
   161  	}
   162  
   163  	return pathObjectNode, nil
   164  }
   165  
   166  // We can simplify this implementation once the go-yaml bug is resolved. See: https://github.com/go-yaml/yaml/issues/643.
   167  //
   168  //	func (pio *openapiPathItemObject) toYAMLNode() (*yaml.Node, error) {
   169  //		var node yaml.Node
   170  //		if err := node.Encode(pio); err != nil {
   171  //			return nil, err
   172  //		}
   173  //		return &node, nil
   174  //	}
   175  func (pio *openapiPathItemObject) toYAMLNode() (*yaml.Node, error) {
   176  	var doc yaml.Node
   177  	var buf bytes.Buffer
   178  	ec := yaml.NewEncoder(&buf)
   179  	ec.SetIndent(2)
   180  	if err := ec.Encode(pio); err != nil {
   181  		return nil, err
   182  	}
   183  	if err := yaml.Unmarshal(buf.Bytes(), &doc); err != nil {
   184  		return nil, err
   185  	}
   186  	if len(doc.Content) == 0 {
   187  		return nil, errors.New("unexpected number of yaml nodes")
   188  	}
   189  	return doc.Content[0], nil
   190  }
   191  
   192  func (so openapiInfoObject) MarshalJSON() ([]byte, error) {
   193  	type alias openapiInfoObject
   194  	return extensionMarshalJSON(alias(so), so.extensions)
   195  }
   196  
   197  func (so openapiInfoObject) MarshalYAML() (interface{}, error) {
   198  	type Alias openapiInfoObject
   199  
   200  	return struct {
   201  		Extension map[string]interface{} `yaml:",inline"`
   202  		Alias     `yaml:",inline"`
   203  	}{
   204  		Extension: extensionsToMap(so.extensions),
   205  		Alias:     Alias(so),
   206  	}, nil
   207  }
   208  
   209  func (so openapiSecuritySchemeObject) MarshalJSON() ([]byte, error) {
   210  	type alias openapiSecuritySchemeObject
   211  	return extensionMarshalJSON(alias(so), so.extensions)
   212  }
   213  
   214  func (so openapiSecuritySchemeObject) MarshalYAML() (interface{}, error) {
   215  	type Alias openapiSecuritySchemeObject
   216  
   217  	return struct {
   218  		Extension map[string]interface{} `yaml:",inline"`
   219  		Alias     `yaml:",inline"`
   220  	}{
   221  		Extension: extensionsToMap(so.extensions),
   222  		Alias:     Alias(so),
   223  	}, nil
   224  }
   225  
   226  func (so openapiOperationObject) MarshalJSON() ([]byte, error) {
   227  	type alias openapiOperationObject
   228  	return extensionMarshalJSON(alias(so), so.extensions)
   229  }
   230  
   231  func (so openapiOperationObject) MarshalYAML() (interface{}, error) {
   232  	type Alias openapiOperationObject
   233  
   234  	return struct {
   235  		Extension map[string]interface{} `yaml:",inline"`
   236  		Alias     `yaml:",inline"`
   237  	}{
   238  		Extension: extensionsToMap(so.extensions),
   239  		Alias:     Alias(so),
   240  	}, nil
   241  }
   242  
   243  func (so openapiResponseObject) MarshalJSON() ([]byte, error) {
   244  	type alias openapiResponseObject
   245  	return extensionMarshalJSON(alias(so), so.extensions)
   246  }
   247  
   248  func (so openapiResponseObject) MarshalYAML() (interface{}, error) {
   249  	type Alias openapiResponseObject
   250  
   251  	return struct {
   252  		Extension map[string]interface{} `yaml:",inline"`
   253  		Alias     `yaml:",inline"`
   254  	}{
   255  		Extension: extensionsToMap(so.extensions),
   256  		Alias:     Alias(so),
   257  	}, nil
   258  }
   259  
   260  func (so openapiSchemaObject) MarshalJSON() ([]byte, error) {
   261  	type alias openapiSchemaObject
   262  	return extensionMarshalJSON(alias(so), so.extensions)
   263  }
   264  
   265  func (so openapiSchemaObject) MarshalYAML() (interface{}, error) {
   266  	type Alias openapiSchemaObject
   267  
   268  	return struct {
   269  		Extension map[string]interface{} `yaml:",inline"`
   270  		Alias     `yaml:",inline"`
   271  	}{
   272  		Extension: extensionsToMap(so.extensions),
   273  		Alias:     Alias(so),
   274  	}, nil
   275  }
   276  
   277  func (so openapiParameterObject) MarshalJSON() ([]byte, error) {
   278  	type alias openapiParameterObject
   279  	return extensionMarshalJSON(alias(so), so.extensions)
   280  }
   281  
   282  func (so openapiParameterObject) MarshalYAML() (interface{}, error) {
   283  	type Alias openapiParameterObject
   284  
   285  	return struct {
   286  		Extension map[string]interface{} `yaml:",inline"`
   287  		Alias     `yaml:",inline"`
   288  	}{
   289  		Extension: extensionsToMap(so.extensions),
   290  		Alias:     Alias(so),
   291  	}, nil
   292  }
   293  
   294  func (so openapiTagObject) MarshalJSON() ([]byte, error) {
   295  	type alias openapiTagObject
   296  	return extensionMarshalJSON(alias(so), so.extensions)
   297  }
   298  
   299  func (so openapiTagObject) MarshalYAML() (interface{}, error) {
   300  	type Alias openapiTagObject
   301  
   302  	return struct {
   303  		Extension map[string]interface{} `yaml:",inline"`
   304  		Alias     `yaml:",inline"`
   305  	}{
   306  		Extension: extensionsToMap(so.extensions),
   307  		Alias:     Alias(so),
   308  	}, nil
   309  }
   310  
   311  func extensionMarshalJSON(so interface{}, extensions []extension) ([]byte, error) {
   312  	// To append arbitrary keys to the struct we'll render into json,
   313  	// we're creating another struct that embeds the original one, and
   314  	// its extra fields:
   315  	//
   316  	// The struct will look like
   317  	// struct {
   318  	//   *openapiCore
   319  	//   XGrpcGatewayFoo json.RawMessage `json:"x-grpc-gateway-foo"`
   320  	//   XGrpcGatewayBar json.RawMessage `json:"x-grpc-gateway-bar"`
   321  	// }
   322  	// and thus render into what we want -- the JSON of openapiCore with the
   323  	// extensions appended.
   324  	fields := []reflect.StructField{
   325  		{ // embedded
   326  			Name:      "Embedded",
   327  			Type:      reflect.TypeOf(so),
   328  			Anonymous: true,
   329  		},
   330  	}
   331  	for _, ext := range extensions {
   332  		fields = append(fields, reflect.StructField{
   333  			Name: fieldName(ext.key),
   334  			Type: reflect.TypeOf(ext.value),
   335  			Tag:  reflect.StructTag(fmt.Sprintf("json:\"%s\"", ext.key)),
   336  		})
   337  	}
   338  
   339  	t := reflect.StructOf(fields)
   340  	s := reflect.New(t).Elem()
   341  	s.Field(0).Set(reflect.ValueOf(so))
   342  	for _, ext := range extensions {
   343  		s.FieldByName(fieldName(ext.key)).Set(reflect.ValueOf(ext.value))
   344  	}
   345  	return json.Marshal(s.Interface())
   346  }
   347  
   348  // encodeOpenAPI converts OpenAPI file obj to pluginpb.CodeGeneratorResponse_File
   349  func encodeOpenAPI(file *wrapper, format Format) (*descriptor.ResponseFile, error) {
   350  	var contentBuf bytes.Buffer
   351  	enc, err := format.NewEncoder(&contentBuf)
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	if err := enc.Encode(*file.swagger); err != nil {
   357  		return nil, err
   358  	}
   359  
   360  	name := file.fileName
   361  	ext := filepath.Ext(name)
   362  	base := strings.TrimSuffix(name, ext)
   363  	output := fmt.Sprintf("%s.swagger."+string(format), base)
   364  	return &descriptor.ResponseFile{
   365  		CodeGeneratorResponse_File: &pluginpb.CodeGeneratorResponse_File{
   366  			Name:    proto.String(output),
   367  			Content: proto.String(contentBuf.String()),
   368  		},
   369  	}, nil
   370  }
   371  
   372  func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.ResponseFile, error) {
   373  	var files []*descriptor.ResponseFile
   374  	if g.reg.IsAllowMerge() {
   375  		var mergedTarget *descriptor.File
   376  		// try to find proto leader
   377  		for _, f := range targets {
   378  			if proto.HasExtension(f.Options, openapioptions.E_Openapiv2Swagger) {
   379  				mergedTarget = f
   380  				break
   381  			}
   382  		}
   383  		// merge protos to leader
   384  		for _, f := range targets {
   385  			if mergedTarget == nil {
   386  				mergedTarget = f
   387  			} else if mergedTarget != f {
   388  				mergedTarget.Enums = append(mergedTarget.Enums, f.Enums...)
   389  				mergedTarget.Messages = append(mergedTarget.Messages, f.Messages...)
   390  				mergedTarget.Services = append(mergedTarget.Services, f.Services...)
   391  			}
   392  		}
   393  
   394  		targets = nil
   395  		targets = append(targets, mergedTarget)
   396  	}
   397  
   398  	var openapis []*wrapper
   399  	for _, file := range targets {
   400  		if grpclog.V(1) {
   401  			grpclog.Infof("Processing %s", file.GetName())
   402  		}
   403  		swagger, err := applyTemplate(param{File: file, reg: g.reg})
   404  		if errors.Is(err, errNoTargetService) {
   405  			if grpclog.V(1) {
   406  				grpclog.Infof("%s: %v", file.GetName(), err)
   407  			}
   408  			continue
   409  		}
   410  		if err != nil {
   411  			return nil, err
   412  		}
   413  		openapis = append(openapis, &wrapper{
   414  			fileName: file.GetName(),
   415  			swagger:  swagger,
   416  		})
   417  	}
   418  
   419  	if g.reg.IsAllowMerge() {
   420  		targetOpenAPI := mergeTargetFile(openapis, g.reg.GetMergeFileName())
   421  		if !g.reg.IsPreserveRPCOrder() {
   422  			targetOpenAPI.swagger.sortPathsAlphabetically()
   423  		}
   424  		f, err := encodeOpenAPI(targetOpenAPI, g.format)
   425  		if err != nil {
   426  			return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", g.reg.GetMergeFileName(), err)
   427  		}
   428  		files = append(files, f)
   429  		if grpclog.V(1) {
   430  			grpclog.Infof("New OpenAPI file will emit")
   431  		}
   432  	} else {
   433  		for _, file := range openapis {
   434  			if !g.reg.IsPreserveRPCOrder() {
   435  				file.swagger.sortPathsAlphabetically()
   436  			}
   437  			f, err := encodeOpenAPI(file, g.format)
   438  			if err != nil {
   439  				return nil, fmt.Errorf("failed to encode OpenAPI for %s: %w", file.fileName, err)
   440  			}
   441  			files = append(files, f)
   442  			if grpclog.V(1) {
   443  				grpclog.Infof("New OpenAPI file will emit")
   444  			}
   445  		}
   446  	}
   447  	return files, nil
   448  }
   449  
   450  func (so openapiSwaggerObject) sortPathsAlphabetically() {
   451  	sort.Slice(so.Paths, func(i, j int) bool {
   452  		return so.Paths[i].Path < so.Paths[j].Path
   453  	})
   454  }
   455  
   456  // AddErrorDefs Adds google.rpc.Status and google.protobuf.Any
   457  // to registry (used for error-related API responses)
   458  func AddErrorDefs(reg *descriptor.Registry) error {
   459  	// load internal protos
   460  	any := protodesc.ToFileDescriptorProto((&anypb.Any{}).ProtoReflect().Descriptor().ParentFile())
   461  	any.SourceCodeInfo = new(descriptorpb.SourceCodeInfo)
   462  	status := protodesc.ToFileDescriptorProto((&statuspb.Status{}).ProtoReflect().Descriptor().ParentFile())
   463  	status.SourceCodeInfo = new(descriptorpb.SourceCodeInfo)
   464  	return reg.Load(&pluginpb.CodeGeneratorRequest{
   465  		ProtoFile: []*descriptorpb.FileDescriptorProto{
   466  			any,
   467  			status,
   468  		},
   469  	})
   470  }
   471  
   472  func extensionsToMap(extensions []extension) map[string]interface{} {
   473  	m := make(map[string]interface{}, len(extensions))
   474  
   475  	for _, v := range extensions {
   476  		m[v.key] = RawExample(v.value)
   477  	}
   478  
   479  	return m
   480  }