github.com/josephspurrier/go-swagger@v0.2.1-0.20221129144919-1f672a142a00/codescan/parameters.go (about)

     1  package codescan
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/types"
     7  	"strings"
     8  
     9  	"golang.org/x/tools/go/ast/astutil"
    10  
    11  	"github.com/pkg/errors"
    12  
    13  	"github.com/go-openapi/spec"
    14  )
    15  
    16  type paramTypable struct {
    17  	param *spec.Parameter
    18  }
    19  
    20  func (pt paramTypable) Level() int { return 0 }
    21  
    22  func (pt paramTypable) Typed(tpe, format string) {
    23  	pt.param.Typed(tpe, format)
    24  }
    25  
    26  func (pt paramTypable) SetRef(ref spec.Ref) {
    27  	pt.param.Ref = ref
    28  }
    29  
    30  func (pt paramTypable) Items() swaggerTypable {
    31  	bdt, schema := bodyTypable(pt.param.In, pt.param.Schema)
    32  	if bdt != nil {
    33  		pt.param.Schema = schema
    34  		return bdt
    35  	}
    36  
    37  	if pt.param.Items == nil {
    38  		pt.param.Items = new(spec.Items)
    39  	}
    40  	pt.param.Type = "array"
    41  	return itemsTypable{pt.param.Items, 1}
    42  }
    43  
    44  func (pt paramTypable) Schema() *spec.Schema {
    45  	if pt.param.In != "body" {
    46  		return nil
    47  	}
    48  	if pt.param.Schema == nil {
    49  		pt.param.Schema = new(spec.Schema)
    50  	}
    51  	return pt.param.Schema
    52  }
    53  
    54  func (pt paramTypable) AddExtension(key string, value interface{}) {
    55  	if pt.param.In == "body" {
    56  		pt.Schema().AddExtension(key, value)
    57  	} else {
    58  		pt.param.AddExtension(key, value)
    59  	}
    60  }
    61  
    62  func (pt paramTypable) WithEnum(values ...interface{}) {
    63  	pt.param.WithEnum(values...)
    64  }
    65  
    66  func (pt paramTypable) WithEnumDescription(desc string) {
    67  	if desc == "" {
    68  		return
    69  	}
    70  	pt.param.AddExtension(extEnumDesc, desc)
    71  }
    72  
    73  type itemsTypable struct {
    74  	items *spec.Items
    75  	level int
    76  }
    77  
    78  func (pt itemsTypable) Level() int { return pt.level }
    79  
    80  func (pt itemsTypable) Typed(tpe, format string) {
    81  	pt.items.Typed(tpe, format)
    82  }
    83  
    84  func (pt itemsTypable) SetRef(ref spec.Ref) {
    85  	pt.items.Ref = ref
    86  }
    87  
    88  func (pt itemsTypable) Schema() *spec.Schema {
    89  	return nil
    90  }
    91  
    92  func (pt itemsTypable) Items() swaggerTypable {
    93  	if pt.items.Items == nil {
    94  		pt.items.Items = new(spec.Items)
    95  	}
    96  	pt.items.Type = "array"
    97  	return itemsTypable{pt.items.Items, pt.level + 1}
    98  }
    99  
   100  func (pt itemsTypable) AddExtension(key string, value interface{}) {
   101  	pt.items.AddExtension(key, value)
   102  }
   103  
   104  func (pt itemsTypable) WithEnum(values ...interface{}) {
   105  	pt.items.WithEnum(values...)
   106  }
   107  
   108  func (pt itemsTypable) WithEnumDescription(desc string) {
   109  	// no
   110  }
   111  
   112  type paramValidations struct {
   113  	current *spec.Parameter
   114  }
   115  
   116  func (sv paramValidations) SetMaximum(val float64, exclusive bool) {
   117  	sv.current.Maximum = &val
   118  	sv.current.ExclusiveMaximum = exclusive
   119  }
   120  func (sv paramValidations) SetMinimum(val float64, exclusive bool) {
   121  	sv.current.Minimum = &val
   122  	sv.current.ExclusiveMinimum = exclusive
   123  }
   124  func (sv paramValidations) SetMultipleOf(val float64)      { sv.current.MultipleOf = &val }
   125  func (sv paramValidations) SetMinItems(val int64)          { sv.current.MinItems = &val }
   126  func (sv paramValidations) SetMaxItems(val int64)          { sv.current.MaxItems = &val }
   127  func (sv paramValidations) SetMinLength(val int64)         { sv.current.MinLength = &val }
   128  func (sv paramValidations) SetMaxLength(val int64)         { sv.current.MaxLength = &val }
   129  func (sv paramValidations) SetPattern(val string)          { sv.current.Pattern = val }
   130  func (sv paramValidations) SetUnique(val bool)             { sv.current.UniqueItems = val }
   131  func (sv paramValidations) SetCollectionFormat(val string) { sv.current.CollectionFormat = val }
   132  func (sv paramValidations) SetEnum(val string) {
   133  	sv.current.Enum = parseEnum(val, &spec.SimpleSchema{Type: sv.current.Type, Format: sv.current.Format})
   134  }
   135  func (sv paramValidations) SetDefault(val interface{}) { sv.current.Default = val }
   136  func (sv paramValidations) SetExample(val interface{}) { sv.current.Example = val }
   137  
   138  type itemsValidations struct {
   139  	current *spec.Items
   140  }
   141  
   142  func (sv itemsValidations) SetMaximum(val float64, exclusive bool) {
   143  	sv.current.Maximum = &val
   144  	sv.current.ExclusiveMaximum = exclusive
   145  }
   146  func (sv itemsValidations) SetMinimum(val float64, exclusive bool) {
   147  	sv.current.Minimum = &val
   148  	sv.current.ExclusiveMinimum = exclusive
   149  }
   150  func (sv itemsValidations) SetMultipleOf(val float64)      { sv.current.MultipleOf = &val }
   151  func (sv itemsValidations) SetMinItems(val int64)          { sv.current.MinItems = &val }
   152  func (sv itemsValidations) SetMaxItems(val int64)          { sv.current.MaxItems = &val }
   153  func (sv itemsValidations) SetMinLength(val int64)         { sv.current.MinLength = &val }
   154  func (sv itemsValidations) SetMaxLength(val int64)         { sv.current.MaxLength = &val }
   155  func (sv itemsValidations) SetPattern(val string)          { sv.current.Pattern = val }
   156  func (sv itemsValidations) SetUnique(val bool)             { sv.current.UniqueItems = val }
   157  func (sv itemsValidations) SetCollectionFormat(val string) { sv.current.CollectionFormat = val }
   158  func (sv itemsValidations) SetEnum(val string) {
   159  	sv.current.Enum = parseEnum(val, &spec.SimpleSchema{Type: sv.current.Type, Format: sv.current.Format})
   160  }
   161  func (sv itemsValidations) SetDefault(val interface{}) { sv.current.Default = val }
   162  func (sv itemsValidations) SetExample(val interface{}) { sv.current.Example = val }
   163  
   164  type parameterBuilder struct {
   165  	ctx       *scanCtx
   166  	decl      *entityDecl
   167  	postDecls []*entityDecl
   168  }
   169  
   170  func (p *parameterBuilder) Build(operations map[string]*spec.Operation) error {
   171  
   172  	// check if there is a swagger:parameters tag that is followed by one or more words,
   173  	// these words are the ids of the operations this parameter struct applies to
   174  	// once type name is found convert it to a schema, by looking up the schema in the
   175  	// parameters dictionary that got passed into this parse method
   176  	for _, opid := range p.decl.OperationIDS() {
   177  		operation, ok := operations[opid]
   178  		if !ok {
   179  			operation = new(spec.Operation)
   180  			operations[opid] = operation
   181  			operation.ID = opid
   182  		}
   183  		debugLog("building parameters for: %s", opid)
   184  
   185  		// analyze struct body for fields etc
   186  		// each exported struct field:
   187  		// * gets a type mapped to a go primitive
   188  		// * perhaps gets a format
   189  		// * has to document the validations that apply for the type and the field
   190  		// * when the struct field points to a model it becomes a ref: #/definitions/ModelName
   191  		// * comments that aren't tags is used as the description
   192  		if err := p.buildFromType(p.decl.Type, operation, make(map[string]spec.Parameter)); err != nil {
   193  			return err
   194  		}
   195  	}
   196  	return nil
   197  }
   198  
   199  func (p *parameterBuilder) buildFromType(otpe types.Type, op *spec.Operation, seen map[string]spec.Parameter) error {
   200  	switch tpe := otpe.(type) {
   201  	case *types.Pointer:
   202  		return p.buildFromType(tpe.Elem(), op, seen)
   203  	case *types.Named:
   204  		o := tpe.Obj()
   205  		switch stpe := o.Type().Underlying().(type) {
   206  		case *types.Struct:
   207  			debugLog("build from type %s: %T", tpe.Obj().Name(), otpe)
   208  			if decl, found := p.ctx.DeclForType(o.Type()); found {
   209  				return p.buildFromStruct(decl, stpe, op, seen)
   210  			}
   211  			return p.buildFromStruct(p.decl, stpe, op, seen)
   212  		default:
   213  			return errors.Errorf("unhandled type (%T): %s", stpe, o.Type().Underlying().String())
   214  		}
   215  	default:
   216  		return errors.Errorf("unhandled type (%T): %s", otpe, tpe.String())
   217  	}
   218  }
   219  
   220  func (p *parameterBuilder) buildFromField(fld *types.Var, tpe types.Type, typable swaggerTypable, seen map[string]spec.Parameter) error {
   221  	debugLog("build from field %s: %T", fld.Name(), tpe)
   222  	switch ftpe := tpe.(type) {
   223  	case *types.Basic:
   224  		return swaggerSchemaForType(ftpe.Name(), typable)
   225  	case *types.Struct:
   226  		sb := schemaBuilder{
   227  			decl: p.decl,
   228  			ctx:  p.ctx,
   229  		}
   230  		if err := sb.buildFromType(tpe, typable); err != nil {
   231  			return err
   232  		}
   233  		p.postDecls = append(p.postDecls, sb.postDecls...)
   234  		return nil
   235  	case *types.Pointer:
   236  		return p.buildFromField(fld, ftpe.Elem(), typable, seen)
   237  	case *types.Interface:
   238  		sb := schemaBuilder{
   239  			decl: p.decl,
   240  			ctx:  p.ctx,
   241  		}
   242  		if err := sb.buildFromType(tpe, typable); err != nil {
   243  			return err
   244  		}
   245  		p.postDecls = append(p.postDecls, sb.postDecls...)
   246  		return nil
   247  	case *types.Array:
   248  		return p.buildFromField(fld, ftpe.Elem(), typable.Items(), seen)
   249  	case *types.Slice:
   250  		return p.buildFromField(fld, ftpe.Elem(), typable.Items(), seen)
   251  	case *types.Map:
   252  		schema := new(spec.Schema)
   253  		typable.Schema().Typed("object", "").AdditionalProperties = &spec.SchemaOrBool{
   254  			Schema: schema,
   255  		}
   256  		sb := schemaBuilder{
   257  			decl: p.decl,
   258  			ctx:  p.ctx,
   259  		}
   260  		if err := sb.buildFromType(ftpe.Elem(), schemaTypable{schema, typable.Level() + 1}); err != nil {
   261  			return err
   262  		}
   263  		return nil
   264  	case *types.Named:
   265  		if decl, found := p.ctx.DeclForType(ftpe.Obj().Type()); found {
   266  			if decl.Type.Obj().Pkg().Path() == "time" && decl.Type.Obj().Name() == "Time" {
   267  				typable.Typed("string", "date-time")
   268  				return nil
   269  			}
   270  			if sfnm, isf := strfmtName(decl.Comments); isf {
   271  				typable.Typed("string", sfnm)
   272  				return nil
   273  			}
   274  			sb := &schemaBuilder{ctx: p.ctx, decl: decl}
   275  			sb.inferNames()
   276  			if err := sb.buildFromType(decl.Type, typable); err != nil {
   277  				return err
   278  			}
   279  			p.postDecls = append(p.postDecls, sb.postDecls...)
   280  			return nil
   281  		}
   282  		return errors.Errorf("unable to find package and source file for: %s", ftpe.String())
   283  	default:
   284  		return errors.Errorf("unknown type for %s: %T", fld.String(), fld.Type())
   285  	}
   286  }
   287  
   288  func spExtensionsSetter(ps *spec.Parameter) func(*spec.Extensions) {
   289  	return func(exts *spec.Extensions) {
   290  		for name, value := range *exts {
   291  			addExtension(&ps.VendorExtensible, name, value)
   292  		}
   293  	}
   294  }
   295  
   296  func (p *parameterBuilder) buildFromStruct(decl *entityDecl, tpe *types.Struct, op *spec.Operation, seen map[string]spec.Parameter) error {
   297  	if tpe.NumFields() == 0 {
   298  		return nil
   299  	}
   300  
   301  	var sequence []string
   302  
   303  	for i := 0; i < tpe.NumFields(); i++ {
   304  		fld := tpe.Field(i)
   305  
   306  		if fld.Embedded() {
   307  			if err := p.buildFromType(fld.Type(), op, seen); err != nil {
   308  				return err
   309  			}
   310  			continue
   311  		}
   312  
   313  		if !fld.Exported() {
   314  			debugLog("skipping field %s because it's not exported", fld.Name())
   315  			continue
   316  		}
   317  
   318  		tg := tpe.Tag(i)
   319  
   320  		var afld *ast.Field
   321  		ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos())
   322  		for _, an := range ans {
   323  			at, valid := an.(*ast.Field)
   324  			if !valid {
   325  				continue
   326  			}
   327  
   328  			debugLog("field %s: %s(%T) [%q] ==> %s", fld.Name(), fld.Type().String(), fld.Type(), tg, at.Doc.Text())
   329  			afld = at
   330  			break
   331  		}
   332  
   333  		if afld == nil {
   334  			debugLog("can't find source associated with %s for %s", fld.String(), tpe.String())
   335  			continue
   336  		}
   337  
   338  		// if the field is annotated with swagger:ignore, ignore it
   339  		if ignored(afld.Doc) {
   340  			continue
   341  		}
   342  
   343  		name, ignore, _, err := parseJSONTag(afld)
   344  		if err != nil {
   345  			return err
   346  		}
   347  		if ignore {
   348  			continue
   349  		}
   350  
   351  		in := "query"
   352  		// scan for param location first, this changes some behavior down the line
   353  		if afld.Doc != nil {
   354  			for _, cmt := range afld.Doc.List {
   355  				for _, line := range strings.Split(cmt.Text, "\n") {
   356  					matches := rxIn.FindStringSubmatch(line)
   357  					if len(matches) > 0 && len(strings.TrimSpace(matches[1])) > 0 {
   358  						in = strings.TrimSpace(matches[1])
   359  					}
   360  				}
   361  			}
   362  		}
   363  
   364  		ps := seen[name]
   365  		ps.In = in
   366  		var pty swaggerTypable = paramTypable{&ps}
   367  		if in == "body" {
   368  			pty = schemaTypable{pty.Schema(), 0}
   369  		}
   370  		if in == "formData" && afld.Doc != nil && fileParam(afld.Doc) {
   371  			pty.Typed("file", "")
   372  		} else if err := p.buildFromField(fld, fld.Type(), pty, seen); err != nil {
   373  			return err
   374  		}
   375  
   376  		if strfmtName, ok := strfmtName(afld.Doc); ok {
   377  			ps.Typed("string", strfmtName)
   378  			ps.Ref = spec.Ref{}
   379  			ps.Items = nil
   380  		}
   381  
   382  		sp := new(sectionedParser)
   383  		sp.setDescription = func(lines []string) {
   384  			ps.Description = joinDropLast(lines)
   385  			enumDesc := getEnumDesc(ps.Extensions)
   386  			if enumDesc != "" {
   387  				ps.Description += "\n" + enumDesc
   388  			}
   389  		}
   390  		if ps.Ref.String() == "" {
   391  			sp.taggers = []tagParser{
   392  				newSingleLineTagParser("in", &matchOnlyParam{&ps, rxIn}),
   393  				newSingleLineTagParser("maximum", &setMaximum{paramValidations{&ps}, rxf(rxMaximumFmt, "")}),
   394  				newSingleLineTagParser("minimum", &setMinimum{paramValidations{&ps}, rxf(rxMinimumFmt, "")}),
   395  				newSingleLineTagParser("multipleOf", &setMultipleOf{paramValidations{&ps}, rxf(rxMultipleOfFmt, "")}),
   396  				newSingleLineTagParser("minLength", &setMinLength{paramValidations{&ps}, rxf(rxMinLengthFmt, "")}),
   397  				newSingleLineTagParser("maxLength", &setMaxLength{paramValidations{&ps}, rxf(rxMaxLengthFmt, "")}),
   398  				newSingleLineTagParser("pattern", &setPattern{paramValidations{&ps}, rxf(rxPatternFmt, "")}),
   399  				newSingleLineTagParser("collectionFormat", &setCollectionFormat{paramValidations{&ps}, rxf(rxCollectionFormatFmt, "")}),
   400  				newSingleLineTagParser("minItems", &setMinItems{paramValidations{&ps}, rxf(rxMinItemsFmt, "")}),
   401  				newSingleLineTagParser("maxItems", &setMaxItems{paramValidations{&ps}, rxf(rxMaxItemsFmt, "")}),
   402  				newSingleLineTagParser("unique", &setUnique{paramValidations{&ps}, rxf(rxUniqueFmt, "")}),
   403  				newSingleLineTagParser("enum", &setEnum{paramValidations{&ps}, rxf(rxEnumFmt, "")}),
   404  				newSingleLineTagParser("default", &setDefault{&ps.SimpleSchema, paramValidations{&ps}, rxf(rxDefaultFmt, "")}),
   405  				newSingleLineTagParser("example", &setExample{&ps.SimpleSchema, paramValidations{&ps}, rxf(rxExampleFmt, "")}),
   406  				newSingleLineTagParser("required", &setRequiredParam{&ps}),
   407  				newMultiLineTagParser("Extensions", newSetExtensions(spExtensionsSetter(&ps)), true),
   408  			}
   409  
   410  			itemsTaggers := func(items *spec.Items, level int) []tagParser {
   411  				// the expression is 1-index based not 0-index
   412  				itemsPrefix := fmt.Sprintf(rxItemsPrefixFmt, level+1)
   413  
   414  				return []tagParser{
   415  					newSingleLineTagParser(fmt.Sprintf("items%dMaximum", level), &setMaximum{itemsValidations{items}, rxf(rxMaximumFmt, itemsPrefix)}),
   416  					newSingleLineTagParser(fmt.Sprintf("items%dMinimum", level), &setMinimum{itemsValidations{items}, rxf(rxMinimumFmt, itemsPrefix)}),
   417  					newSingleLineTagParser(fmt.Sprintf("items%dMultipleOf", level), &setMultipleOf{itemsValidations{items}, rxf(rxMultipleOfFmt, itemsPrefix)}),
   418  					newSingleLineTagParser(fmt.Sprintf("items%dMinLength", level), &setMinLength{itemsValidations{items}, rxf(rxMinLengthFmt, itemsPrefix)}),
   419  					newSingleLineTagParser(fmt.Sprintf("items%dMaxLength", level), &setMaxLength{itemsValidations{items}, rxf(rxMaxLengthFmt, itemsPrefix)}),
   420  					newSingleLineTagParser(fmt.Sprintf("items%dPattern", level), &setPattern{itemsValidations{items}, rxf(rxPatternFmt, itemsPrefix)}),
   421  					newSingleLineTagParser(fmt.Sprintf("items%dCollectionFormat", level), &setCollectionFormat{itemsValidations{items}, rxf(rxCollectionFormatFmt, itemsPrefix)}),
   422  					newSingleLineTagParser(fmt.Sprintf("items%dMinItems", level), &setMinItems{itemsValidations{items}, rxf(rxMinItemsFmt, itemsPrefix)}),
   423  					newSingleLineTagParser(fmt.Sprintf("items%dMaxItems", level), &setMaxItems{itemsValidations{items}, rxf(rxMaxItemsFmt, itemsPrefix)}),
   424  					newSingleLineTagParser(fmt.Sprintf("items%dUnique", level), &setUnique{itemsValidations{items}, rxf(rxUniqueFmt, itemsPrefix)}),
   425  					newSingleLineTagParser(fmt.Sprintf("items%dEnum", level), &setEnum{itemsValidations{items}, rxf(rxEnumFmt, itemsPrefix)}),
   426  					newSingleLineTagParser(fmt.Sprintf("items%dDefault", level), &setDefault{&items.SimpleSchema, itemsValidations{items}, rxf(rxDefaultFmt, itemsPrefix)}),
   427  					newSingleLineTagParser(fmt.Sprintf("items%dExample", level), &setExample{&items.SimpleSchema, itemsValidations{items}, rxf(rxExampleFmt, itemsPrefix)}),
   428  				}
   429  			}
   430  
   431  			var parseArrayTypes func(expr ast.Expr, items *spec.Items, level int) ([]tagParser, error)
   432  			parseArrayTypes = func(expr ast.Expr, items *spec.Items, level int) ([]tagParser, error) {
   433  				if items == nil {
   434  					return []tagParser{}, nil
   435  				}
   436  				switch iftpe := expr.(type) {
   437  				case *ast.ArrayType:
   438  					eleTaggers := itemsTaggers(items, level)
   439  					sp.taggers = append(eleTaggers, sp.taggers...)
   440  					otherTaggers, err := parseArrayTypes(iftpe.Elt, items.Items, level+1)
   441  					if err != nil {
   442  						return nil, err
   443  					}
   444  					return otherTaggers, nil
   445  				case *ast.SelectorExpr:
   446  					otherTaggers, err := parseArrayTypes(iftpe.Sel, items.Items, level+1)
   447  					if err != nil {
   448  						return nil, err
   449  					}
   450  					return otherTaggers, nil
   451  				case *ast.Ident:
   452  					taggers := []tagParser{}
   453  					if iftpe.Obj == nil {
   454  						taggers = itemsTaggers(items, level)
   455  					}
   456  					otherTaggers, err := parseArrayTypes(expr, items.Items, level+1)
   457  					if err != nil {
   458  						return nil, err
   459  					}
   460  					return append(taggers, otherTaggers...), nil
   461  				case *ast.StarExpr:
   462  					otherTaggers, err := parseArrayTypes(iftpe.X, items, level)
   463  					if err != nil {
   464  						return nil, err
   465  					}
   466  					return otherTaggers, nil
   467  				default:
   468  					return nil, fmt.Errorf("unknown field type ele for %q", name)
   469  				}
   470  			}
   471  
   472  			// check if this is a primitive, if so parse the validations from the
   473  			// doc comments of the slice declaration.
   474  			if ftped, ok := afld.Type.(*ast.ArrayType); ok {
   475  				taggers, err := parseArrayTypes(ftped.Elt, ps.Items, 0)
   476  				if err != nil {
   477  					return err
   478  				}
   479  				sp.taggers = append(taggers, sp.taggers...)
   480  			}
   481  
   482  		} else {
   483  			sp.taggers = []tagParser{
   484  				newSingleLineTagParser("in", &matchOnlyParam{&ps, rxIn}),
   485  				newSingleLineTagParser("required", &matchOnlyParam{&ps, rxRequired}),
   486  				newMultiLineTagParser("Extensions", newSetExtensions(spExtensionsSetter(&ps)), true),
   487  			}
   488  		}
   489  		if err := sp.Parse(afld.Doc); err != nil {
   490  			return err
   491  		}
   492  		if ps.In == "path" {
   493  			ps.Required = true
   494  		}
   495  
   496  		if ps.Name == "" {
   497  			ps.Name = name
   498  		}
   499  
   500  		if name != fld.Name() {
   501  			addExtension(&ps.VendorExtensible, "x-go-name", fld.Name())
   502  		}
   503  		seen[name] = ps
   504  		sequence = append(sequence, name)
   505  	}
   506  
   507  	for _, k := range sequence {
   508  		p := seen[k]
   509  		for i, v := range op.Parameters {
   510  			if v.Name == k {
   511  				op.Parameters = append(op.Parameters[:i], op.Parameters[i+1:]...)
   512  				break
   513  			}
   514  		}
   515  		op.Parameters = append(op.Parameters, p)
   516  	}
   517  	return nil
   518  }