github.com/cbroglie/openapi2proto@v0.0.0-20171004221549-76b8501da882/proto.go (about) 1 package openapi2proto 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "log" 9 "net/http" 10 "os" 11 "path" 12 "regexp" 13 "sort" 14 "strconv" 15 "strings" 16 "text/template" 17 18 "github.com/pkg/errors" 19 20 yaml "gopkg.in/yaml.v2" 21 ) 22 23 func getPathItems(p *Path) []*Items { 24 var items []*Items 25 if p.Get != nil { 26 items = append(items, getEndpointItems(p.Get)...) 27 } 28 if p.Put != nil { 29 items = append(items, getEndpointItems(p.Put)...) 30 } 31 if p.Post != nil { 32 items = append(items, getEndpointItems(p.Post)...) 33 } 34 if p.Delete != nil { 35 items = append(items, getEndpointItems(p.Delete)...) 36 } 37 return items 38 } 39 40 func getEndpointItems(e *Endpoint) []*Items { 41 items := make([]*Items, len(e.Parameters)) 42 for i, itm := range e.Parameters { 43 // add the request params 44 items[i] = itm 45 } 46 // and the response 47 var ok bool 48 var res *Response 49 res, ok = e.Responses["200"] 50 if !ok { 51 res, ok = e.Responses["201"] 52 } 53 if !ok { 54 return items 55 } 56 if res.Schema != nil { 57 items = append(items, res.Schema) 58 } 59 return items 60 } 61 62 func LoadDefinition(pth string) (*APIDefinition, error) { 63 var ( 64 b []byte 65 err error 66 ) 67 // url? fetch it 68 if strings.HasPrefix(pth, "http") { 69 res, err := http.Get(pth) 70 if err != nil { 71 log.Printf("unable to fetch path: %s - %s", pth, err) 72 os.Exit(1) 73 } 74 defer res.Body.Close() 75 76 b, err = ioutil.ReadAll(res.Body) 77 if err != nil { 78 log.Printf("unable to read from path: %s - %s", pth, err) 79 os.Exit(1) 80 } 81 if res.StatusCode != http.StatusOK { 82 log.Print("unable to get remote definition: ", string(b)) 83 os.Exit(1) 84 } 85 } else { 86 b, err = ioutil.ReadFile(pth) 87 if err != nil { 88 log.Print("unable to read spec file: ", err) 89 os.Exit(1) 90 } 91 92 } 93 94 var api *APIDefinition 95 isYaml := path.Ext(pth) == ".yaml" 96 if isYaml { 97 err = yaml.Unmarshal(b, &api) 98 } else { 99 err = json.Unmarshal(b, &api) 100 } 101 if err != nil { 102 return nil, errors.Wrap(err, "unable to parse referened file") 103 } 104 105 // no paths or defs declared? 106 // check if this is a plain map[name]*Items (definitions) 107 if len(api.Paths) == 0 && len(api.Definitions) == 0 { 108 var defs map[string]*Items 109 if isYaml { 110 err = yaml.Unmarshal(b, &defs) 111 } else { 112 err = json.Unmarshal(b, &defs) 113 } 114 _, nok := defs["type"] 115 if err == nil && !nok { 116 api.Definitions = defs 117 } 118 } 119 120 // _still_ no defs? try to see if this is a single item 121 // check if its just an *Item 122 if len(api.Paths) == 0 && len(api.Definitions) == 0 { 123 var item Items 124 if isYaml { 125 err = yaml.Unmarshal(b, &item) 126 } else { 127 err = json.Unmarshal(b, &item) 128 } 129 if err != nil { 130 return nil, errors.Wrap(err, "unable to load referenced item") 131 } 132 api.Definitions = map[string]*Items{strings.TrimSuffix(path.Base(pth), path.Ext(pth)): &item} 133 } 134 135 api.FileName = pth 136 137 return api, nil 138 } 139 140 // GenerateProto will attempt to generate an protobuf version 3 141 // schema from the given OpenAPI definition. 142 func GenerateProto(api *APIDefinition, annotate bool) ([]byte, error) { 143 if api.Definitions == nil { 144 api.Definitions = map[string]*Items{} 145 } 146 // jam all the parameters into the normal 'definitions' for easier reference. 147 for name, param := range api.Parameters { 148 api.Definitions[name] = param 149 } 150 151 // at this point, traverse imports to find possible nested definition references 152 // inline external $refs 153 imports, err := importsAndRefs(api) 154 if err != nil { 155 log.Fatal(err) 156 } 157 158 // if no package name given, default to filename 159 if api.Info.Title == "" { 160 api.Info.Title = strings.TrimSuffix(path.Base(api.FileName), 161 path.Ext(api.FileName)) 162 } 163 164 var out bytes.Buffer 165 data := struct { 166 *APIDefinition 167 Annotate bool 168 Imports []string 169 }{ 170 api, annotate, imports, 171 } 172 err = protoFileTmpl.Execute(&out, data) 173 if err != nil { 174 return nil, fmt.Errorf("unable to generate protobuf schema: %s", err) 175 } 176 return cleanSpacing(addImports(out.Bytes())), nil 177 } 178 179 func importsAndRefs(api *APIDefinition) ([]string, error) { 180 var imports []string 181 // determine external imports by traversing struct, looking for $refs 182 for _, def := range api.Definitions { 183 defs, err := replaceExternalRefs(def) 184 if err != nil { 185 return imports, errors.Wrap(err, "unable to replace external refs in definitions") 186 } 187 for k, v := range defs { 188 api.Definitions[k] = v 189 } 190 imports = append(imports, traverseItemsForImports(def, api.Definitions)...) 191 } 192 193 for _, pth := range api.Paths { 194 for _, itm := range getPathItems(pth) { 195 defs, err := replaceExternalRefs(itm) 196 if err != nil { 197 return imports, errors.Wrap(err, "unable to replace external refs in path") 198 } 199 for k, v := range defs { 200 api.Definitions[k] = v 201 } 202 imports = append(imports, traverseItemsForImports(itm, api.Definitions)...) 203 } 204 } 205 sort.Strings(imports) 206 var impts []string 207 // dedupe 208 var last string 209 for _, i := range imports { 210 if i != last { 211 impts = append(impts, i) 212 } 213 last = i 214 } 215 return imports, nil 216 } 217 218 func replaceExternalRefs(item *Items) (map[string]*Items, error) { 219 defs := map[string]*Items{} 220 if item.Ref != "" { 221 possSpecPath, name := refDatas(item.Ref) 222 // if it's an OpenAPI spec, try reading it in 223 if name == "" { // path#/type 224 name = strings.TrimSuffix(name, path.Ext(name)) 225 } 226 if possSpecPath != "" && (path.Ext(possSpecPath) != ".proto") { 227 def, err := LoadDefinition(possSpecPath) 228 if err == nil { 229 if len(def.Definitions) > 0 { 230 for nam, v := range def.Definitions { 231 if name == nam { 232 *item = *v 233 } 234 if v.Type == "object" { 235 defs[nam] = v 236 } 237 } 238 } 239 } 240 } 241 } 242 if item.Schema != nil && item.Schema.Ref != "" { 243 possSpecPath, name := refDatas(item.Schema.Ref) 244 // if it's an OpenAPI spec, try reading it in 245 if name == "" { // path#/type 246 name = strings.Title(strings.TrimSuffix(item.Schema.Ref, path.Ext(item.Schema.Ref))) 247 } 248 if possSpecPath != "" && (path.Ext(possSpecPath) != ".proto") { 249 def, err := LoadDefinition(possSpecPath) 250 if err == nil { 251 item.Schema.Ref = "#/definitions/" + name 252 for k, v := range def.Definitions { 253 defs[k] = v 254 } 255 } 256 } 257 } 258 for _, itm := range item.Model.Properties { 259 ds, err := replaceExternalRefs(itm) 260 if err != nil { 261 return nil, errors.Wrap(err, "unable to replace external spec refs") 262 } 263 for k, v := range ds { 264 defs[k] = v 265 } 266 } 267 if item.Items != nil { 268 ds, err := replaceExternalRefs(item.Items) 269 if err != nil { 270 return nil, errors.Wrap(err, "unable to replace external spec refs") 271 } 272 for k, v := range ds { 273 defs[k] = v 274 } 275 } 276 if item.AdditionalProperties != nil { 277 ds, err := replaceExternalRefs(item.AdditionalProperties) 278 if err != nil { 279 return nil, errors.Wrap(err, "unable to replace external spec refs") 280 } 281 for k, v := range ds { 282 defs[k] = v 283 } 284 } 285 return defs, nil 286 } 287 288 func traverseItemsForImports(item *Items, defs map[string]*Items) []string { 289 imports := map[string]struct{}{} 290 if item.Ref != "" { 291 _, pkg := refType(item.Ref, defs) 292 impt, _ := refDatas(item.Ref) 293 pext := path.Ext(impt) 294 if (pkg != "" && (path.Ext(item.Ref) == "")) || pext == ".proto" { 295 imports[pkg] = struct{}{} 296 } 297 } 298 for _, itm := range item.Model.Properties { 299 for _, impt := range traverseItemsForImports(itm, defs) { 300 imports[impt] = struct{}{} 301 } 302 } 303 if item.Items != nil { 304 for _, impt := range traverseItemsForImports(item.Items, defs) { 305 imports[impt] = struct{}{} 306 } 307 } 308 if item.AdditionalProperties != nil { 309 for _, impt := range traverseItemsForImports(item.AdditionalProperties, defs) { 310 imports[impt] = struct{}{} 311 } 312 } 313 var out []string 314 for impt, _ := range imports { 315 out = append(out, impt) 316 } 317 return out 318 } 319 320 const protoFileTmplStr = `syntax = "proto3"; 321 {{ $defs := .Definitions }}{{ $annotate := .Annotate }}{{ if $annotate }} 322 import "google/api/annotations.proto"; 323 {{ end }}{{ range $import := .Imports }} 324 import "{{ $import }}"; 325 {{ end }} 326 package {{ packageName .Info.Title }}; 327 {{ range $path, $endpoint := .Paths }} 328 {{ $endpoint.ProtoMessages $path $defs }} 329 {{ end }} 330 {{ range $modelName, $model := $defs }} 331 {{ $model.ProtoMessage "" $modelName $defs counter -1 }} 332 {{ end }}{{ $basePath := .BasePath }} 333 {{ if len .Paths }}service {{ serviceName .Info.Title }} {{"{"}}{{ range $path, $endpoint := .Paths }} 334 {{ $endpoint.ProtoEndpoints $annotate $basePath $path }}{{ end }} 335 }{{ end }} 336 ` 337 338 const protoEndpointTmplStr = `{{ if .HasComment }}{{ .Comment }}{{ end }} rpc {{ .Name }}({{ .RequestName }}) returns ({{ .ResponseName }}) {{"{"}}{{ if .Annotate }} 339 option (google.api.http) = { 340 {{ .Method }}: "{{ .Path }}"{{ if .IncludeBody }} 341 body: "{{ .BodyAttr }}"{{ end }} 342 }; 343 {{ end }}{{"}"}}` 344 345 const protoMsgTmplStr = `{{ $i := counter }}{{ $defs := .Defs }}{{ $msgName := .Name }}{{ $depth := .Depth }}message {{ .Name }} {{"{"}}{{ range $propName, $prop := .Properties }} 346 {{ indent $depth }}{{ if $prop.HasComment }}{{ indent $depth }}{{ $prop.Comment }}{{ end }} {{ $prop.ProtoMessage $msgName $propName $defs $i $depth }};{{ end }} 347 {{ indent $depth }}}` 348 349 const protoEnumTmplStr = `{{ $i := zcounter }}{{ $depth := .Depth }}{{ $name := .Name }}enum {{ .Name }} {{"{"}}{{ range $index, $pName := .Enum }} 350 {{ indent $depth }} {{ toEnum $name $pName $depth }} = {{ inc $i }};{{ end }} 351 {{ indent $depth }}}` 352 353 var funcMap = template.FuncMap{ 354 "inc": inc, 355 "counter": counter, 356 "zcounter": zcounter, 357 "indent": indent, 358 "toEnum": toEnum, 359 "packageName": packageName, 360 "serviceName": serviceName, 361 "PathMethodToName": PathMethodToName, 362 } 363 364 func packageName(t string) string { 365 return strings.ToLower(strings.Join(strings.Fields(t), "")) 366 } 367 368 func serviceName(t string) string { 369 var name string 370 for _, nme := range strings.Fields(t) { 371 name += strings.Title(nme) 372 } 373 return name + "Service" 374 } 375 376 func counter() *int { 377 i := 0 378 return &i 379 } 380 func zcounter() *int { 381 i := -1 382 return &i 383 } 384 385 func inc(i *int) int { 386 *i++ 387 return *i 388 } 389 390 func indent(depth int) string { 391 var out string 392 for i := 0; i < depth; i++ { 393 out += " " 394 } 395 return out 396 } 397 398 func toEnum(name, enum string, depth int) string { 399 if strings.TrimSpace(enum) == "" { 400 enum = "empty" 401 } 402 e := enum 403 if _, err := strconv.Atoi(enum); err == nil || depth > 0 { 404 e = name + "_" + enum 405 } 406 e = strings.Replace(e, " & ", " AND ", -1) 407 e = strings.Replace(e, "&", "_AND_", -1) 408 e = strings.Replace(e, " ", "_", -1) 409 re := regexp.MustCompile(`[%\{\}\[\]()/\.'’-]`) 410 e = re.ReplaceAllString(e, "") 411 return strings.ToUpper(e) 412 } 413 414 var ( 415 protoFileTmpl = template.Must(template.New("protoFile").Funcs(funcMap).Parse(protoFileTmplStr)) 416 protoMsgTmpl = template.Must(template.New("protoMsg").Funcs(funcMap).Parse(protoMsgTmplStr)) 417 protoEndpointTmpl = template.Must(template.New("protoEndpoint").Funcs(funcMap).Parse(protoEndpointTmplStr)) 418 protoEnumTmpl = template.Must(template.New("protoEnum").Funcs(funcMap).Parse(protoEnumTmplStr)) 419 ) 420 421 func cleanSpacing(output []byte) []byte { 422 re := regexp.MustCompile(`}\n*message `) 423 output = re.ReplaceAll(output, []byte("}\n\nmessage ")) 424 re = regexp.MustCompile(`}\n*enum `) 425 output = re.ReplaceAll(output, []byte("}\n\nenum ")) 426 re = regexp.MustCompile(`;\n*message `) 427 output = re.ReplaceAll(output, []byte(";\n\nmessage ")) 428 re = regexp.MustCompile(`}\n*service `) 429 return re.ReplaceAll(output, []byte("}\n\nservice ")) 430 } 431 432 func addImports(output []byte) []byte { 433 if bytes.Contains(output, []byte("google.protobuf.Any")) { 434 output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3"; 435 436 import "google/protobuf/any.proto";`), 1) 437 } 438 439 if bytes.Contains(output, []byte("google.protobuf.Empty")) { 440 output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3"; 441 442 import "google/protobuf/empty.proto";`), 1) 443 } 444 445 if bytes.Contains(output, []byte("google.protobuf.NullValue")) { 446 output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3"; 447 448 import "google/protobuf/struct.proto";`), 1) 449 } 450 451 match, err := regexp.Match("google.protobuf.(String|Bytes|Int.*|UInt.*|Float|Double)Value", output) 452 if err != nil { 453 log.Fatal("unable to find wrapper values: ", err) 454 } 455 if match { 456 output = bytes.Replace(output, []byte(`"proto3";`), []byte(`"proto3"; 457 458 import "google/protobuf/wrappers.proto";`), 1) 459 } 460 461 return output 462 }