github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/oapispec/openapi3.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package oapispec
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"fmt"
    23  	"log"
    24  	"net/http"
    25  	"reflect"
    26  	"strconv"
    27  	"strings"
    28  
    29  	"github.com/getkin/kin-openapi/openapi3"
    30  	"github.com/getkin/kin-openapi/openapi3gen"
    31  	"github.com/kaleido-io/firefly/internal/config"
    32  	"github.com/kaleido-io/firefly/internal/i18n"
    33  )
    34  
    35  func getHost() string {
    36  	proto := "https"
    37  	if !config.GetBool(config.HTTPTLSEnabled) {
    38  		proto = "http"
    39  	}
    40  	return fmt.Sprintf("%s://%s:%s", proto, config.GetString(config.HTTPAddress), config.GetString(config.HTTPPort))
    41  }
    42  
    43  func SwaggerGen(ctx context.Context, routes []*Route) *openapi3.T {
    44  
    45  	doc := &openapi3.T{
    46  		OpenAPI: "3.0.2",
    47  		Servers: openapi3.Servers{
    48  			{URL: fmt.Sprintf("%s/api/v1", getHost())},
    49  		},
    50  		Info: &openapi3.Info{
    51  			Title:       "Firefly",
    52  			Version:     "1.0",
    53  			Description: "Copyright © 2021 Kaleido, Inc.",
    54  		},
    55  	}
    56  	opIds := make(map[string]bool)
    57  	for _, route := range routes {
    58  		if route.Name == "" || opIds[route.Name] {
    59  			log.Panicf("Duplicate/invalid name (used as operation ID in swagger): %s", route.Name)
    60  		}
    61  		addRoute(ctx, doc, route)
    62  		opIds[route.Name] = true
    63  	}
    64  	return doc
    65  }
    66  
    67  func getPathItem(doc *openapi3.T, path string) *openapi3.PathItem {
    68  	if !strings.HasPrefix(path, "/") {
    69  		path = "/" + path
    70  	}
    71  	if doc.Paths == nil {
    72  		doc.Paths = openapi3.Paths{}
    73  	}
    74  	pi, ok := doc.Paths[path]
    75  	if ok {
    76  		return pi
    77  	}
    78  	pi = &openapi3.PathItem{}
    79  	doc.Paths[path] = pi
    80  	return pi
    81  }
    82  
    83  func initInput(op *openapi3.Operation) {
    84  	op.RequestBody = &openapi3.RequestBodyRef{
    85  		Value: &openapi3.RequestBody{
    86  			Content: openapi3.Content{},
    87  		},
    88  	}
    89  }
    90  
    91  func addInput(input interface{}, mask []string, schemaDef string, op *openapi3.Operation) {
    92  	var schemaRef *openapi3.SchemaRef
    93  	if schemaDef != "" {
    94  		err := json.Unmarshal([]byte(schemaDef), &schemaRef)
    95  		if err != nil {
    96  			panic(fmt.Sprintf("invalid schema for %T: %s", input, err))
    97  		}
    98  	}
    99  	if schemaRef == nil {
   100  		schemaRef, _, _ = openapi3gen.NewSchemaRefForValue(maskFields(input, mask))
   101  	}
   102  	op.RequestBody.Value.Content["application/json"] = &openapi3.MediaType{
   103  		Schema: schemaRef,
   104  	}
   105  }
   106  
   107  func addFormInput(op *openapi3.Operation) {
   108  	op.RequestBody.Value.Content["multipart/form-data"] = &openapi3.MediaType{
   109  		Schema: &openapi3.SchemaRef{
   110  			Value: &openapi3.Schema{
   111  				Type: "object",
   112  				Properties: openapi3.Schemas{
   113  					"filename.ext": &openapi3.SchemaRef{
   114  						Value: &openapi3.Schema{
   115  							Type:   "string",
   116  							Format: "binary",
   117  						},
   118  					},
   119  				},
   120  			},
   121  		},
   122  	}
   123  }
   124  
   125  func addOutput(ctx context.Context, route *Route, output interface{}, op *openapi3.Operation) {
   126  	schemaRef, _, _ := openapi3gen.NewSchemaRefForValue(output)
   127  	s := i18n.Expand(ctx, i18n.MsgSuccessResponse)
   128  	op.Responses[strconv.FormatInt(int64(route.JSONOutputCode), 10)] = &openapi3.ResponseRef{
   129  		Value: &openapi3.Response{
   130  			Description: &s,
   131  			Content: openapi3.Content{
   132  				"application/json": &openapi3.MediaType{
   133  					Schema: schemaRef,
   134  				},
   135  			},
   136  		},
   137  	}
   138  }
   139  
   140  func addParam(ctx context.Context, op *openapi3.Operation, in, name, def, example string, description i18n.MessageKey, msgArgs ...interface{}) {
   141  	required := false
   142  	if in == "path" {
   143  		required = true
   144  	}
   145  	var defValue interface{}
   146  	if def != "" {
   147  		defValue = &def
   148  	}
   149  	var exampleValue interface{}
   150  	if example != "" {
   151  		exampleValue = example
   152  	}
   153  	op.Parameters = append(op.Parameters, &openapi3.ParameterRef{
   154  		Value: &openapi3.Parameter{
   155  			In:          in,
   156  			Name:        name,
   157  			Required:    required,
   158  			Description: i18n.Expand(ctx, description, msgArgs...),
   159  			Schema: &openapi3.SchemaRef{
   160  				Value: &openapi3.Schema{
   161  					Type:    "string",
   162  					Default: defValue,
   163  					Example: exampleValue,
   164  				},
   165  			},
   166  		},
   167  	})
   168  }
   169  
   170  func addRoute(ctx context.Context, doc *openapi3.T, route *Route) {
   171  	pi := getPathItem(doc, route.Path)
   172  	op := &openapi3.Operation{
   173  		Description: i18n.Expand(ctx, route.Description),
   174  		OperationID: route.Name,
   175  		Responses:   openapi3.NewResponses(),
   176  	}
   177  	if route.Method != http.MethodGet && route.Method != http.MethodDelete {
   178  		var input interface{}
   179  		if route.JSONInputValue != nil {
   180  			input = route.JSONInputValue()
   181  		}
   182  		initInput(op)
   183  		if input != nil {
   184  			addInput(input, route.JSONInputMask, route.JSONInputSchema, op)
   185  		}
   186  		if route.FormUploadHandler != nil {
   187  			addFormInput(op)
   188  		}
   189  	}
   190  	var output interface{}
   191  	if route.JSONOutputValue != nil {
   192  		output = route.JSONOutputValue()
   193  	}
   194  	if output != nil {
   195  		addOutput(ctx, route, output, op)
   196  	}
   197  	for _, p := range route.PathParams {
   198  		example := p.Example
   199  		if p.ExampleFromConf != "" {
   200  			example = config.GetString(p.ExampleFromConf)
   201  		}
   202  		addParam(ctx, op, "path", p.Name, p.Default, example, p.Description)
   203  	}
   204  	for _, q := range route.QueryParams {
   205  		example := q.Example
   206  		if q.ExampleFromConf != "" {
   207  			example = config.GetString(q.ExampleFromConf)
   208  		}
   209  		addParam(ctx, op, "query", q.Name, q.Default, example, q.Description)
   210  	}
   211  	if route.FilterFactory != nil {
   212  		for _, field := range route.FilterFactory.NewFilter(ctx).Fields() {
   213  			addParam(ctx, op, "query", field, "", "", i18n.MsgFilterParamDesc)
   214  		}
   215  		addParam(ctx, op, "query", "sort", "", "", i18n.MsgFilterSortDesc)
   216  		addParam(ctx, op, "query", "descending", "", "", i18n.MsgFilterDescendingDesc)
   217  		addParam(ctx, op, "query", "skip", "", "", i18n.MsgFilterSkipDesc, config.GetUint(config.APIMaxFilterSkip))
   218  		addParam(ctx, op, "query", "limit", "", config.GetString(config.APIDefaultFilterLimit), i18n.MsgFilterLimitDesc, config.GetUint(config.APIMaxFilterLimit))
   219  	}
   220  	switch route.Method {
   221  	case http.MethodGet:
   222  		pi.Get = op
   223  	case http.MethodPut:
   224  		pi.Put = op
   225  	case http.MethodPost:
   226  		pi.Post = op
   227  	case http.MethodDelete:
   228  		pi.Delete = op
   229  	}
   230  }
   231  
   232  func maskFieldsOnStruct(t reflect.Type, mask []string) reflect.Type {
   233  	fieldCount := t.NumField()
   234  	newFields := make([]reflect.StructField, fieldCount)
   235  	for i := 0; i < fieldCount; i++ {
   236  		field := t.FieldByIndex([]int{i})
   237  		if field.Type.Kind() == reflect.Struct {
   238  			field.Type = maskFieldsOnStruct(field.Type, mask)
   239  		} else {
   240  			for _, m := range mask {
   241  				if strings.EqualFold(field.Name, m) {
   242  					field.Tag = "`json:-`"
   243  				}
   244  			}
   245  		}
   246  		newFields[i] = field
   247  	}
   248  	return reflect.StructOf(newFields)
   249  }
   250  
   251  func maskFields(input interface{}, mask []string) interface{} {
   252  	t := reflect.TypeOf(input)
   253  	newStruct := maskFieldsOnStruct(t.Elem(), mask)
   254  	i := reflect.New(newStruct).Interface()
   255  	return i
   256  }