github.com/woocoos/entco@v0.0.0-20240411071658-1e7b23d4df15/genx/decimal.go (about)

     1  package genx
     2  
     3  import (
     4  	"entgo.io/ent/entc"
     5  	"entgo.io/ent/entc/gen"
     6  	"entgo.io/ent/schema/field"
     7  	"text/template"
     8  )
     9  
    10  type DecimalExtension struct {
    11  	entc.DefaultExtension
    12  }
    13  
    14  func (DecimalExtension) Hooks() []gen.Hook {
    15  	return []gen.Hook{
    16  		DecimalHook(),
    17  	}
    18  }
    19  
    20  func (DecimalExtension) Func() template.FuncMap {
    21  	return template.FuncMap{
    22  		"isCustomerField": func(f *gen.Field) bool {
    23  			if !f.HasGoType() {
    24  				return false
    25  			}
    26  			if f.Type.Numeric() && f.Type.RType != nil && f.Type.RType.PkgPath == "github.com/shopspring/decimal" {
    27  				return true
    28  			}
    29  			return false
    30  		},
    31  	}
    32  }
    33  
    34  func (d DecimalExtension) Templates() []*gen.Template {
    35  	return []*gen.Template{
    36  		gen.MustParse(gen.NewTemplate("runtime").
    37  			Funcs(d.Func()).
    38  			ParseFS(_templates, "template/runtime.tmpl")),
    39  		gen.MustParse(gen.NewTemplate("meta").
    40  			Funcs(d.Func()).
    41  			ParseFS(_templates, "template/meta.tmpl")),
    42  		gen.MustParse(gen.NewTemplate("create").
    43  			Funcs(d.Func()).
    44  			ParseFS(_templates, "template/create.tmpl")),
    45  		gen.MustParse(gen.NewTemplate("update").
    46  			Funcs(d.Func()).
    47  			ParseFS(_templates, "template/update.tmpl")),
    48  	}
    49  }
    50  
    51  func DecimalHook() gen.Hook {
    52  	return func(next gen.Generator) gen.Generator {
    53  		return gen.GenerateFunc(func(g *gen.Graph) error {
    54  			for _, nodes := range g.Nodes {
    55  				for _, f := range nodes.Fields {
    56  					if f.Type.RType != nil && f.Type.RType.String() == "decimal.Decimal" {
    57  						f.Type.Type = field.TypeFloat64
    58  					}
    59  				}
    60  			}
    61  			return next.Generate(g)
    62  		})
    63  	}
    64  }