github.com/cbroglie/openapi2proto@v0.0.0-20171004221549-76b8501da882/proto.go (about)

     1  package openapi2proto
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"log"
     9  	"net/http"
    10  	"os"
    11  	"path"
    12  	"regexp"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  	"text/template"
    17  
    18  	"github.com/pkg/errors"
    19  
    20  	yaml "gopkg.in/yaml.v2"
    21  )
    22  
    23  func getPathItems(p *Path) []*Items {
    24  	var items []*Items
    25  	if p.Get != nil {
    26  		items = append(items, getEndpointItems(p.Get)...)
    27  	}
    28  	if p.Put != nil {
    29  		items = append(items, getEndpointItems(p.Put)...)
    30  	}
    31  	if p.Post != nil {
    32  		items = append(items, getEndpointItems(p.Post)...)
    33  	}
    34  	if p.Delete != nil {
    35  		items = append(items, getEndpointItems(p.Delete)...)
    36  	}
    37  	return items
    38  }
    39  
    40  func getEndpointItems(e *Endpoint) []*Items {
    41  	items := make([]*Items, len(e.Parameters))
    42  	for i, itm := range e.Parameters {
    43  		// add the request params
    44  		items[i] = itm
    45  	}
    46  	// and the response
    47  	var ok bool
    48  	var res *Response
    49  	res, ok = e.Responses["200"]
    50  	if !ok {
    51  		res, ok = e.Responses["201"]
    52  	}
    53  	if !ok {
    54  		return items
    55  	}
    56  	if res.Schema != nil {
    57  		items = append(items, res.Schema)
    58  	}
    59  	return items
    60  }
    61  
    62  func LoadDefinition(pth string) (*APIDefinition, error) {
    63  	var (
    64  		b   []byte
    65  		err error
    66  	)
    67  	// url? fetch it
    68  	if strings.HasPrefix(pth, "http") {
    69  		res, err := http.Get(pth)
    70  		if err != nil {
    71  			log.Printf("unable to fetch path: %s - %s", pth, err)
    72  			os.Exit(1)
    73  		}
    74  		defer res.Body.Close()
    75  
    76  		b, err = ioutil.ReadAll(res.Body)
    77  		if err != nil {
    78  			log.Printf("unable to read from path: %s - %s", pth, err)
    79  			os.Exit(1)
    80  		}
    81  		if res.StatusCode != http.StatusOK {
    82  			log.Print("unable to get remote definition: ", string(b))
    83  			os.Exit(1)
    84  		}
    85  	} else {
    86  		b, err = ioutil.ReadFile(pth)
    87  		if err != nil {
    88  			log.Print("unable to read spec file: ", err)
    89  			os.Exit(1)
    90  		}
    91  
    92  	}
    93  
    94  	var api *APIDefinition
    95  	isYaml := path.Ext(pth) == ".yaml"
    96  	if isYaml {
    97  		err = yaml.Unmarshal(b, &api)
    98  	} else {
    99  		err = json.Unmarshal(b, &api)
   100  	}
   101  	if err != nil {
   102  		return nil, errors.Wrap(err, "unable to parse referened file")
   103  	}
   104  
   105  	// no paths or defs declared?
   106  	// check if this is a plain map[name]*Items (definitions)
   107  	if len(api.Paths) == 0 && len(api.Definitions) == 0 {
   108  		var defs map[string]*Items
   109  		if isYaml {
   110  			err = yaml.Unmarshal(b, &defs)
   111  		} else {
   112  			err = json.Unmarshal(b, &defs)
   113  		}
   114  		_, nok := defs["type"]
   115  		if err == nil && !nok {
   116  			api.Definitions = defs
   117  		}
   118  	}
   119  
   120  	// _still_ no defs? try to see if this is a single item
   121  	// check if its just an *Item
   122  	if len(api.Paths) == 0 && len(api.Definitions) == 0 {
   123  		var item Items
   124  		if isYaml {
   125  			err = yaml.Unmarshal(b, &item)
   126  		} else {
   127  			err = json.Unmarshal(b, &item)
   128  		}
   129  		if err != nil {
   130  			return nil, errors.Wrap(err, "unable to load referenced item")
   131  		}
   132  		api.Definitions = map[string]*Items{strings.TrimSuffix(path.Base(pth), path.Ext(pth)): &item}
   133  	}
   134  
   135  	api.FileName = pth
   136  
   137  	return api, nil
   138  }
   139  
   140  // GenerateProto will attempt to generate an protobuf version 3
   141  // schema from the given OpenAPI definition.
   142  func GenerateProto(api *APIDefinition, annotate bool) ([]byte, error) {
   143  	if api.Definitions == nil {
   144  		api.Definitions = map[string]*Items{}
   145  	}
   146  	// jam all the parameters into the normal 'definitions' for easier reference.
   147  	for name, param := range api.Parameters {
   148  		api.Definitions[name] = param
   149  	}
   150  
   151  	// at this point, traverse imports to find possible nested definition references
   152  	// inline external $refs
   153  	imports, err := importsAndRefs(api)
   154  	if err != nil {
   155  		log.Fatal(err)
   156  	}
   157  
   158  	// if no package name given, default to filename
   159  	if api.Info.Title == "" {
   160  		api.Info.Title = strings.TrimSuffix(path.Base(api.FileName),
   161  			path.Ext(api.FileName))
   162  	}
   163  
   164  	var out bytes.Buffer
   165  	data := struct {
   166  		*APIDefinition
   167  		Annotate bool
   168  		Imports  []string
   169  	}{
   170  		api, annotate, imports,
   171  	}
   172  	err = protoFileTmpl.Execute(&out, data)
   173  	if err != nil {
   174  		return nil, fmt.Errorf("unable to generate protobuf schema: %s", err)
   175  	}
   176  	return cleanSpacing(addImports(out.Bytes())), nil
   177  }
   178  
   179  func importsAndRefs(api *APIDefinition) ([]string, error) {
   180  	var imports []string
   181  	// determine external imports by traversing struct, looking for $refs
   182  	for _, def := range api.Definitions {
   183  		defs, err := replaceExternalRefs(def)
   184  		if err != nil {
   185  			return imports, errors.Wrap(err, "unable to replace external refs in definitions")
   186  		}
   187  		for k, v := range defs {
   188  			api.Definitions[k] = v
   189  		}
   190  		imports = append(imports, traverseItemsForImports(def, api.Definitions)...)
   191  	}
   192  
   193  	for _, pth := range api.Paths {
   194  		for _, itm := range getPathItems(pth) {
   195  			defs, err := replaceExternalRefs(itm)
   196  			if err != nil {
   197  				return imports, errors.Wrap(err, "unable to replace external refs in path")
   198  			}
   199  			for k, v := range defs {
   200  				api.Definitions[k] = v
   201  			}
   202  			imports = append(imports, traverseItemsForImports(itm, api.Definitions)...)
   203  		}
   204  	}
   205  	sort.Strings(imports)
   206  	var impts []string
   207  	// dedupe
   208  	var last string
   209  	for _, i := range imports {
   210  		if i != last {
   211  			impts = append(impts, i)
   212  		}
   213  		last = i
   214  	}
   215  	return imports, nil
   216  }
   217  
   218  func replaceExternalRefs(item *Items) (map[string]*Items, error) {
   219  	defs := map[string]*Items{}
   220  	if item.Ref != "" {
   221  		possSpecPath, name := refDatas(item.Ref)
   222  		// if it's an OpenAPI spec, try reading it in
   223  		if name == "" { // path#/type
   224  			name = strings.TrimSuffix(name, path.Ext(name))
   225  		}
   226  		if possSpecPath != "" && (path.Ext(possSpecPath) != ".proto") {
   227  			def, err := LoadDefinition(possSpecPath)
   228  			if err == nil {
   229  				if len(def.Definitions) > 0 {
   230  					for nam, v := range def.Definitions {
   231  						if name == nam {
   232  							*item = *v
   233  						}
   234  						if v.Type == "object" {
   235  							defs[nam] = v
   236  						}
   237  					}
   238  				}
   239  			}
   240  		}
   241  	}
   242  	if item.Schema != nil && item.Schema.Ref != "" {
   243  		possSpecPath, name := refDatas(item.Schema.Ref)
   244  		// if it's an OpenAPI spec, try reading it in
   245  		if name == "" { // path#/type
   246  			name = strings.Title(strings.TrimSuffix(item.Schema.Ref, path.Ext(item.Schema.Ref)))
   247  		}
   248  		if possSpecPath != "" && (path.Ext(possSpecPath) != ".proto") {
   249  			def, err := LoadDefinition(possSpecPath)
   250  			if err == nil {
   251  				item.Schema.Ref = "#/definitions/" + name
   252  				for k, v := range def.Definitions {
   253  					defs[k] = v
   254  				}
   255  			}
   256  		}
   257  	}
   258  	for _, itm := range item.Model.Properties {
   259  		ds, err := replaceExternalRefs(itm)
   260  		if err != nil {
   261  			return nil, errors.Wrap(err, "unable to replace external spec refs")
   262  		}
   263  		for k, v := range ds {
   264  			defs[k] = v
   265  		}
   266  	}
   267  	if item.Items != nil {
   268  		ds, err := replaceExternalRefs(item.Items)
   269  		if err != nil {
   270  			return nil, errors.Wrap(err, "unable to replace external spec refs")
   271  		}
   272  		for k, v := range ds {
   273  			defs[k] = v
   274  		}
   275  	}
   276  	if item.AdditionalProperties != nil {
   277  		ds, err := replaceExternalRefs(item.AdditionalProperties)
   278  		if err != nil {
   279  			return nil, errors.Wrap(err, "unable to replace external spec refs")
   280  		}
   281  		for k, v := range ds {
   282  			defs[k] = v
   283  		}
   284  	}
   285  	return defs, nil
   286  }
   287  
   288  func traverseItemsForImports(item *Items, defs map[string]*Items) []string {
   289  	imports := map[string]struct{}{}
   290  	if item.Ref != "" {
   291  		_, pkg := refType(item.Ref, defs)
   292  		impt, _ := refDatas(item.Ref)
   293  		pext := path.Ext(impt)
   294  		if (pkg != "" && (path.Ext(item.Ref) == "")) || pext == ".proto" {
   295  			imports[pkg] = struct{}{}
   296  		}
   297  	}
   298  	for _, itm := range item.Model.Properties {
   299  		for _, impt := range traverseItemsForImports(itm, defs) {
   300  			imports[impt] = struct{}{}
   301  		}
   302  	}
   303  	if item.Items != nil {
   304  		for _, impt := range traverseItemsForImports(item.Items, defs) {
   305  			imports[impt] = struct{}{}
   306  		}
   307  	}
   308  	if item.AdditionalProperties != nil {
   309  		for _, impt := range traverseItemsForImports(item.AdditionalProperties, defs) {
   310  			imports[impt] = struct{}{}
   311  		}
   312  	}
   313  	var out []string
   314  	for impt, _ := range imports {
   315  		out = append(out, impt)
   316  	}
   317  	return out
   318  }
   319  
   320  const protoFileTmplStr = `syntax = "proto3";
   321  {{ $defs := .Definitions }}{{ $annotate := .Annotate }}{{ if $annotate }}
   322  import "google/api/annotations.proto";
   323  {{ end }}{{ range $import := .Imports }}
   324  import "{{ $import }}";
   325  {{ end }}
   326  package {{ packageName .Info.Title }};
   327  {{ range $path, $endpoint := .Paths }}
   328  {{ $endpoint.ProtoMessages $path $defs }}
   329  {{ end }}
   330  {{ range $modelName, $model := $defs }}
   331  {{ $model.ProtoMessage "" $modelName $defs counter -1 }}
   332  {{ end }}{{ $basePath := .BasePath }}
   333  {{ if len .Paths }}service {{ serviceName .Info.Title }} {{"{"}}{{ range $path, $endpoint := .Paths }}
   334  {{ $endpoint.ProtoEndpoints $annotate $basePath $path }}{{ end }}
   335  }{{ end }}
   336  `
   337  
   338  const protoEndpointTmplStr = `{{ if .HasComment }}{{ .Comment }}{{ end }}    rpc {{ .Name }}({{ .RequestName }}) returns ({{ .ResponseName }}) {{"{"}}{{ if .Annotate }}
   339        option (google.api.http) = {
   340          {{ .Method }}: "{{ .Path }}"{{ if .IncludeBody }}
   341          body: "{{ .BodyAttr }}"{{ end }}
   342        };
   343      {{ end }}{{"}"}}`
   344  
   345  const protoMsgTmplStr = `{{ $i := counter }}{{ $defs := .Defs }}{{ $msgName := .Name }}{{ $depth := .Depth }}message {{ .Name }} {{"{"}}{{ range $propName, $prop := .Properties }}
   346  {{ indent $depth }}{{ if $prop.HasComment }}{{ indent $depth }}{{ $prop.Comment }}{{ end }}    {{ $prop.ProtoMessage $msgName $propName $defs $i $depth }};{{ end }}
   347  {{ indent $depth }}}`
   348  
   349  const protoEnumTmplStr = `{{ $i := zcounter }}{{ $depth := .Depth }}{{ $name := .Name }}enum {{ .Name }} {{"{"}}{{ range $index, $pName := .Enum }}
   350  {{ indent $depth }}    {{ toEnum $name $pName $depth }} = {{ inc $i }};{{ end }}
   351  {{ indent $depth }}}`
   352  
   353  var funcMap = template.FuncMap{
   354  	"inc":              inc,
   355  	"counter":          counter,
   356  	"zcounter":         zcounter,
   357  	"indent":           indent,
   358  	"toEnum":           toEnum,
   359  	"packageName":      packageName,
   360  	"serviceName":      serviceName,
   361  	"PathMethodToName": PathMethodToName,
   362  }
   363  
   364  func packageName(t string) string {
   365  	return strings.ToLower(strings.Join(strings.Fields(t), ""))
   366  }
   367  
   368  func serviceName(t string) string {
   369  	var name string
   370  	for _, nme := range strings.Fields(t) {
   371  		name += strings.Title(nme)
   372  	}
   373  	return name + "Service"
   374  }
   375  
   376  func counter() *int {
   377  	i := 0
   378  	return &i
   379  }
   380  func zcounter() *int {
   381  	i := -1
   382  	return &i
   383  }
   384  
   385  func inc(i *int) int {
   386  	*i++
   387  	return *i
   388  }
   389  
   390  func indent(depth int) string {
   391  	var out string
   392  	for i := 0; i < depth; i++ {
   393  		out += "    "
   394  	}
   395  	return out
   396  }
   397  
   398  func toEnum(name, enum string, depth int) string {
   399  	if strings.TrimSpace(enum) == "" {
   400  		enum = "empty"
   401  	}
   402  	e := enum
   403  	if _, err := strconv.Atoi(enum); err == nil || depth > 0 {
   404  		e = name + "_" + enum
   405  	}
   406  	e = strings.Replace(e, " & ", " AND ", -1)
   407  	e = strings.Replace(e, "&", "_AND_", -1)
   408  	e = strings.Replace(e, " ", "_", -1)
   409  	re := regexp.MustCompile(`[%\{\}\[\]()/\.'’-]`)
   410  	e = re.ReplaceAllString(e, "")
   411  	return strings.ToUpper(e)
   412  }
   413  
   414  var (
   415  	protoFileTmpl     = template.Must(template.New("protoFile").Funcs(funcMap).Parse(protoFileTmplStr))
   416  	protoMsgTmpl      = template.Must(template.New("protoMsg").Funcs(funcMap).Parse(protoMsgTmplStr))
   417  	protoEndpointTmpl = template.Must(template.New("protoEndpoint").Funcs(funcMap).Parse(protoEndpointTmplStr))
   418  	protoEnumTmpl     = template.Must(template.New("protoEnum").Funcs(funcMap).Parse(protoEnumTmplStr))
   419  )
   420  
   421  func cleanSpacing(output []byte) []byte {
   422  	re := regexp.MustCompile(`}\n*message `)
   423  	output = re.ReplaceAll(output, []byte("}\n\nmessage "))
   424  	re = regexp.MustCompile(`}\n*enum `)
   425  	output = re.ReplaceAll(output, []byte("}\n\nenum "))
   426  	re = regexp.MustCompile(`;\n*message `)
   427  	output = re.ReplaceAll(output, []byte(";\n\nmessage "))
   428  	re = regexp.MustCompile(`}\n*service `)
   429  	return re.ReplaceAll(output, []byte("}\n\nservice "))
   430  }
   431  
   432  func addImports(output []byte) []byte {
   433  	if bytes.Contains(output, []byte("google.protobuf.Any")) {
   434  		output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3";
   435  
   436  import "google/protobuf/any.proto";`), 1)
   437  	}
   438  
   439  	if bytes.Contains(output, []byte("google.protobuf.Empty")) {
   440  		output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3";
   441  
   442  import "google/protobuf/empty.proto";`), 1)
   443  	}
   444  
   445  	if bytes.Contains(output, []byte("google.protobuf.NullValue")) {
   446  		output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3";
   447  
   448  import "google/protobuf/struct.proto";`), 1)
   449  	}
   450  
   451  	match, err := regexp.Match("google.protobuf.(String|Bytes|Int.*|UInt.*|Float|Double)Value", output)
   452  	if err != nil {
   453  		log.Fatal("unable to find wrapper values: ", err)
   454  	}
   455  	if match {
   456  		output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3";
   457  
   458  import "google/protobuf/wrappers.proto";`), 1)
   459  	}
   460  
   461  	return output
   462  }