github.com/woocoos/entco@v0.0.0-20240411071658-1e7b23d4df15/genx/entc.go (about)

     1  package genx
     2  
     3  import (
     4  	atlas "ariga.io/atlas/sql/schema"
     5  	"embed"
     6  	"entgo.io/contrib/entgql"
     7  	"entgo.io/ent/dialect/sql/schema"
     8  	"entgo.io/ent/entc"
     9  	"entgo.io/ent/entc/gen"
    10  	"github.com/vektah/gqlparser/v2/ast"
    11  )
    12  
    13  var (
    14  	//go:embed template/*
    15  	_templates embed.FS
    16  )
    17  
    18  // GlobalID is a global id template for Noder Query. Use with ChangeRelayNodeType().
    19  //
    20  // if you use GlobalID, must use GID as scalar type.
    21  // and use ChangeRelayNodeType() in entgql.WithSchemaHook()
    22  func GlobalID() entc.Option {
    23  	return func(g *gen.Config) error {
    24  		g.Templates = append(g.Templates, gen.MustParse(gen.NewTemplate("gql_globalid").
    25  			Funcs(entgql.TemplateFuncs).
    26  			ParseFS(_templates, "template/globalid.tmpl")))
    27  		return nil
    28  	}
    29  }
    30  
    31  func SimplePagination() entc.Option {
    32  	return func(g *gen.Config) error {
    33  		g.Templates = append(g.Templates, gen.MustParse(gen.NewTemplate("gql_pagination_simple").
    34  			Funcs(entgql.TemplateFuncs).
    35  			ParseFS(_templates, "template/gql_pagination_simple.tmpl")))
    36  		return nil
    37  	}
    38  }
    39  
    40  // ChangeRelayNodeType is a schema hook for change relay node type to GID. Use with GlobalID().
    41  //
    42  // add it to entgql.WithSchemaHook()
    43  func ChangeRelayNodeType() entgql.SchemaHook {
    44  	idType := ast.NonNullNamedType("GID", nil)
    45  	found := false
    46  	return func(graph *gen.Graph, schema *ast.Schema) error {
    47  		for _, field := range schema.Types["Query"].Fields {
    48  			if field.Name == "node" {
    49  				field.Arguments[0].Type = idType
    50  				found = true
    51  			}
    52  			if field.Name == "nodes" {
    53  				field.Arguments[0].Type = ast.NonNullListType(idType, nil)
    54  				found = true
    55  			}
    56  		}
    57  		if found && schema.Types["GID"] == nil {
    58  			schema.Types["GID"] = &ast.Definition{
    59  				Kind:        ast.Scalar,
    60  				Name:        "GID",
    61  				Description: "An object with a Global ID,for using in Noder interface.",
    62  			}
    63  		}
    64  		return nil
    65  	}
    66  }
    67  
    68  // WithGqlWithTemplates is a schema hook for replace entgql default template.
    69  // Note: this option must put before WithWhereInputs or which changed entgql templates option.
    70  //
    71  // extensions:
    72  //  1. NodeTemplate:
    73  //     Noder: add entcache context
    74  func WithGqlWithTemplates() entgql.ExtensionOption {
    75  	nodeTpl := gen.MustParse(gen.NewTemplate("node").
    76  		Funcs(entgql.TemplateFuncs).ParseFS(_templates, "template/node.tmpl"))
    77  	return entgql.WithTemplates(append(entgql.AllTemplates, nodeTpl)...)
    78  }
    79  
    80  // ReplaceGqlMutationInput is a schema hook for replace gql mutation input template.
    81  // Deprecated: not use
    82  func ReplaceGqlMutationInput() entgql.ExtensionOption {
    83  	rt := gen.MustParse(gen.NewTemplate("gql_mutation_input").
    84  		Funcs(entgql.TemplateFuncs).
    85  		ParseFS(_templates, "template/gql_mutation_input.tmpl")).SkipIf(skipMutationTemplate)
    86  	return entgql.WithTemplates([]*gen.Template{
    87  		entgql.CollectionTemplate,
    88  		entgql.EnumTemplate,
    89  		entgql.NodeTemplate,
    90  		entgql.PaginationTemplate,
    91  		entgql.TransactionTemplate,
    92  		entgql.EdgeTemplate,
    93  		entgql.WhereTemplate,
    94  		rt,
    95  	}...)
    96  }
    97  
    98  func skipMutationTemplate(g *gen.Graph) bool {
    99  	for _, n := range g.Nodes {
   100  		ant, err := annotation(n.Annotations)
   101  		if err != nil {
   102  			continue
   103  		}
   104  		for _, i := range ant.MutationInputs {
   105  			if (i.IsCreate && !ant.Skip.Is(entgql.SkipMutationCreateInput)) ||
   106  				(!i.IsCreate && !ant.Skip.Is(entgql.SkipMutationUpdateInput)) {
   107  				return false
   108  			}
   109  		}
   110  	}
   111  	return true
   112  }
   113  
   114  // annotation extracts the entgql.Annotation or returns its empty value.
   115  func annotation(ants gen.Annotations) (*entgql.Annotation, error) {
   116  	ant := &entgql.Annotation{}
   117  	if ants != nil && ants[ant.Name()] != nil {
   118  		if err := ant.Decode(ants[ant.Name()]); err != nil {
   119  			return nil, err
   120  		}
   121  	}
   122  	return ant, nil
   123  }
   124  
   125  // SkipTablesDiffHook is a schema migration hook for skip tables diff thus skip migration.
   126  // the table name is database name,not the ent schema struct name.
   127  //
   128  //	err = client.Schema.Create(ctx,SkipTablesDiffHook("table1","table2"))
   129  func SkipTablesDiffHook(tables ...string) schema.MigrateOption {
   130  	return schema.WithDiffHook(func(next schema.Differ) schema.Differ {
   131  		return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) {
   132  			var dt []*atlas.Table
   133  		LOOP:
   134  			for i, table := range desired.Tables {
   135  				for _, t := range tables {
   136  					if table.Name == t {
   137  						continue LOOP
   138  					}
   139  				}
   140  				dt = append(dt, desired.Tables[i])
   141  			}
   142  			desired.Tables = dt
   143  			// Before calculating changes.
   144  			changes, err := next.Diff(current, desired)
   145  			if err != nil {
   146  				return nil, err
   147  			}
   148  			// After diff, you can filter
   149  			// changes or return new ones.
   150  			return changes, nil
   151  		})
   152  	})
   153  }