github.com/goldeneggg/goa@v1.3.1/goagen/gen_app/generator.go (about)

     1  package genapp
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  
    10  	"github.com/goadesign/goa/design"
    11  	"github.com/goadesign/goa/goagen/codegen"
    12  	"github.com/goadesign/goa/goagen/utils"
    13  )
    14  
    15  //NewGenerator returns an initialized instance of an Application Generator
    16  func NewGenerator(options ...Option) *Generator {
    17  	g := &Generator{}
    18  	g.validator = codegen.NewValidator()
    19  
    20  	for _, option := range options {
    21  		option(g)
    22  	}
    23  
    24  	return g
    25  }
    26  
    27  // Generator is the application code generator.
    28  type Generator struct {
    29  	API       *design.APIDefinition // The API definition
    30  	OutDir    string                // Path to output directory
    31  	Target    string                // Name of generated package
    32  	NoTest    bool                  // Whether to skip test generation
    33  	genfiles  []string              // Generated files
    34  	validator *codegen.Validator    // Validation code generator
    35  }
    36  
    37  // Generate is the generator entry point called by the meta generator.
    38  func Generate() (files []string, err error) {
    39  	var (
    40  		outDir, toolDir, target, ver string
    41  		notest, notool, regen        bool
    42  	)
    43  
    44  	set := flag.NewFlagSet("app", flag.PanicOnError)
    45  	set.String("design", "", "")
    46  	set.StringVar(&outDir, "out", "", "")
    47  	set.StringVar(&target, "pkg", "app", "")
    48  	set.StringVar(&ver, "version", "", "")
    49  	set.StringVar(&toolDir, "tooldir", "tool", "")
    50  	set.BoolVar(&notest, "notest", false, "")
    51  	set.BoolVar(&notool, "notool", false, "")
    52  	set.BoolVar(&regen, "regen", false, "")
    53  	set.Bool("force", false, "")
    54  	set.Parse(os.Args[1:])
    55  	outDir = filepath.Join(outDir, target)
    56  
    57  	if err := codegen.CheckVersion(ver); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	target = codegen.Goify(target, false)
    62  	g := &Generator{OutDir: outDir, Target: target, NoTest: notest, API: design.Design, validator: codegen.NewValidator()}
    63  
    64  	return g.Generate()
    65  }
    66  
    67  // Generate the application code, implement codegen.Generator.
    68  func (g *Generator) Generate() (_ []string, err error) {
    69  	if g.API == nil {
    70  		return nil, fmt.Errorf("missing API definition, make sure design is properly initialized")
    71  	}
    72  
    73  	go utils.Catch(nil, func() { g.Cleanup() })
    74  
    75  	defer func() {
    76  		if err != nil {
    77  			g.Cleanup()
    78  		}
    79  	}()
    80  
    81  	codegen.Reserved[g.Target] = true
    82  
    83  	os.RemoveAll(g.OutDir)
    84  
    85  	if err := os.MkdirAll(g.OutDir, 0755); err != nil {
    86  		return nil, err
    87  	}
    88  	g.genfiles = []string{g.OutDir}
    89  	if err := g.generateContexts(); err != nil {
    90  		return nil, err
    91  	}
    92  	if err := g.generateControllers(); err != nil {
    93  		return nil, err
    94  	}
    95  	if err := g.generateSecurity(); err != nil {
    96  		return nil, err
    97  	}
    98  	if err := g.generateHrefs(); err != nil {
    99  		return nil, err
   100  	}
   101  	if err := g.generateMediaTypes(); err != nil {
   102  		return nil, err
   103  	}
   104  	if err := g.generateUserTypes(); err != nil {
   105  		return nil, err
   106  	}
   107  	if !g.NoTest {
   108  		if err := g.generateResourceTest(); err != nil {
   109  			return nil, err
   110  		}
   111  	}
   112  
   113  	return g.genfiles, nil
   114  }
   115  
   116  // Cleanup removes the entire "app" directory if it was created by this generator.
   117  func (g *Generator) Cleanup() {
   118  	if len(g.genfiles) == 0 {
   119  		return
   120  	}
   121  	os.RemoveAll(g.OutDir)
   122  	g.genfiles = nil
   123  }
   124  
   125  // generateContexts iterates through the API resources and actions and generates the action
   126  // contexts.
   127  func (g *Generator) generateContexts() (err error) {
   128  	var (
   129  		ctxFile string
   130  		ctxWr   *ContextsWriter
   131  	)
   132  	{
   133  		ctxFile = filepath.Join(g.OutDir, "contexts.go")
   134  		ctxWr, err = NewContextsWriter(ctxFile)
   135  		if err != nil {
   136  			return
   137  		}
   138  	}
   139  	defer func() {
   140  		ctxWr.Close()
   141  		if err == nil {
   142  			err = ctxWr.FormatCode()
   143  		}
   144  	}()
   145  	title := fmt.Sprintf("%s: Application Contexts", g.API.Context())
   146  	imports := []*codegen.ImportSpec{
   147  		codegen.SimpleImport("fmt"),
   148  		codegen.SimpleImport("net/http"),
   149  		codegen.SimpleImport("strconv"),
   150  		codegen.SimpleImport("strings"),
   151  		codegen.SimpleImport("time"),
   152  		codegen.SimpleImport("unicode/utf8"),
   153  		codegen.SimpleImport("github.com/goadesign/goa"),
   154  		codegen.NewImport("uuid", "github.com/satori/go.uuid"),
   155  		codegen.SimpleImport("context"),
   156  	}
   157  	g.API.IterateResources(func(r *design.ResourceDefinition) error {
   158  		return r.IterateActions(func(a *design.ActionDefinition) error {
   159  			if a.Payload != nil {
   160  				imports = codegen.AttributeImports(a.Payload.AttributeDefinition, imports, nil)
   161  			}
   162  			return nil
   163  		})
   164  	})
   165  
   166  	g.genfiles = append(g.genfiles, ctxFile)
   167  	if err = ctxWr.WriteHeader(title, g.Target, imports); err != nil {
   168  		return
   169  	}
   170  	err = g.API.IterateResources(func(r *design.ResourceDefinition) error {
   171  		return r.IterateActions(func(a *design.ActionDefinition) error {
   172  			ctxName := codegen.Goify(a.Name, true) + codegen.Goify(a.Parent.Name, true) + "Context"
   173  			headers := &design.AttributeDefinition{
   174  				Type: design.Object{},
   175  			}
   176  			if r.Headers != nil {
   177  				headers.Merge(r.Headers)
   178  				headers.Validation = r.Headers.Validation
   179  			}
   180  			if a.Headers != nil {
   181  				headers.Merge(a.Headers)
   182  				headers.Validation = a.Headers.Validation
   183  			}
   184  			if headers != nil && len(headers.Type.ToObject()) == 0 {
   185  				headers = nil // So that {{if .Headers}} returns false in templates
   186  			}
   187  			params := a.AllParams()
   188  			if params != nil && len(params.Type.ToObject()) == 0 {
   189  				params = nil // So that {{if .Params}} returns false in templates
   190  			}
   191  
   192  			non101 := make(map[string]*design.ResponseDefinition)
   193  			for k, v := range a.Responses {
   194  				if v.Status != 101 {
   195  					non101[k] = v
   196  				}
   197  			}
   198  			ctxData := ContextTemplateData{
   199  				Name:         ctxName,
   200  				ResourceName: r.Name,
   201  				ActionName:   a.Name,
   202  				Payload:      a.Payload,
   203  				Params:       params,
   204  				Headers:      headers,
   205  				Routes:       a.Routes,
   206  				Responses:    non101,
   207  				API:          g.API,
   208  				DefaultPkg:   g.Target,
   209  				Security:     a.Security,
   210  			}
   211  			return ctxWr.Execute(&ctxData)
   212  		})
   213  	})
   214  	return
   215  }
   216  
   217  // generateControllers iterates through the API resources and generates the low level
   218  // controllers.
   219  func (g *Generator) generateControllers() (err error) {
   220  	var (
   221  		ctlFile string
   222  		ctlWr   *ControllersWriter
   223  	)
   224  	{
   225  		ctlFile = filepath.Join(g.OutDir, "controllers.go")
   226  		ctlWr, err = NewControllersWriter(ctlFile)
   227  		if err != nil {
   228  			return
   229  		}
   230  	}
   231  	defer func() {
   232  		ctlWr.Close()
   233  		if err == nil {
   234  			err = ctlWr.FormatCode()
   235  		}
   236  	}()
   237  	title := fmt.Sprintf("%s: Application Controllers", g.API.Context())
   238  	imports := []*codegen.ImportSpec{
   239  		codegen.SimpleImport("net/http"),
   240  		codegen.SimpleImport("fmt"),
   241  		codegen.SimpleImport("context"),
   242  		codegen.SimpleImport("github.com/goadesign/goa"),
   243  		codegen.SimpleImport("github.com/goadesign/goa/cors"),
   244  		codegen.SimpleImport("regexp"),
   245  	}
   246  	encoders, err := BuildEncoders(g.API.Produces, true)
   247  	if err != nil {
   248  		return err
   249  	}
   250  	decoders, err := BuildEncoders(g.API.Consumes, false)
   251  	if err != nil {
   252  		return err
   253  	}
   254  	encoderImports := make(map[string]bool)
   255  	for _, data := range encoders {
   256  		encoderImports[data.PackagePath] = true
   257  	}
   258  	for _, data := range decoders {
   259  		encoderImports[data.PackagePath] = true
   260  	}
   261  	var packagePaths []string
   262  	for packagePath := range encoderImports {
   263  		if packagePath != "github.com/goadesign/goa" {
   264  			packagePaths = append(packagePaths, packagePath)
   265  		}
   266  	}
   267  	sort.Strings(packagePaths)
   268  	for _, packagePath := range packagePaths {
   269  		imports = append(imports, codegen.SimpleImport(packagePath))
   270  	}
   271  	if err = ctlWr.WriteHeader(title, g.Target, imports); err != nil {
   272  		return err
   273  	}
   274  	if err = ctlWr.WriteInitService(encoders, decoders); err != nil {
   275  		return err
   276  	}
   277  
   278  	g.genfiles = append(g.genfiles, ctlFile)
   279  	var controllersData []*ControllerTemplateData
   280  	g.API.IterateResources(func(r *design.ResourceDefinition) error {
   281  		// Create file servers for all directory file servers that serve index.html.
   282  		fileServers := r.FileServers
   283  		for _, fs := range r.FileServers {
   284  			if fs.IsDir() {
   285  				rpath := design.WildcardRegex.ReplaceAllLiteralString(fs.RequestPath, "")
   286  				rpath += "/"
   287  				fileServers = append(fileServers, &design.FileServerDefinition{
   288  					Parent:      fs.Parent,
   289  					Description: fs.Description,
   290  					Docs:        fs.Docs,
   291  					FilePath:    filepath.Join(fs.FilePath, "index.html"),
   292  					RequestPath: rpath,
   293  					Metadata:    fs.Metadata,
   294  					Security:    fs.Security,
   295  				})
   296  			}
   297  		}
   298  		data := &ControllerTemplateData{
   299  			API:            g.API,
   300  			Resource:       codegen.Goify(r.Name, true),
   301  			PreflightPaths: r.PreflightPaths(),
   302  			FileServers:    fileServers,
   303  		}
   304  		r.IterateActions(func(a *design.ActionDefinition) error {
   305  			context := fmt.Sprintf("%s%sContext", codegen.Goify(a.Name, true), codegen.Goify(r.Name, true))
   306  			unmarshal := fmt.Sprintf("unmarshal%s%sPayload", codegen.Goify(a.Name, true), codegen.Goify(r.Name, true))
   307  			action := map[string]interface{}{
   308  				"Name":            codegen.Goify(a.Name, true),
   309  				"DesignName":      a.Name,
   310  				"Routes":          a.Routes,
   311  				"Context":         context,
   312  				"Unmarshal":       unmarshal,
   313  				"Payload":         a.Payload,
   314  				"PayloadOptional": a.PayloadOptional,
   315  				"Security":        a.Security,
   316  			}
   317  			data.Actions = append(data.Actions, action)
   318  			return nil
   319  		})
   320  		if len(data.Actions) > 0 || len(data.FileServers) > 0 {
   321  			data.Encoders = encoders
   322  			data.Decoders = decoders
   323  			data.Origins = r.AllOrigins()
   324  			controllersData = append(controllersData, data)
   325  		}
   326  		return nil
   327  	})
   328  	err = ctlWr.Execute(controllersData)
   329  	return
   330  }
   331  
   332  // generateControllers iterates through the API resources and generates the low level
   333  // controllers.
   334  func (g *Generator) generateSecurity() (err error) {
   335  	if len(g.API.SecuritySchemes) == 0 {
   336  		return nil
   337  	}
   338  
   339  	var (
   340  		secFile string
   341  		secWr   *SecurityWriter
   342  	)
   343  	{
   344  		secFile = filepath.Join(g.OutDir, "security.go")
   345  		secWr, err = NewSecurityWriter(secFile)
   346  		if err != nil {
   347  			return
   348  		}
   349  	}
   350  	defer func() {
   351  		secWr.Close()
   352  		if err == nil {
   353  			err = secWr.FormatCode()
   354  		}
   355  	}()
   356  	title := fmt.Sprintf("%s: Application Security", g.API.Context())
   357  	imports := []*codegen.ImportSpec{
   358  		codegen.SimpleImport("net/http"),
   359  		codegen.SimpleImport("errors"),
   360  		codegen.SimpleImport("context"),
   361  		codegen.SimpleImport("github.com/goadesign/goa"),
   362  	}
   363  	if err = secWr.WriteHeader(title, g.Target, imports); err != nil {
   364  		return err
   365  	}
   366  	g.genfiles = append(g.genfiles, secFile)
   367  	err = secWr.Execute(design.Design.SecuritySchemes)
   368  
   369  	return
   370  }
   371  
   372  // generateHrefs iterates through the API resources and generates the href factory methods.
   373  func (g *Generator) generateHrefs() (err error) {
   374  	var (
   375  		hrefFile string
   376  		resWr    *ResourcesWriter
   377  	)
   378  	{
   379  		hrefFile = filepath.Join(g.OutDir, "hrefs.go")
   380  		resWr, err = NewResourcesWriter(hrefFile)
   381  		if err != nil {
   382  			return
   383  		}
   384  	}
   385  	defer func() {
   386  		resWr.Close()
   387  		if err == nil {
   388  			err = resWr.FormatCode()
   389  		}
   390  	}()
   391  	title := fmt.Sprintf("%s: Application Resource Href Factories", g.API.Context())
   392  	imports := []*codegen.ImportSpec{
   393  		codegen.SimpleImport("fmt"),
   394  		codegen.SimpleImport("strings"),
   395  	}
   396  	if err = resWr.WriteHeader(title, g.Target, imports); err != nil {
   397  		return err
   398  	}
   399  	g.genfiles = append(g.genfiles, hrefFile)
   400  	err = g.API.IterateResources(func(r *design.ResourceDefinition) error {
   401  		m := g.API.MediaTypeWithIdentifier(r.MediaType)
   402  		var identifier string
   403  		if m != nil {
   404  			identifier = m.Identifier
   405  		} else {
   406  			identifier = "text/plain"
   407  		}
   408  		data := ResourceData{
   409  			Name:              codegen.Goify(r.Name, true),
   410  			Identifier:        identifier,
   411  			Description:       r.Description,
   412  			Type:              m,
   413  			CanonicalTemplate: codegen.CanonicalTemplate(r),
   414  			CanonicalParams:   codegen.CanonicalParams(r),
   415  		}
   416  		return resWr.Execute(&data)
   417  	})
   418  	return
   419  }
   420  
   421  // generateMediaTypes iterates through the media types and generate the data structures and
   422  // marshaling code.
   423  func (g *Generator) generateMediaTypes() (err error) {
   424  	var (
   425  		mtFile string
   426  		mtWr   *MediaTypesWriter
   427  	)
   428  	{
   429  		mtFile = filepath.Join(g.OutDir, "media_types.go")
   430  		mtWr, err = NewMediaTypesWriter(mtFile)
   431  		if err != nil {
   432  			return
   433  		}
   434  	}
   435  	defer func() {
   436  		mtWr.Close()
   437  		if err == nil {
   438  			err = mtWr.FormatCode()
   439  		}
   440  	}()
   441  	title := fmt.Sprintf("%s: Application Media Types", g.API.Context())
   442  	imports := []*codegen.ImportSpec{
   443  		codegen.SimpleImport("github.com/goadesign/goa"),
   444  		codegen.SimpleImport("fmt"),
   445  		codegen.SimpleImport("time"),
   446  		codegen.SimpleImport("unicode/utf8"),
   447  		codegen.NewImport("uuid", "github.com/satori/go.uuid"),
   448  	}
   449  	for _, v := range g.API.MediaTypes {
   450  		imports = codegen.AttributeImports(v.AttributeDefinition, imports, nil)
   451  	}
   452  	if err = mtWr.WriteHeader(title, g.Target, imports); err != nil {
   453  		return err
   454  	}
   455  	g.genfiles = append(g.genfiles, mtFile)
   456  	err = g.API.IterateMediaTypes(func(mt *design.MediaTypeDefinition) error {
   457  		if mt.IsError() {
   458  			return nil
   459  		}
   460  		if mt.Type.IsObject() || mt.Type.IsArray() {
   461  			return mtWr.Execute(mt)
   462  		}
   463  		return nil
   464  	})
   465  	return
   466  }
   467  
   468  // generateUserTypes iterates through the user types and generates the data structures and
   469  // marshaling code.
   470  func (g *Generator) generateUserTypes() (err error) {
   471  	var (
   472  		utFile string
   473  		utWr   *UserTypesWriter
   474  	)
   475  	{
   476  		utFile = filepath.Join(g.OutDir, "user_types.go")
   477  		utWr, err = NewUserTypesWriter(utFile)
   478  		if err != nil {
   479  			return
   480  		}
   481  	}
   482  	defer func() {
   483  		utWr.Close()
   484  		if err == nil {
   485  			err = utWr.FormatCode()
   486  		}
   487  	}()
   488  	title := fmt.Sprintf("%s: Application User Types", g.API.Context())
   489  	imports := []*codegen.ImportSpec{
   490  		codegen.SimpleImport("fmt"),
   491  		codegen.SimpleImport("time"),
   492  		codegen.SimpleImport("unicode/utf8"),
   493  		codegen.SimpleImport("github.com/goadesign/goa"),
   494  		codegen.NewImport("uuid", "github.com/satori/go.uuid"),
   495  	}
   496  	for _, v := range g.API.Types {
   497  		imports = codegen.AttributeImports(v.AttributeDefinition, imports, nil)
   498  	}
   499  	if err = utWr.WriteHeader(title, g.Target, imports); err != nil {
   500  		return err
   501  	}
   502  	g.genfiles = append(g.genfiles, utFile)
   503  	err = g.API.IterateUserTypes(func(t *design.UserTypeDefinition) error {
   504  		return utWr.Execute(t)
   505  	})
   506  	return
   507  }