github.com/woocoos/entco@v0.0.0-20240411071658-1e7b23d4df15/genx/decimal.go (about) 1 package genx 2 3 import ( 4 "entgo.io/ent/entc" 5 "entgo.io/ent/entc/gen" 6 "entgo.io/ent/schema/field" 7 "text/template" 8 ) 9 10 type DecimalExtension struct { 11 entc.DefaultExtension 12 } 13 14 func (DecimalExtension) Hooks() []gen.Hook { 15 return []gen.Hook{ 16 DecimalHook(), 17 } 18 } 19 20 func (DecimalExtension) Func() template.FuncMap { 21 return template.FuncMap{ 22 "isCustomerField": func(f *gen.Field) bool { 23 if !f.HasGoType() { 24 return false 25 } 26 if f.Type.Numeric() && f.Type.RType != nil && f.Type.RType.PkgPath == "github.com/shopspring/decimal" { 27 return true 28 } 29 return false 30 }, 31 } 32 } 33 34 func (d DecimalExtension) Templates() []*gen.Template { 35 return []*gen.Template{ 36 gen.MustParse(gen.NewTemplate("runtime"). 37 Funcs(d.Func()). 38 ParseFS(_templates, "template/runtime.tmpl")), 39 gen.MustParse(gen.NewTemplate("meta"). 40 Funcs(d.Func()). 41 ParseFS(_templates, "template/meta.tmpl")), 42 gen.MustParse(gen.NewTemplate("create"). 43 Funcs(d.Func()). 44 ParseFS(_templates, "template/create.tmpl")), 45 gen.MustParse(gen.NewTemplate("update"). 46 Funcs(d.Func()). 47 ParseFS(_templates, "template/update.tmpl")), 48 } 49 } 50 51 func DecimalHook() gen.Hook { 52 return func(next gen.Generator) gen.Generator { 53 return gen.GenerateFunc(func(g *gen.Graph) error { 54 for _, nodes := range g.Nodes { 55 for _, f := range nodes.Fields { 56 if f.Type.RType != nil && f.Type.RType.String() == "decimal.Decimal" { 57 f.Type.Type = field.TypeFloat64 58 } 59 } 60 } 61 return next.Generate(g) 62 }) 63 } 64 }