github.com/ManabuSeki/goa-v1@v1.4.3/goagen/gen_app/generator.go (about)

     1  package genapp
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  
    10  	"github.com/goadesign/goa/design"
    11  	"github.com/goadesign/goa/goagen/codegen"
    12  	"github.com/goadesign/goa/goagen/utils"
    13  )
    14  
    15  //NewGenerator returns an initialized instance of an Application Generator
    16  func NewGenerator(options ...Option) *Generator {
    17  	g := &Generator{}
    18  	g.validator = codegen.NewValidator()
    19  
    20  	for _, option := range options {
    21  		option(g)
    22  	}
    23  
    24  	return g
    25  }
    26  
    27  // Generator is the application code generator.
    28  type Generator struct {
    29  	API       *design.APIDefinition // The API definition
    30  	OutDir    string                // Path to output directory
    31  	Target    string                // Name of generated package
    32  	NoTest    bool                  // Whether to skip test generation
    33  	genfiles  []string              // Generated files
    34  	validator *codegen.Validator    // Validation code generator
    35  }
    36  
    37  // Generate is the generator entry point called by the meta generator.
    38  func Generate() (files []string, err error) {
    39  	var (
    40  		outDir, toolDir, target, ver string
    41  		notest, notool, regen        bool
    42  	)
    43  
    44  	set := flag.NewFlagSet("app", flag.PanicOnError)
    45  	set.String("design", "", "")
    46  	set.StringVar(&outDir, "out", "", "")
    47  	set.StringVar(&target, "pkg", "app", "")
    48  	set.StringVar(&ver, "version", "", "")
    49  	set.StringVar(&toolDir, "tooldir", "tool", "")
    50  	set.BoolVar(&notest, "notest", false, "")
    51  	set.BoolVar(&notool, "notool", false, "")
    52  	set.BoolVar(&regen, "regen", false, "")
    53  	set.Bool("force", false, "")
    54  	set.Parse(os.Args[1:])
    55  	outDir = filepath.Join(outDir, target)
    56  
    57  	if err := codegen.CheckVersion(ver); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	target = codegen.Goify(target, false)
    62  	g := &Generator{OutDir: outDir, Target: target, NoTest: notest, API: design.Design, validator: codegen.NewValidator()}
    63  
    64  	return g.Generate()
    65  }
    66  
    67  // Generate the application code, implement codegen.Generator.
    68  func (g *Generator) Generate() (_ []string, err error) {
    69  	if g.API == nil {
    70  		return nil, fmt.Errorf("missing API definition, make sure design is properly initialized")
    71  	}
    72  
    73  	go utils.Catch(nil, func() { g.Cleanup() })
    74  
    75  	defer func() {
    76  		if err != nil {
    77  			g.Cleanup()
    78  		}
    79  	}()
    80  
    81  	codegen.Reserved[g.Target] = true
    82  
    83  	os.RemoveAll(g.OutDir)
    84  
    85  	if err := os.MkdirAll(g.OutDir, 0755); err != nil {
    86  		return nil, err
    87  	}
    88  	g.genfiles = []string{g.OutDir}
    89  	if err := g.generateContexts(); err != nil {
    90  		return nil, err
    91  	}
    92  	if err := g.generateControllers(); err != nil {
    93  		return nil, err
    94  	}
    95  	if err := g.generateSecurity(); err != nil {
    96  		return nil, err
    97  	}
    98  	if err := g.generateHrefs(); err != nil {
    99  		return nil, err
   100  	}
   101  	if err := g.generateMediaTypes(); err != nil {
   102  		return nil, err
   103  	}
   104  	if err := g.generateUserTypes(); err != nil {
   105  		return nil, err
   106  	}
   107  	if !g.NoTest {
   108  		if err := g.generateResourceTest(); err != nil {
   109  			return nil, err
   110  		}
   111  	}
   112  
   113  	return g.genfiles, nil
   114  }
   115  
   116  // Cleanup removes the entire "app" directory if it was created by this generator.
   117  func (g *Generator) Cleanup() {
   118  	if len(g.genfiles) == 0 {
   119  		return
   120  	}
   121  	os.RemoveAll(g.OutDir)
   122  	g.genfiles = nil
   123  }
   124  
   125  // generateContexts iterates through the API resources and actions and generates the action
   126  // contexts.
   127  func (g *Generator) generateContexts() (err error) {
   128  	var (
   129  		ctxFile string
   130  		ctxWr   *ContextsWriter
   131  	)
   132  	{
   133  		ctxFile = filepath.Join(g.OutDir, "contexts.go")
   134  		ctxWr, err = NewContextsWriter(ctxFile)
   135  		if err != nil {
   136  			return
   137  		}
   138  	}
   139  	defer func() {
   140  		ctxWr.Close()
   141  		if err == nil {
   142  			err = ctxWr.FormatCode()
   143  		}
   144  	}()
   145  	title := fmt.Sprintf("%s: Application Contexts", g.API.Context())
   146  	imports := []*codegen.ImportSpec{
   147  		codegen.SimpleImport("fmt"),
   148  		codegen.SimpleImport("net/http"),
   149  		codegen.SimpleImport("strconv"),
   150  		codegen.SimpleImport("strings"),
   151  		codegen.SimpleImport("time"),
   152  		codegen.SimpleImport("unicode/utf8"),
   153  		codegen.SimpleImport("github.com/goadesign/goa"),
   154  		codegen.NewImport("uuid", "github.com/gofrs/uuid"),
   155  		codegen.SimpleImport("context"),
   156  	}
   157  	g.API.IterateResources(func(r *design.ResourceDefinition) error {
   158  		return r.IterateActions(func(a *design.ActionDefinition) error {
   159  			if a.Payload != nil {
   160  				imports = codegen.AttributeImports(a.Payload.AttributeDefinition, imports, nil)
   161  			}
   162  			return nil
   163  		})
   164  	})
   165  
   166  	g.genfiles = append(g.genfiles, ctxFile)
   167  	if err = ctxWr.WriteHeader(title, g.Target, imports); err != nil {
   168  		return
   169  	}
   170  	err = g.API.IterateResources(func(r *design.ResourceDefinition) error {
   171  		return r.IterateActions(func(a *design.ActionDefinition) error {
   172  			ctxName := codegen.Goify(a.Name, true) + codegen.Goify(a.Parent.Name, true) + "Context"
   173  			headers := &design.AttributeDefinition{
   174  				Type: design.Object{},
   175  			}
   176  			if r.Headers != nil {
   177  				headers.Merge(r.Headers)
   178  				headers.Validation = r.Headers.Validation
   179  			}
   180  			if a.Headers != nil {
   181  				headers.Merge(a.Headers)
   182  				headers.Validation = a.Headers.Validation
   183  			}
   184  			if headers != nil && len(headers.Type.ToObject()) == 0 {
   185  				headers = nil // So that {{if .Headers}} returns false in templates
   186  			}
   187  			params := a.AllParams()
   188  			if params != nil && len(params.Type.ToObject()) == 0 {
   189  				params = nil // So that {{if .Params}} returns false in templates
   190  			}
   191  
   192  			non101 := make(map[string]*design.ResponseDefinition)
   193  			for k, v := range a.Responses {
   194  				if v.Status != 101 {
   195  					non101[k] = v
   196  				}
   197  			}
   198  			ctxData := ContextTemplateData{
   199  				Name:         ctxName,
   200  				ResourceName: r.Name,
   201  				ActionName:   a.Name,
   202  				Payload:      a.Payload,
   203  				Params:       params,
   204  				Headers:      headers,
   205  				Routes:       a.Routes,
   206  				Responses:    non101,
   207  				API:          g.API,
   208  				DefaultPkg:   g.Target,
   209  				Security:     a.Security,
   210  			}
   211  			return ctxWr.Execute(&ctxData)
   212  		})
   213  	})
   214  	return
   215  }
   216  
   217  // generateControllers iterates through the API resources and generates the low level
   218  // controllers.
   219  func (g *Generator) generateControllers() (err error) {
   220  	var (
   221  		ctlFile string
   222  		ctlWr   *ControllersWriter
   223  	)
   224  	{
   225  		ctlFile = filepath.Join(g.OutDir, "controllers.go")
   226  		ctlWr, err = NewControllersWriter(ctlFile)
   227  		if err != nil {
   228  			return
   229  		}
   230  	}
   231  	defer func() {
   232  		ctlWr.Close()
   233  		if err == nil {
   234  			err = ctlWr.FormatCode()
   235  		}
   236  	}()
   237  	title := fmt.Sprintf("%s: Application Controllers", g.API.Context())
   238  	imports := []*codegen.ImportSpec{
   239  		codegen.SimpleImport("net/http"),
   240  		codegen.SimpleImport("fmt"),
   241  		codegen.SimpleImport("context"),
   242  		codegen.SimpleImport("github.com/goadesign/goa"),
   243  		codegen.SimpleImport("github.com/goadesign/goa/cors"),
   244  		codegen.SimpleImport("regexp"),
   245  		codegen.SimpleImport("strconv"),
   246  		codegen.SimpleImport("time"),
   247  		codegen.NewImport("uuid", "github.com/gofrs/uuid"),
   248  	}
   249  	encoders, err := BuildEncoders(g.API.Produces, true)
   250  	if err != nil {
   251  		return err
   252  	}
   253  	decoders, err := BuildEncoders(g.API.Consumes, false)
   254  	if err != nil {
   255  		return err
   256  	}
   257  	encoderImports := make(map[string]bool)
   258  	for _, data := range encoders {
   259  		encoderImports[data.PackagePath] = true
   260  	}
   261  	for _, data := range decoders {
   262  		encoderImports[data.PackagePath] = true
   263  	}
   264  	var packagePaths []string
   265  	for packagePath := range encoderImports {
   266  		if packagePath != "github.com/goadesign/goa" {
   267  			packagePaths = append(packagePaths, packagePath)
   268  		}
   269  	}
   270  	sort.Strings(packagePaths)
   271  	for _, packagePath := range packagePaths {
   272  		imports = append(imports, codegen.SimpleImport(packagePath))
   273  	}
   274  	if err = ctlWr.WriteHeader(title, g.Target, imports); err != nil {
   275  		return err
   276  	}
   277  	if err = ctlWr.WriteInitService(encoders, decoders); err != nil {
   278  		return err
   279  	}
   280  
   281  	g.genfiles = append(g.genfiles, ctlFile)
   282  	var controllersData []*ControllerTemplateData
   283  	g.API.IterateResources(func(r *design.ResourceDefinition) error {
   284  		// Create file servers for all directory file servers that serve index.html.
   285  		fileServers := r.FileServers
   286  		for _, fs := range r.FileServers {
   287  			if fs.IsDir() {
   288  				rpath := design.WildcardRegex.ReplaceAllLiteralString(fs.RequestPath, "")
   289  				rpath += "/"
   290  				fileServers = append(fileServers, &design.FileServerDefinition{
   291  					Parent:      fs.Parent,
   292  					Description: fs.Description,
   293  					Docs:        fs.Docs,
   294  					FilePath:    filepath.Join(fs.FilePath, "index.html"),
   295  					RequestPath: rpath,
   296  					Metadata:    fs.Metadata,
   297  					Security:    fs.Security,
   298  				})
   299  			}
   300  		}
   301  		data := &ControllerTemplateData{
   302  			API:            g.API,
   303  			Resource:       codegen.Goify(r.Name, true),
   304  			PreflightPaths: r.PreflightPaths(),
   305  			FileServers:    fileServers,
   306  		}
   307  		r.IterateActions(func(a *design.ActionDefinition) error {
   308  			context := fmt.Sprintf("%s%sContext", codegen.Goify(a.Name, true), codegen.Goify(r.Name, true))
   309  			unmarshal := fmt.Sprintf("unmarshal%s%sPayload", codegen.Goify(a.Name, true), codegen.Goify(r.Name, true))
   310  			action := map[string]interface{}{
   311  				"Name":             codegen.Goify(a.Name, true),
   312  				"DesignName":       a.Name,
   313  				"Routes":           a.Routes,
   314  				"Context":          context,
   315  				"Unmarshal":        unmarshal,
   316  				"Payload":          a.Payload,
   317  				"PayloadOptional":  a.PayloadOptional,
   318  				"PayloadMultipart": a.PayloadMultipart,
   319  				"Security":         a.Security,
   320  			}
   321  			data.Actions = append(data.Actions, action)
   322  			return nil
   323  		})
   324  		if len(data.Actions) > 0 || len(data.FileServers) > 0 {
   325  			data.Encoders = encoders
   326  			data.Decoders = decoders
   327  			data.Origins = r.AllOrigins()
   328  			controllersData = append(controllersData, data)
   329  		}
   330  		return nil
   331  	})
   332  	err = ctlWr.Execute(controllersData)
   333  	return
   334  }
   335  
   336  // generateControllers iterates through the API resources and generates the low level
   337  // controllers.
   338  func (g *Generator) generateSecurity() (err error) {
   339  	if len(g.API.SecuritySchemes) == 0 {
   340  		return nil
   341  	}
   342  
   343  	var (
   344  		secFile string
   345  		secWr   *SecurityWriter
   346  	)
   347  	{
   348  		secFile = filepath.Join(g.OutDir, "security.go")
   349  		secWr, err = NewSecurityWriter(secFile)
   350  		if err != nil {
   351  			return
   352  		}
   353  	}
   354  	defer func() {
   355  		secWr.Close()
   356  		if err == nil {
   357  			err = secWr.FormatCode()
   358  		}
   359  	}()
   360  	title := fmt.Sprintf("%s: Application Security", g.API.Context())
   361  	imports := []*codegen.ImportSpec{
   362  		codegen.SimpleImport("net/http"),
   363  		codegen.SimpleImport("errors"),
   364  		codegen.SimpleImport("context"),
   365  		codegen.SimpleImport("github.com/goadesign/goa"),
   366  	}
   367  	if err = secWr.WriteHeader(title, g.Target, imports); err != nil {
   368  		return err
   369  	}
   370  	g.genfiles = append(g.genfiles, secFile)
   371  	err = secWr.Execute(design.Design.SecuritySchemes)
   372  
   373  	return
   374  }
   375  
   376  // generateHrefs iterates through the API resources and generates the href factory methods.
   377  func (g *Generator) generateHrefs() (err error) {
   378  	var (
   379  		hrefFile string
   380  		resWr    *ResourcesWriter
   381  	)
   382  	{
   383  		hrefFile = filepath.Join(g.OutDir, "hrefs.go")
   384  		resWr, err = NewResourcesWriter(hrefFile)
   385  		if err != nil {
   386  			return
   387  		}
   388  	}
   389  	defer func() {
   390  		resWr.Close()
   391  		if err == nil {
   392  			err = resWr.FormatCode()
   393  		}
   394  	}()
   395  	title := fmt.Sprintf("%s: Application Resource Href Factories", g.API.Context())
   396  	imports := []*codegen.ImportSpec{
   397  		codegen.SimpleImport("fmt"),
   398  		codegen.SimpleImport("strings"),
   399  	}
   400  	if err = resWr.WriteHeader(title, g.Target, imports); err != nil {
   401  		return err
   402  	}
   403  	g.genfiles = append(g.genfiles, hrefFile)
   404  	err = g.API.IterateResources(func(r *design.ResourceDefinition) error {
   405  		m := g.API.MediaTypeWithIdentifier(r.MediaType)
   406  		var identifier string
   407  		if m != nil {
   408  			identifier = m.Identifier
   409  		} else {
   410  			identifier = "text/plain"
   411  		}
   412  		data := ResourceData{
   413  			Name:              codegen.Goify(r.Name, true),
   414  			Identifier:        identifier,
   415  			Description:       r.Description,
   416  			Type:              m,
   417  			CanonicalTemplate: codegen.CanonicalTemplate(r),
   418  			CanonicalParams:   codegen.CanonicalParams(r),
   419  		}
   420  		return resWr.Execute(&data)
   421  	})
   422  	return
   423  }
   424  
   425  // generateMediaTypes iterates through the media types and generate the data structures and
   426  // marshaling code.
   427  func (g *Generator) generateMediaTypes() (err error) {
   428  	var (
   429  		mtFile string
   430  		mtWr   *MediaTypesWriter
   431  	)
   432  	{
   433  		mtFile = filepath.Join(g.OutDir, "media_types.go")
   434  		mtWr, err = NewMediaTypesWriter(mtFile)
   435  		if err != nil {
   436  			return
   437  		}
   438  	}
   439  	defer func() {
   440  		mtWr.Close()
   441  		if err == nil {
   442  			err = mtWr.FormatCode()
   443  		}
   444  	}()
   445  	title := fmt.Sprintf("%s: Application Media Types", g.API.Context())
   446  	imports := []*codegen.ImportSpec{
   447  		codegen.SimpleImport("github.com/goadesign/goa"),
   448  		codegen.SimpleImport("fmt"),
   449  		codegen.SimpleImport("time"),
   450  		codegen.SimpleImport("unicode/utf8"),
   451  		codegen.NewImport("uuid", "github.com/gofrs/uuid"),
   452  	}
   453  	for _, v := range g.API.MediaTypes {
   454  		imports = codegen.AttributeImports(v.AttributeDefinition, imports, nil)
   455  	}
   456  	if err = mtWr.WriteHeader(title, g.Target, imports); err != nil {
   457  		return err
   458  	}
   459  	g.genfiles = append(g.genfiles, mtFile)
   460  	err = g.API.IterateMediaTypes(func(mt *design.MediaTypeDefinition) error {
   461  		if mt.IsError() {
   462  			return nil
   463  		}
   464  		if mt.Type.IsObject() || mt.Type.IsArray() {
   465  			return mtWr.Execute(mt)
   466  		}
   467  		return nil
   468  	})
   469  	return
   470  }
   471  
   472  // generateUserTypes iterates through the user types and generates the data structures and
   473  // marshaling code.
   474  func (g *Generator) generateUserTypes() (err error) {
   475  	var (
   476  		utFile string
   477  		utWr   *UserTypesWriter
   478  	)
   479  	{
   480  		utFile = filepath.Join(g.OutDir, "user_types.go")
   481  		utWr, err = NewUserTypesWriter(utFile)
   482  		if err != nil {
   483  			return
   484  		}
   485  	}
   486  	defer func() {
   487  		utWr.Close()
   488  		if err == nil {
   489  			err = utWr.FormatCode()
   490  		}
   491  	}()
   492  	title := fmt.Sprintf("%s: Application User Types", g.API.Context())
   493  	imports := []*codegen.ImportSpec{
   494  		codegen.SimpleImport("fmt"),
   495  		codegen.SimpleImport("mime/multipart"),
   496  		codegen.SimpleImport("time"),
   497  		codegen.SimpleImport("unicode/utf8"),
   498  		codegen.SimpleImport("github.com/goadesign/goa"),
   499  		codegen.NewImport("uuid", "github.com/gofrs/uuid"),
   500  	}
   501  	for _, v := range g.API.Types {
   502  		imports = codegen.AttributeImports(v.AttributeDefinition, imports, nil)
   503  	}
   504  	if err = utWr.WriteHeader(title, g.Target, imports); err != nil {
   505  		return err
   506  	}
   507  	g.genfiles = append(g.genfiles, utFile)
   508  	err = g.API.IterateUserTypes(func(t *design.UserTypeDefinition) error {
   509  		return utWr.Execute(t)
   510  	})
   511  	return
   512  }