github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/enum_scanner.go (about) 1 package gen 2 3 import ( 4 "go/ast" 5 "go/constant" 6 "go/types" 7 "reflect" 8 "sort" 9 "strconv" 10 "strings" 11 12 "github.com/go-courier/oas" 13 "golang.org/x/tools/go/loader" 14 15 "github.com/artisanhe/tools/codegen" 16 "github.com/artisanhe/tools/courier/enumeration" 17 ) 18 19 func NewEnumScanner(program *loader.Program) *EnumScanner { 20 return &EnumScanner{ 21 program: program, 22 } 23 } 24 25 type EnumScanner struct { 26 program *loader.Program 27 Enums map[*types.TypeName]Enum 28 } 29 30 func (scanner *EnumScanner) HasOffset(typeName *types.TypeName) bool { 31 pkgInfo := scanner.program.Package(typeName.Pkg().Path()) 32 if pkgInfo == nil { 33 return false 34 } 35 for _, def := range pkgInfo.Defs { 36 if typeConst, ok := def.(*types.Const); ok { 37 if typeConst.Name() == codegen.ToUpperSnakeCase(typeName.Name())+"_OFFSET" { 38 return true 39 } 40 } 41 } 42 return false 43 } 44 45 func (scanner *EnumScanner) Enum(typeName *types.TypeName) Enum { 46 if enumOptions, ok := scanner.Enums[typeName]; ok { 47 return enumOptions.Sort() 48 } 49 50 pkgInfo := scanner.program.Package(typeName.Pkg().Path()) 51 if pkgInfo == nil { 52 return nil 53 } 54 55 typeNameString := typeName.Name() 56 57 for ident, def := range pkgInfo.Defs { 58 if typeConst, ok := def.(*types.Const); ok { 59 if typeConst.Type() == typeName.Type() { 60 name := typeConst.Name() 61 62 if name != "_" { 63 val := typeConst.Val() 64 label := strings.TrimSpace(ident.Obj.Decl.(*ast.ValueSpec).Comment.Text()) 65 66 if strings.HasPrefix(name, codegen.ToUpperSnakeCase(typeNameString)) { 67 var values = strings.SplitN(name, "__", 2) 68 if len(values) == 2 { 69 scanner.addEnum(typeName, values[1], getConstVal(val), label) 70 } 71 } else { 72 v := getConstVal(val) 73 scanner.addEnum(typeName, v, v, label) 74 } 75 } 76 } 77 } 78 } 79 80 return scanner.Enums[typeName].Sort() 81 } 82 83 func (scanner *EnumScanner) addEnum(typeName *types.TypeName, value interface{}, val interface{}, label string) { 84 if scanner.Enums == nil { 85 scanner.Enums = map[*types.TypeName]Enum{} 86 } 87 scanner.Enums[typeName] = append(scanner.Enums[typeName], enumeration.EnumOption{ 88 Value: value, 89 Val: val, 90 Label: label, 91 }) 92 } 93 94 type Enum enumeration.Enum 95 96 func (enum Enum) Sort() Enum { 97 sort.Slice(enum, func(i, j int) bool { 98 switch enum[i].Value.(type) { 99 case string: 100 return enum[i].Value.(string) < enum[j].Value.(string) 101 case int64: 102 return enum[i].Value.(int64) < enum[j].Value.(int64) 103 case float64: 104 return enum[i].Value.(float64) < enum[j].Value.(float64) 105 } 106 return true 107 }) 108 return enum 109 } 110 111 func (enum Enum) Labels() (labels []string) { 112 for _, e := range enum { 113 labels = append(labels, e.Label) 114 } 115 return 116 } 117 118 func (enum Enum) Vals() (vals []interface{}) { 119 for _, e := range enum { 120 vals = append(vals, e.Val) 121 } 122 return 123 } 124 125 func (enum Enum) Values() (values []interface{}) { 126 for _, e := range enum { 127 values = append(values, e.Value) 128 } 129 return 130 } 131 132 func (enum Enum) ToSchema() *oas.Schema { 133 values := enum.Values() 134 135 // nullable bool 136 if len(enum) == 2 && reflect.DeepEqual(values, []string{"FALSE", "TRUE"}) { 137 return oas.Boolean() 138 } 139 140 typeName, _ := getSchemaTypeFromBasicType(reflect.TypeOf(values[0]).Name()) 141 142 s := oas.NewSchema(typeName, "").WithValidation(&oas.SchemaValidation{ 143 Enum: values, 144 }) 145 s.AddExtension(XEnumLabels, enum.Labels()) 146 s.AddExtension(XEnumVals, enum.Vals()) 147 s.AddExtension(XEnumValues, values) 148 return s 149 } 150 151 func getConstVal(constVal constant.Value) interface{} { 152 switch constVal.Kind() { 153 case constant.String: 154 stringVal, _ := strconv.Unquote(constVal.String()) 155 return stringVal 156 case constant.Int: 157 intVal, _ := strconv.ParseInt(constVal.String(), 10, 64) 158 return intVal 159 case constant.Float: 160 floatVal, _ := strconv.ParseFloat(constVal.String(), 10) 161 return floatVal 162 } 163 return nil 164 }