github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/enum_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"go/ast"
     5  	"go/constant"
     6  	"go/types"
     7  	"reflect"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/go-courier/oas"
    13  	"golang.org/x/tools/go/loader"
    14  
    15  	"github.com/artisanhe/tools/codegen"
    16  	"github.com/artisanhe/tools/courier/enumeration"
    17  )
    18  
    19  func NewEnumScanner(program *loader.Program) *EnumScanner {
    20  	return &EnumScanner{
    21  		program: program,
    22  	}
    23  }
    24  
    25  type EnumScanner struct {
    26  	program *loader.Program
    27  	Enums   map[*types.TypeName]Enum
    28  }
    29  
    30  func (scanner *EnumScanner) HasOffset(typeName *types.TypeName) bool {
    31  	pkgInfo := scanner.program.Package(typeName.Pkg().Path())
    32  	if pkgInfo == nil {
    33  		return false
    34  	}
    35  	for _, def := range pkgInfo.Defs {
    36  		if typeConst, ok := def.(*types.Const); ok {
    37  			if typeConst.Name() == codegen.ToUpperSnakeCase(typeName.Name())+"_OFFSET" {
    38  				return true
    39  			}
    40  		}
    41  	}
    42  	return false
    43  }
    44  
    45  func (scanner *EnumScanner) Enum(typeName *types.TypeName) Enum {
    46  	if enumOptions, ok := scanner.Enums[typeName]; ok {
    47  		return enumOptions.Sort()
    48  	}
    49  
    50  	pkgInfo := scanner.program.Package(typeName.Pkg().Path())
    51  	if pkgInfo == nil {
    52  		return nil
    53  	}
    54  
    55  	typeNameString := typeName.Name()
    56  
    57  	for ident, def := range pkgInfo.Defs {
    58  		if typeConst, ok := def.(*types.Const); ok {
    59  			if typeConst.Type() == typeName.Type() {
    60  				name := typeConst.Name()
    61  
    62  				if name != "_" {
    63  					val := typeConst.Val()
    64  					label := strings.TrimSpace(ident.Obj.Decl.(*ast.ValueSpec).Comment.Text())
    65  
    66  					if strings.HasPrefix(name, codegen.ToUpperSnakeCase(typeNameString)) {
    67  						var values = strings.SplitN(name, "__", 2)
    68  						if len(values) == 2 {
    69  							scanner.addEnum(typeName, values[1], getConstVal(val), label)
    70  						}
    71  					} else {
    72  						v := getConstVal(val)
    73  						scanner.addEnum(typeName, v, v, label)
    74  					}
    75  				}
    76  			}
    77  		}
    78  	}
    79  
    80  	return scanner.Enums[typeName].Sort()
    81  }
    82  
    83  func (scanner *EnumScanner) addEnum(typeName *types.TypeName, value interface{}, val interface{}, label string) {
    84  	if scanner.Enums == nil {
    85  		scanner.Enums = map[*types.TypeName]Enum{}
    86  	}
    87  	scanner.Enums[typeName] = append(scanner.Enums[typeName], enumeration.EnumOption{
    88  		Value: value,
    89  		Val:   val,
    90  		Label: label,
    91  	})
    92  }
    93  
    94  type Enum enumeration.Enum
    95  
    96  func (enum Enum) Sort() Enum {
    97  	sort.Slice(enum, func(i, j int) bool {
    98  		switch enum[i].Value.(type) {
    99  		case string:
   100  			return enum[i].Value.(string) < enum[j].Value.(string)
   101  		case int64:
   102  			return enum[i].Value.(int64) < enum[j].Value.(int64)
   103  		case float64:
   104  			return enum[i].Value.(float64) < enum[j].Value.(float64)
   105  		}
   106  		return true
   107  	})
   108  	return enum
   109  }
   110  
   111  func (enum Enum) Labels() (labels []string) {
   112  	for _, e := range enum {
   113  		labels = append(labels, e.Label)
   114  	}
   115  	return
   116  }
   117  
   118  func (enum Enum) Vals() (vals []interface{}) {
   119  	for _, e := range enum {
   120  		vals = append(vals, e.Val)
   121  	}
   122  	return
   123  }
   124  
   125  func (enum Enum) Values() (values []interface{}) {
   126  	for _, e := range enum {
   127  		values = append(values, e.Value)
   128  	}
   129  	return
   130  }
   131  
   132  func (enum Enum) ToSchema() *oas.Schema {
   133  	values := enum.Values()
   134  
   135  	// nullable bool
   136  	if len(enum) == 2 && reflect.DeepEqual(values, []string{"FALSE", "TRUE"}) {
   137  		return oas.Boolean()
   138  	}
   139  
   140  	typeName, _ := getSchemaTypeFromBasicType(reflect.TypeOf(values[0]).Name())
   141  
   142  	s := oas.NewSchema(typeName, "").WithValidation(&oas.SchemaValidation{
   143  		Enum: values,
   144  	})
   145  	s.AddExtension(XEnumLabels, enum.Labels())
   146  	s.AddExtension(XEnumVals, enum.Vals())
   147  	s.AddExtension(XEnumValues, values)
   148  	return s
   149  }
   150  
   151  func getConstVal(constVal constant.Value) interface{} {
   152  	switch constVal.Kind() {
   153  	case constant.String:
   154  		stringVal, _ := strconv.Unquote(constVal.String())
   155  		return stringVal
   156  	case constant.Int:
   157  		intVal, _ := strconv.ParseInt(constVal.String(), 10, 64)
   158  		return intVal
   159  	case constant.Float:
   160  		floatVal, _ := strconv.ParseFloat(constVal.String(), 10)
   161  		return floatVal
   162  	}
   163  	return nil
   164  }