github.com/brycereitano/goa@v0.0.0-20170315073847-8ffa6c85e265/goagen/gen_client/generator.go (about)

     1  package genclient
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"path"
     8  	"path/filepath"
     9  	"sort"
    10  	"strings"
    11  	"text/template"
    12  
    13  	"github.com/goadesign/goa/design"
    14  	"github.com/goadesign/goa/dslengine"
    15  	"github.com/goadesign/goa/goagen/codegen"
    16  	"github.com/goadesign/goa/goagen/gen_app"
    17  	"github.com/goadesign/goa/goagen/utils"
    18  )
    19  
    20  // Filename used to generate all data types (without the ".go" extension)
    21  const typesFileName = "datatypes"
    22  
    23  //NewGenerator returns an initialized instance of a Go Client Generator
    24  func NewGenerator(options ...Option) *Generator {
    25  	g := &Generator{}
    26  
    27  	for _, option := range options {
    28  		option(g)
    29  	}
    30  
    31  	return g
    32  }
    33  
    34  // Generator is the application code generator.
    35  type Generator struct {
    36  	API            *design.APIDefinition // The API definition
    37  	OutDir         string                // Path to output directory
    38  	Target         string                // Name of generated package
    39  	ToolDirName    string                // Name of tool directory where CLI main is generated once
    40  	Tool           string                // Name of CLI tool
    41  	NoTool         bool                  // Whether to skip tool generation
    42  	genfiles       []string
    43  	encoders       []*genapp.EncoderTemplateData
    44  	decoders       []*genapp.EncoderTemplateData
    45  	encoderImports []string
    46  }
    47  
    48  // Generate is the generator entry point called by the meta generator.
    49  func Generate() (files []string, err error) {
    50  	var (
    51  		outDir, target, toolDir, tool, ver string
    52  		notool                             bool
    53  	)
    54  	dtool := defaultToolName(design.Design)
    55  
    56  	set := flag.NewFlagSet("client", flag.PanicOnError)
    57  	set.StringVar(&outDir, "out", "", "")
    58  	set.StringVar(&target, "pkg", "client", "")
    59  	set.StringVar(&toolDir, "tooldir", "tool", "")
    60  	set.StringVar(&tool, "tool", dtool, "")
    61  	set.StringVar(&ver, "version", "", "")
    62  	set.BoolVar(&notool, "notool", false, "")
    63  	set.String("design", "", "")
    64  	set.Bool("force", false, "")
    65  	set.Bool("notest", false, "")
    66  	set.Parse(os.Args[1:])
    67  
    68  	// First check compatibility
    69  	if err := codegen.CheckVersion(ver); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	// Now proceed
    74  	target = codegen.Goify(target, false)
    75  	g := &Generator{OutDir: outDir, Target: target, ToolDirName: toolDir, Tool: tool, NoTool: notool, API: design.Design}
    76  
    77  	return g.Generate()
    78  }
    79  
    80  // Generate generats the client package and CLI.
    81  func (g *Generator) Generate() (_ []string, err error) {
    82  	if g.API == nil {
    83  		return nil, fmt.Errorf("missing API definition, make sure design is properly initialized")
    84  	}
    85  
    86  	go utils.Catch(nil, func() { g.Cleanup() })
    87  
    88  	defer func() {
    89  		if err != nil {
    90  			g.Cleanup()
    91  		}
    92  	}()
    93  
    94  	firstNonEmpty := func(args ...string) string {
    95  		for _, value := range args {
    96  			if len(value) > 0 {
    97  				return value
    98  			}
    99  		}
   100  		return ""
   101  	}
   102  
   103  	g.Target = firstNonEmpty(g.Target, "client")
   104  	g.ToolDirName = firstNonEmpty(g.ToolDirName, "tool")
   105  	g.Tool = firstNonEmpty(g.Tool, defaultToolName(g.API))
   106  
   107  	codegen.Reserved[g.Target] = true
   108  
   109  	// Setup output directories as needed
   110  	var pkgDir, toolDir, cliDir string
   111  	{
   112  		if !g.NoTool {
   113  			toolDir = filepath.Join(g.OutDir, g.ToolDirName, g.Tool)
   114  			if _, err = os.Stat(toolDir); err != nil {
   115  				if err = os.MkdirAll(toolDir, 0755); err != nil {
   116  					return
   117  				}
   118  			}
   119  
   120  			cliDir = filepath.Join(g.OutDir, g.ToolDirName, "cli")
   121  			if err = os.RemoveAll(cliDir); err != nil {
   122  				return
   123  			}
   124  			if err = os.MkdirAll(cliDir, 0755); err != nil {
   125  				return
   126  			}
   127  		}
   128  
   129  		pkgDir = filepath.Join(g.OutDir, g.Target)
   130  		if err = os.RemoveAll(pkgDir); err != nil {
   131  			return
   132  		}
   133  		if err = os.MkdirAll(pkgDir, 0755); err != nil {
   134  			return
   135  		}
   136  	}
   137  
   138  	// Setup generation
   139  	var funcs template.FuncMap
   140  	var clientPkg string
   141  	{
   142  		funcs = template.FuncMap{
   143  			"add":                func(a, b int) int { return a + b },
   144  			"cmdFieldType":       cmdFieldType,
   145  			"defaultPath":        defaultPath,
   146  			"escapeBackticks":    escapeBackticks,
   147  			"goify":              codegen.Goify,
   148  			"gotypedef":          codegen.GoTypeDef,
   149  			"gotypedesc":         codegen.GoTypeDesc,
   150  			"gotypename":         codegen.GoTypeName,
   151  			"gotyperef":          codegen.GoTypeRef,
   152  			"gotyperefext":       goTypeRefExt,
   153  			"join":               join,
   154  			"joinStrings":        strings.Join,
   155  			"multiComment":       multiComment,
   156  			"pathParams":         pathParams,
   157  			"pathTemplate":       pathTemplate,
   158  			"signerType":         signerType,
   159  			"tempvar":            codegen.Tempvar,
   160  			"title":              strings.Title,
   161  			"toString":           toString,
   162  			"typeName":           typeName,
   163  			"format":             format,
   164  			"handleSpecialTypes": handleSpecialTypes,
   165  		}
   166  		clientPkg, err = codegen.PackagePath(pkgDir)
   167  		if err != nil {
   168  			return
   169  		}
   170  		arrayToStringTmpl = template.Must(template.New("client").Funcs(funcs).Parse(arrayToStringT))
   171  	}
   172  
   173  	if !g.NoTool {
   174  		var cliPkg string
   175  		cliPkg, err = codegen.PackagePath(cliDir)
   176  		if err != nil {
   177  			return
   178  		}
   179  
   180  		// Generate tool/main.go (only once)
   181  		mainFile := filepath.Join(toolDir, "main.go")
   182  		if _, err := os.Stat(mainFile); err != nil {
   183  			g.genfiles = append(g.genfiles, toolDir)
   184  			if err = g.generateMain(mainFile, clientPkg, cliPkg, funcs); err != nil {
   185  				return nil, err
   186  			}
   187  		}
   188  
   189  		// Generate tool/cli/commands.go
   190  		g.genfiles = append(g.genfiles, cliDir)
   191  		if err = g.generateCommands(filepath.Join(cliDir, "commands.go"), clientPkg, funcs); err != nil {
   192  			return
   193  		}
   194  	}
   195  
   196  	// Generate client/client.go
   197  	g.genfiles = append(g.genfiles, pkgDir)
   198  	if err = g.generateClient(filepath.Join(pkgDir, "client.go"), clientPkg, funcs); err != nil {
   199  		return
   200  	}
   201  
   202  	// Generate client/$res.go and types.go
   203  	if err = g.generateClientResources(pkgDir, clientPkg, funcs); err != nil {
   204  		return
   205  	}
   206  
   207  	return g.genfiles, nil
   208  }
   209  
   210  func defaultToolName(api *design.APIDefinition) string {
   211  	if api == nil {
   212  		return ""
   213  	}
   214  	return strings.Replace(strings.ToLower(api.Name), " ", "-", -1) + "-cli"
   215  }
   216  
   217  // Cleanup removes all the files generated by this generator during the last invokation of Generate.
   218  func (g *Generator) Cleanup() {
   219  	for _, f := range g.genfiles {
   220  		os.Remove(f)
   221  	}
   222  	g.genfiles = nil
   223  }
   224  
   225  func (g *Generator) generateClient(clientFile string, clientPkg string, funcs template.FuncMap) error {
   226  	file, err := codegen.SourceFileFor(clientFile)
   227  	if err != nil {
   228  		return err
   229  	}
   230  	clientTmpl := template.Must(template.New("client").Funcs(funcs).Parse(clientTmpl))
   231  
   232  	// Compute list of encoders and decoders
   233  	encoders, err := genapp.BuildEncoders(g.API.Produces, true)
   234  	if err != nil {
   235  		return err
   236  	}
   237  	decoders, err := genapp.BuildEncoders(g.API.Consumes, false)
   238  	if err != nil {
   239  		return err
   240  	}
   241  	im := make(map[string]bool)
   242  	for _, data := range encoders {
   243  		im[data.PackagePath] = true
   244  	}
   245  	for _, data := range decoders {
   246  		im[data.PackagePath] = true
   247  	}
   248  	var packagePaths []string
   249  	for packagePath := range im {
   250  		if packagePath != "github.com/goadesign/goa" {
   251  			packagePaths = append(packagePaths, packagePath)
   252  		}
   253  	}
   254  	sort.Strings(packagePaths)
   255  
   256  	// Setup codegen
   257  	imports := []*codegen.ImportSpec{
   258  		codegen.SimpleImport("net/http"),
   259  		codegen.SimpleImport("github.com/goadesign/goa"),
   260  		codegen.NewImport("goaclient", "github.com/goadesign/goa/client"),
   261  		codegen.NewImport("uuid", "github.com/goadesign/goa/uuid"),
   262  	}
   263  	for _, packagePath := range packagePaths {
   264  		imports = append(imports, codegen.SimpleImport(packagePath))
   265  	}
   266  	title := fmt.Sprintf("%s: Client", g.API.Context())
   267  	if err := file.WriteHeader(title, g.Target, imports); err != nil {
   268  		return err
   269  	}
   270  	g.genfiles = append(g.genfiles, clientFile)
   271  
   272  	// Generate
   273  	data := struct {
   274  		API      *design.APIDefinition
   275  		Encoders []*genapp.EncoderTemplateData
   276  		Decoders []*genapp.EncoderTemplateData
   277  	}{
   278  		API:      g.API,
   279  		Encoders: encoders,
   280  		Decoders: decoders,
   281  	}
   282  	if err := clientTmpl.Execute(file, data); err != nil {
   283  		return err
   284  	}
   285  
   286  	return file.FormatCode()
   287  }
   288  
   289  func (g *Generator) generateClientResources(pkgDir, clientPkg string, funcs template.FuncMap) error {
   290  	err := g.API.IterateResources(func(res *design.ResourceDefinition) error {
   291  		return g.generateResourceClient(pkgDir, res, funcs)
   292  	})
   293  	if err != nil {
   294  		return err
   295  	}
   296  	if err := g.generateUserTypes(pkgDir); err != nil {
   297  		return err
   298  	}
   299  
   300  	return g.generateMediaTypes(pkgDir, funcs)
   301  }
   302  
   303  func (g *Generator) generateResourceClient(pkgDir string, res *design.ResourceDefinition, funcs template.FuncMap) error {
   304  	payloadTmpl := template.Must(template.New("payload").Funcs(funcs).Parse(payloadTmpl))
   305  	pathTmpl := template.Must(template.New("pathTemplate").Funcs(funcs).Parse(pathTmpl))
   306  
   307  	resFilename := codegen.SnakeCase(res.Name)
   308  	if resFilename == typesFileName {
   309  		// Avoid clash with datatypes.go
   310  		resFilename += "_client"
   311  	}
   312  	filename := filepath.Join(pkgDir, resFilename+".go")
   313  	file, err := codegen.SourceFileFor(filename)
   314  	if err != nil {
   315  		return err
   316  	}
   317  	imports := []*codegen.ImportSpec{
   318  		codegen.SimpleImport("bytes"),
   319  		codegen.SimpleImport("encoding/json"),
   320  		codegen.SimpleImport("fmt"),
   321  		codegen.SimpleImport("io"),
   322  		codegen.SimpleImport("io/ioutil"),
   323  		codegen.SimpleImport("net/http"),
   324  		codegen.SimpleImport("net/url"),
   325  		codegen.SimpleImport("os"),
   326  		codegen.SimpleImport("path"),
   327  		codegen.SimpleImport("strconv"),
   328  		codegen.SimpleImport("strings"),
   329  		codegen.SimpleImport("time"),
   330  		codegen.SimpleImport("golang.org/x/net/context"),
   331  		codegen.SimpleImport("golang.org/x/net/websocket"),
   332  		codegen.NewImport("uuid", "github.com/goadesign/goa/uuid"),
   333  	}
   334  	title := fmt.Sprintf("%s: %s Resource Client", g.API.Context(), res.Name)
   335  	if err := file.WriteHeader(title, g.Target, imports); err != nil {
   336  		return err
   337  	}
   338  	g.genfiles = append(g.genfiles, filename)
   339  
   340  	err = res.IterateFileServers(func(fs *design.FileServerDefinition) error {
   341  		return g.generateFileServer(file, fs, funcs)
   342  	})
   343  
   344  	err = res.IterateActions(func(action *design.ActionDefinition) error {
   345  		if action.Payload != nil {
   346  			found := false
   347  			typeName := action.Payload.TypeName
   348  			for _, t := range design.Design.Types {
   349  				if t.TypeName == typeName {
   350  					found = true
   351  					break
   352  				}
   353  			}
   354  			if !found {
   355  				if err := payloadTmpl.Execute(file, action); err != nil {
   356  					return err
   357  				}
   358  			}
   359  		}
   360  		for i, r := range action.Routes {
   361  			routeParams := r.Params()
   362  			var pd []*paramData
   363  
   364  			for _, p := range routeParams {
   365  				requiredParams, _ := initParams(&design.AttributeDefinition{
   366  					Type: &design.Object{
   367  						p: action.Params.Type.ToObject()[p],
   368  					},
   369  					Validation: &dslengine.ValidationDefinition{
   370  						Required: routeParams,
   371  					},
   372  				})
   373  				pd = append(pd, requiredParams...)
   374  			}
   375  
   376  			data := struct {
   377  				Route  *design.RouteDefinition
   378  				Index  int
   379  				Params []*paramData
   380  			}{
   381  				Route:  r,
   382  				Index:  i,
   383  				Params: pd,
   384  			}
   385  			if err := pathTmpl.Execute(file, data); err != nil {
   386  				return err
   387  			}
   388  		}
   389  		return g.generateActionClient(action, file, funcs)
   390  	})
   391  	if err != nil {
   392  		return err
   393  	}
   394  
   395  	return file.FormatCode()
   396  }
   397  
   398  func (g *Generator) generateFileServer(file *codegen.SourceFile, fs *design.FileServerDefinition, funcs template.FuncMap) error {
   399  	var (
   400  		dir string
   401  
   402  		fsTmpl = template.Must(template.New("fileserver").Funcs(funcs).Parse(fsTmpl))
   403  		name   = g.fileServerMethod(fs)
   404  		wcs    = design.ExtractWildcards(fs.RequestPath)
   405  		scheme = "http"
   406  	)
   407  
   408  	if len(wcs) > 0 {
   409  		dir = "/"
   410  		fileElems := filepath.SplitList(fs.FilePath)
   411  		if len(fileElems) > 1 {
   412  			dir = fileElems[len(fileElems)-2]
   413  		}
   414  	}
   415  	if len(design.Design.Schemes) > 0 {
   416  		scheme = design.Design.Schemes[0]
   417  	}
   418  	requestDir, _ := path.Split(fs.RequestPath)
   419  
   420  	data := struct {
   421  		Name            string // Download functionn name
   422  		RequestPath     string // File server request path
   423  		FilePath        string // File server file path
   424  		FileName        string // Filename being download if request path has no wildcard
   425  		DirName         string // Parent directory name if request path has wildcard
   426  		RequestDir      string // Request path without wildcard suffix
   427  		CanonicalScheme string // HTTP scheme
   428  	}{
   429  		Name:            name,
   430  		RequestPath:     fs.RequestPath,
   431  		FilePath:        fs.FilePath,
   432  		FileName:        filepath.Base(fs.FilePath),
   433  		DirName:         dir,
   434  		RequestDir:      requestDir,
   435  		CanonicalScheme: scheme,
   436  	}
   437  	return fsTmpl.Execute(file, data)
   438  }
   439  
   440  func (g *Generator) generateActionClient(action *design.ActionDefinition, file *codegen.SourceFile, funcs template.FuncMap) error {
   441  	var (
   442  		params        []string
   443  		names         []string
   444  		queryParams   []*paramData
   445  		headers       []*paramData
   446  		signer        string
   447  		clientsTmpl   = template.Must(template.New("clients").Funcs(funcs).Parse(clientsTmpl))
   448  		requestsTmpl  = template.Must(template.New("requests").Funcs(funcs).Parse(requestsTmpl))
   449  		clientsWSTmpl = template.Must(template.New("clientsws").Funcs(funcs).Parse(clientsWSTmpl))
   450  	)
   451  	if action.Payload != nil {
   452  		params = append(params, "payload "+codegen.GoTypeRef(action.Payload, action.Payload.AllRequired(), 1, false))
   453  		names = append(names, "payload")
   454  	}
   455  
   456  	initParamsScoped := func(att *design.AttributeDefinition) []*paramData {
   457  		reqData, optData := initParams(att)
   458  
   459  		sort.Sort(byParamName(reqData))
   460  		sort.Sort(byParamName(optData))
   461  
   462  		// Update closure
   463  		for _, p := range reqData {
   464  			names = append(names, p.VarName)
   465  			params = append(params, p.VarName+" "+cmdFieldType(p.Attribute.Type, false))
   466  		}
   467  		for _, p := range optData {
   468  			names = append(names, p.VarName)
   469  			params = append(params, p.VarName+" "+cmdFieldType(p.Attribute.Type, p.Attribute.Type.IsPrimitive()))
   470  		}
   471  		return append(reqData, optData...)
   472  	}
   473  	queryParams = initParamsScoped(action.QueryParams)
   474  	headers = initParamsScoped(action.Headers)
   475  
   476  	if action.Security != nil {
   477  		signer = codegen.Goify(action.Security.Scheme.SchemeName, true)
   478  	}
   479  	data := struct {
   480  		Name            string
   481  		ResourceName    string
   482  		Description     string
   483  		Routes          []*design.RouteDefinition
   484  		HasPayload      bool
   485  		HasMultiContent bool
   486  		Params          string
   487  		ParamNames      string
   488  		CanonicalScheme string
   489  		Signer          string
   490  		QueryParams     []*paramData
   491  		Headers         []*paramData
   492  	}{
   493  		Name:            action.Name,
   494  		ResourceName:    action.Parent.Name,
   495  		Description:     action.Description,
   496  		Routes:          action.Routes,
   497  		HasPayload:      action.Payload != nil,
   498  		HasMultiContent: len(design.Design.Consumes) > 1,
   499  		Params:          strings.Join(params, ", "),
   500  		ParamNames:      strings.Join(names, ", "),
   501  		CanonicalScheme: action.CanonicalScheme(),
   502  		Signer:          signer,
   503  		QueryParams:     queryParams,
   504  		Headers:         headers,
   505  	}
   506  	if action.WebSocket() {
   507  		return clientsWSTmpl.Execute(file, data)
   508  	}
   509  	if err := clientsTmpl.Execute(file, data); err != nil {
   510  		return err
   511  	}
   512  	return requestsTmpl.Execute(file, data)
   513  }
   514  
   515  // fileServerMethod returns the name of the client method for downloading assets served by the given
   516  // file server.
   517  // Note: the implementation opts for generating good names rather than names that are guaranteed to
   518  // be unique. This means that the generated code could be potentially incorrect in the rare cases
   519  // where it produces the same names for two different file servers. This should be addressed later
   520  // (when it comes up?) using metadata to let users override the default.
   521  func (g *Generator) fileServerMethod(fs *design.FileServerDefinition) string {
   522  	var (
   523  		suffix string
   524  
   525  		wcs      = design.ExtractWildcards(fs.RequestPath)
   526  		reqElems = strings.Split(fs.RequestPath, "/")
   527  	)
   528  
   529  	if len(wcs) == 0 {
   530  		suffix = path.Base(fs.RequestPath)
   531  		ext := filepath.Ext(suffix)
   532  		suffix = strings.TrimSuffix(suffix, ext)
   533  		suffix += codegen.Goify(ext, true)
   534  	} else {
   535  		if len(reqElems) == 1 {
   536  			suffix = filepath.Base(fs.RequestPath)
   537  			suffix = suffix[1:] // remove "*" prefix
   538  		} else {
   539  			suffix = reqElems[len(reqElems)-2] // should work most of the time
   540  		}
   541  	}
   542  	return "Download" + codegen.Goify(suffix, true)
   543  }
   544  
   545  // generateMediaTypes iterates through the media types and generate the data structures and
   546  // marshaling code.
   547  func (g *Generator) generateMediaTypes(pkgDir string, funcs template.FuncMap) error {
   548  	funcs["decodegotyperef"] = decodeGoTypeRef
   549  	funcs["decodegotypename"] = decodeGoTypeName
   550  	typeDecodeTmpl := template.Must(template.New("typeDecode").Funcs(funcs).Parse(typeDecodeTmpl))
   551  	mtFile := filepath.Join(pkgDir, "media_types.go")
   552  	mtWr, err := genapp.NewMediaTypesWriter(mtFile)
   553  	if err != nil {
   554  		panic(err) // bug
   555  	}
   556  	title := fmt.Sprintf("%s: Application Media Types", g.API.Context())
   557  	imports := []*codegen.ImportSpec{
   558  		codegen.SimpleImport("github.com/goadesign/goa"),
   559  		codegen.SimpleImport("fmt"),
   560  		codegen.SimpleImport("net/http"),
   561  		codegen.SimpleImport("time"),
   562  		codegen.SimpleImport("unicode/utf8"),
   563  		codegen.NewImport("uuid", "github.com/goadesign/goa/uuid"),
   564  	}
   565  	for _, v := range g.API.MediaTypes {
   566  		imports = codegen.AttributeImports(v.AttributeDefinition, imports, nil)
   567  	}
   568  	mtWr.WriteHeader(title, g.Target, imports)
   569  	err = g.API.IterateMediaTypes(func(mt *design.MediaTypeDefinition) error {
   570  		if (mt.Type.IsObject() || mt.Type.IsArray()) && !mt.IsError() {
   571  			if err := mtWr.Execute(mt); err != nil {
   572  				return err
   573  			}
   574  		}
   575  		err := mt.IterateViews(func(view *design.ViewDefinition) error {
   576  			p, _, err := mt.Project(view.Name)
   577  			if err != nil {
   578  				return err
   579  			}
   580  			if err := typeDecodeTmpl.Execute(mtWr.SourceFile, p); err != nil {
   581  				return err
   582  			}
   583  			return nil
   584  		})
   585  		return err
   586  	})
   587  	g.genfiles = append(g.genfiles, mtFile)
   588  	if err != nil {
   589  		return err
   590  	}
   591  	return mtWr.FormatCode()
   592  }
   593  
   594  // generateUserTypes iterates through the user types and generates the data structures and
   595  // marshaling code.
   596  func (g *Generator) generateUserTypes(pkgDir string) error {
   597  	utFile := filepath.Join(pkgDir, "user_types.go")
   598  	utWr, err := genapp.NewUserTypesWriter(utFile)
   599  	if err != nil {
   600  		panic(err) // bug
   601  	}
   602  	title := fmt.Sprintf("%s: Application User Types", g.API.Context())
   603  	imports := []*codegen.ImportSpec{
   604  		codegen.SimpleImport("github.com/goadesign/goa"),
   605  		codegen.SimpleImport("fmt"),
   606  		codegen.SimpleImport("time"),
   607  		codegen.SimpleImport("unicode/utf8"),
   608  		codegen.NewImport("uuid", "github.com/goadesign/goa/uuid"),
   609  	}
   610  	for _, v := range g.API.Types {
   611  		imports = codegen.AttributeImports(v.AttributeDefinition, imports, nil)
   612  	}
   613  	utWr.WriteHeader(title, g.Target, imports)
   614  	err = g.API.IterateUserTypes(func(t *design.UserTypeDefinition) error {
   615  		return utWr.Execute(t)
   616  	})
   617  	g.genfiles = append(g.genfiles, utFile)
   618  	if err != nil {
   619  		return err
   620  	}
   621  	return utWr.FormatCode()
   622  }
   623  
   624  // join is a code generation helper function that generates a function signature built from
   625  // concatenating the properties (name type) of the given attribute type (assuming it's an object).
   626  // join accepts an optional slice of strings which indicates the order in which the parameters
   627  // should appear in the signature. If pos is specified then it must list all the parameters. If
   628  // it's not specified then parameters are sorted alphabetically.
   629  func join(att *design.AttributeDefinition, usePointers bool, pos ...[]string) string {
   630  	if att == nil {
   631  		return ""
   632  	}
   633  	obj := att.Type.ToObject()
   634  	elems := make([]string, len(obj))
   635  	var keys []string
   636  	if len(pos) > 0 {
   637  		keys = pos[0]
   638  		if len(keys) != len(obj) {
   639  			panic("invalid position slice, lenght does not match attribute field count") // bug
   640  		}
   641  	} else {
   642  		keys = make([]string, len(obj))
   643  		i := 0
   644  		for n := range obj {
   645  			keys[i] = n
   646  			i++
   647  		}
   648  		sort.Strings(keys)
   649  	}
   650  	for i, n := range keys {
   651  		a := obj[n]
   652  		elems[i] = fmt.Sprintf("%s %s", codegen.Goify(n, false),
   653  			cmdFieldType(a.Type, usePointers && !a.IsRequired(n)))
   654  	}
   655  	return strings.Join(elems, ", ")
   656  }
   657  
   658  // escapeBackticks is a code generation helper that escapes backticks in a string.
   659  func escapeBackticks(text string) string {
   660  	return strings.Replace(text, "`", "`+\"`\"+`", -1)
   661  }
   662  
   663  // multiComment produces a Go comment containing the given string taking into account newlines.
   664  func multiComment(text string) string {
   665  	lines := strings.Split(text, "\n")
   666  	nl := make([]string, len(lines))
   667  	for i, l := range lines {
   668  		nl[i] = "// " + strings.TrimSpace(l)
   669  	}
   670  	return strings.Join(nl, "\n")
   671  }
   672  
   673  // gotTypeRefExt computes the type reference for a type in a different package.
   674  func goTypeRefExt(t design.DataType, tabs int, pkg string) string {
   675  	ref := codegen.GoTypeRef(t, nil, tabs, false)
   676  	if strings.HasPrefix(ref, "*") {
   677  		return fmt.Sprintf("%s.%s", pkg, ref[1:])
   678  	}
   679  	return fmt.Sprintf("%s.%s", pkg, ref)
   680  }
   681  
   682  // decodeGoTypeRef handles the case where the type being decoded is a error response media type.
   683  func decodeGoTypeRef(t design.DataType, required []string, tabs int, private bool) string {
   684  	mt, ok := t.(*design.MediaTypeDefinition)
   685  	if ok && mt.IsError() {
   686  		return "*goa.ErrorResponse"
   687  	}
   688  	return codegen.GoTypeRef(t, required, tabs, private)
   689  }
   690  
   691  // decodeGoTypeName handles the case where the type being decoded is a error response media type.
   692  func decodeGoTypeName(t design.DataType, required []string, tabs int, private bool) string {
   693  	mt, ok := t.(*design.MediaTypeDefinition)
   694  	if ok && mt.IsError() {
   695  		return "goa.ErrorResponse"
   696  	}
   697  	return codegen.GoTypeName(t, required, tabs, private)
   698  }
   699  
   700  // cmdFieldType computes the Go type name used to store command flags of the given design type.
   701  func cmdFieldType(t design.DataType, point bool) string {
   702  	var pointer, suffix string
   703  	if point && !t.IsArray() {
   704  		pointer = "*"
   705  	}
   706  	suffix = codegen.GoNativeType(t)
   707  	return pointer + suffix
   708  }
   709  
   710  // cmdFieldTypeString computes the Go type name used to store command flags of the given design type. Complex types are String
   711  func cmdFieldTypeString(t design.DataType, point bool) string {
   712  	var pointer, suffix string
   713  	if point && !t.IsArray() {
   714  		pointer = "*"
   715  	}
   716  	if t.Kind() == design.UUIDKind || t.Kind() == design.DateTimeKind || t.Kind() == design.AnyKind || t.Kind() == design.NumberKind || t.Kind() == design.BooleanKind {
   717  		suffix = "string"
   718  	} else if isArrayOfType(t, design.UUIDKind, design.DateTimeKind, design.AnyKind, design.NumberKind, design.BooleanKind) {
   719  		suffix = "[]string"
   720  	} else {
   721  		suffix = codegen.GoNativeType(t)
   722  	}
   723  	return pointer + suffix
   724  }
   725  
   726  func isArrayOfType(array design.DataType, kinds ...design.Kind) bool {
   727  	if !array.IsArray() {
   728  		return false
   729  	}
   730  	kind := array.ToArray().ElemType.Type.Kind()
   731  	for _, t := range kinds {
   732  		if t == kind {
   733  			return true
   734  		}
   735  	}
   736  	return false
   737  }
   738  
   739  // template used to produce code that serializes arrays of simple values into comma separated
   740  // strings.
   741  var arrayToStringTmpl *template.Template
   742  
   743  // toString generates Go code that converts the given simple type attribute into a string.
   744  func toString(name, target string, att *design.AttributeDefinition) string {
   745  	switch actual := att.Type.(type) {
   746  	case design.Primitive:
   747  		switch actual.Kind() {
   748  		case design.IntegerKind:
   749  			return fmt.Sprintf("%s := strconv.Itoa(%s)", target, name)
   750  		case design.BooleanKind:
   751  			return fmt.Sprintf("%s := strconv.FormatBool(%s)", target, name)
   752  		case design.NumberKind:
   753  			return fmt.Sprintf("%s := strconv.FormatFloat(%s, 'f', -1, 64)", target, name)
   754  		case design.StringKind:
   755  			return fmt.Sprintf("%s := %s", target, name)
   756  		case design.DateTimeKind:
   757  			return fmt.Sprintf("%s := %s.Format(time.RFC3339)", target, strings.Replace(name, "*", "", -1)) // remove pointer if present
   758  		case design.UUIDKind:
   759  			return fmt.Sprintf("%s := %s.String()", target, strings.Replace(name, "*", "", -1)) // remove pointer if present
   760  		case design.AnyKind:
   761  			return fmt.Sprintf("%s := fmt.Sprintf(\"%%v\", %s)", target, name)
   762  		default:
   763  			panic("unknown primitive type")
   764  		}
   765  	case *design.Array:
   766  		data := map[string]interface{}{
   767  			"Name":     name,
   768  			"Target":   target,
   769  			"ElemType": actual.ElemType,
   770  		}
   771  		return codegen.RunTemplate(arrayToStringTmpl, data)
   772  	default:
   773  		panic("cannot convert non simple type " + att.Type.Name() + " to string") // bug
   774  	}
   775  }
   776  
   777  // defaultPath returns the first route path for the given action that does not take any wildcard,
   778  // empty string if none.
   779  func defaultPath(action *design.ActionDefinition) string {
   780  	for _, r := range action.Routes {
   781  		candidate := r.FullPath()
   782  		if !strings.ContainsRune(candidate, ':') {
   783  			return candidate
   784  		}
   785  	}
   786  	return ""
   787  }
   788  
   789  // signerType returns the name of the client signer used for the defined security model on the Action
   790  func signerType(scheme *design.SecuritySchemeDefinition) string {
   791  	switch scheme.Kind {
   792  	case design.JWTSecurityKind:
   793  		return "goaclient.JWTSigner" // goa client package imported under goaclient
   794  	case design.OAuth2SecurityKind:
   795  		return "goaclient.OAuth2Signer"
   796  	case design.APIKeySecurityKind:
   797  		return "goaclient.APIKeySigner"
   798  	case design.BasicAuthSecurityKind:
   799  		return "goaclient.BasicSigner"
   800  	}
   801  	return ""
   802  }
   803  
   804  // pathTemplate returns a fmt format suitable to build a request path to the route.
   805  func pathTemplate(r *design.RouteDefinition) string {
   806  	return design.WildcardRegex.ReplaceAllLiteralString(r.FullPath(), "/%s")
   807  }
   808  
   809  // pathParams return the function signature of the path factory function for the given route.
   810  func pathParams(r *design.RouteDefinition) string {
   811  	pnames := r.Params()
   812  	params := make(design.Object, len(pnames))
   813  	for _, p := range pnames {
   814  		params[p] = r.Parent.Params.Type.ToObject()[p]
   815  	}
   816  	return join(&design.AttributeDefinition{Type: params}, false, pnames)
   817  }
   818  
   819  // typeName returns Go type name of given MediaType definition.
   820  func typeName(mt *design.MediaTypeDefinition) string {
   821  	if mt.IsError() {
   822  		return "ErrorResponse"
   823  	}
   824  	return codegen.GoTypeName(mt, mt.AllRequired(), 1, false)
   825  }
   826  
   827  // initParams returns required and optional paramData extracted from given attribute definition.
   828  func initParams(att *design.AttributeDefinition) ([]*paramData, []*paramData) {
   829  	if att == nil {
   830  		return nil, nil
   831  	}
   832  	obj := att.Type.ToObject()
   833  	var reqParamData []*paramData
   834  	var optParamData []*paramData
   835  	for n, q := range obj {
   836  		varName := codegen.Goify(n, false)
   837  		param := &paramData{
   838  			Name:      n,
   839  			VarName:   varName,
   840  			Attribute: q,
   841  		}
   842  		if q.Type.IsPrimitive() {
   843  			param.MustToString = q.Type.Kind() != design.StringKind
   844  			if att.IsRequired(n) {
   845  				param.ValueName = varName
   846  				reqParamData = append(reqParamData, param)
   847  			} else {
   848  				param.ValueName = "*" + varName
   849  				param.CheckNil = true
   850  				optParamData = append(optParamData, param)
   851  			}
   852  		} else {
   853  			if q.Type.IsArray() {
   854  				param.IsArray = true
   855  				param.ElemAttribute = q.Type.ToArray().ElemType
   856  			}
   857  			param.MustToString = true
   858  			param.ValueName = varName
   859  			param.CheckNil = true
   860  			if att.IsRequired(n) {
   861  				reqParamData = append(reqParamData, param)
   862  			} else {
   863  				optParamData = append(optParamData, param)
   864  			}
   865  		}
   866  	}
   867  
   868  	return reqParamData, optParamData
   869  }
   870  
   871  // paramData is the data structure holding the information needed to generate query params and
   872  // headers handling code.
   873  type paramData struct {
   874  	Name          string
   875  	VarName       string
   876  	ValueName     string
   877  	Attribute     *design.AttributeDefinition
   878  	ElemAttribute *design.AttributeDefinition
   879  	MustToString  bool
   880  	IsArray       bool
   881  	CheckNil      bool
   882  }
   883  
   884  type byParamName []*paramData
   885  
   886  func (b byParamName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
   887  func (b byParamName) Less(i, j int) bool { return b[i].Name < b[j].Name }
   888  func (b byParamName) Len() int           { return len(b) }
   889  
   890  const (
   891  	arrayToStringT = `	{{ $tmp := tempvar }}{{ $tmp }} := make([]string, len({{ .Name }}))
   892  	for i, e := range {{ .Name }} {
   893  		{{ $tmp2 := tempvar }}{{ toString "e" $tmp2 .ElemType }}
   894  		{{ $tmp }}[i] = {{ $tmp2 }}
   895  	}
   896  	{{ .Target }} := strings.Join({{ $tmp }}, ",")`
   897  
   898  	payloadTmpl = `// {{ gotypename .Payload nil 0 false }} is the {{ .Parent.Name }} {{ .Name }} action payload.
   899  type {{ gotypename .Payload nil 1 false }} {{ gotypedef .Payload 0 true false }}
   900  `
   901  
   902  	typeDecodeTmpl = `{{ $typeName := typeName . }}{{ $funcName := printf "Decode%s" $typeName }}// {{ $funcName }} decodes the {{ $typeName }} instance encoded in resp body.
   903  func (c *Client) {{ $funcName }}(resp *http.Response) ({{ decodegotyperef . .AllRequired 0 false }}, error) {
   904  	var decoded {{ decodegotypename . .AllRequired 0 false }}
   905  	err := c.Decoder.Decode(&decoded, resp.Body, resp.Header.Get("Content-Type"))
   906  	return {{ if .IsObject }}&{{ end }}decoded, err
   907  }
   908  `
   909  
   910  	pathTmpl = `{{ $funcName := printf "%sPath%s" (goify (printf "%s%s" .Route.Parent.Name (title .Route.Parent.Parent.Name)) true) ((or (and .Index (add .Index 1)) "") | printf "%v") }}{{/*
   911  */}}// {{ $funcName }} computes a request path to the {{ .Route.Parent.Name }} action of {{ .Route.Parent.Parent.Name }}.
   912  func {{ $funcName }}({{ pathParams .Route }}) string {
   913  	{{ range $i, $param := .Params }}{{/*
   914  */}}{{ toString $param.VarName (printf "param%d" $i) $param.Attribute }}
   915  	{{ end }}
   916  	return fmt.Sprintf({{ printf "%q" (pathTemplate .Route) }}{{ range $i, $param := .Params }}, {{ printf "param%d" $i }}{{ end }})
   917  }
   918  `
   919  
   920  	clientsTmpl = `{{ $funcName := goify (printf "%s%s" .Name (title .ResourceName)) true }}{{ $desc := .Description }}{{/*
   921  */}}{{ if $desc }}{{ multiComment $desc }}{{ else }}{{/*
   922  */}}// {{ $funcName }} makes a request to the {{ .Name }} action endpoint of the {{ .ResourceName }} resource{{ end }}
   923  func (c *Client) {{ $funcName }}(ctx context.Context, path string{{ if .Params }}, {{ .Params }}{{ end }}{{ if and .HasPayload .HasMultiContent }}, contentType string{{ end }}) (*http.Response, error) {
   924  	req, err := c.New{{ $funcName }}Request(ctx, path{{ if .ParamNames }}, {{ .ParamNames }}{{ end }}{{ if and .HasPayload .HasMultiContent }}, contentType{{ end }})
   925  	if err != nil {
   926  		return nil, err
   927  	}
   928  	return c.Client.Do(ctx, req)
   929  }
   930  `
   931  
   932  	clientsWSTmpl = `{{ $funcName := goify (printf "%s%s" .Name (title .ResourceName)) true }}{{ $desc := .Description }}{{/*
   933  */}}{{ if $desc }}{{ multiComment $desc }}{{ else }}// {{ $funcName }} establishes a websocket connection to the {{ .Name }} action endpoint of the {{ .ResourceName }} resource{{ end }}
   934  func (c *Client) {{ $funcName }}(ctx context.Context, path string{{ if .Params }}, {{ .Params }}{{ end }}) (*websocket.Conn, error) {
   935  	scheme := c.Scheme
   936  	if scheme == "" {
   937  		scheme = "{{ .CanonicalScheme }}"
   938  	}
   939  	u := url.URL{Host: c.Host, Scheme: scheme, Path: path}
   940  {{ if .QueryParams }}	values := u.Query()
   941  {{ range .QueryParams }}{{ if .CheckNil }}	if {{ .VarName }} != nil {
   942  	{{ end }}{{/*
   943  
   944  // ARRAY
   945  */}}{{ if .IsArray }}		for _, p := range {{ .VarName }} {
   946  {{ if .MustToString }}{{ $tmp := tempvar }}			{{ toString "p" $tmp .ElemAttribute }}
   947  			values.Add("{{ .Name }}", {{ $tmp }})
   948  {{ else }}			values.Add("{{ .Name }}", {{ .ValueName }})
   949  {{ end }}}{{/*
   950  
   951  // NON STRING
   952  */}}{{ else if .MustToString }}{{ $tmp := tempvar }}	{{ toString .ValueName $tmp .Attribute }}
   953  	values.Set("{{ .Name }}", {{ $tmp }}){{/*
   954  
   955  // STRING
   956  */}}{{ else }}	values.Set("{{ .Name }}", {{ .ValueName }})
   957  {{ end }}{{ if .CheckNil }}	}
   958  {{ end }}{{ end }}	u.RawQuery = values.Encode()
   959  {{ end }}	url_ := u.String()
   960  	cfg, err := websocket.NewConfig(url_, url_)
   961  	if err != nil {
   962  		return nil, err
   963  	}
   964  {{ range $header := .Headers }}{{ $tmp := tempvar }}	{{ toString $header.VarName $tmp $header.Attribute }}
   965  	cfg.Header["{{ $header.Name }}"] = []string{ {{ $tmp }} }
   966  {{ end }}	return websocket.DialConfig(cfg)
   967  }
   968  `
   969  
   970  	fsTmpl = `// {{ .Name }} downloads {{ if .DirName }}{{ .DirName }}files with the given filename{{ else }}{{ .FileName }}{{ end }} and writes it to the file dest.
   971  // It returns the number of bytes downloaded in case of success.
   972  func (c * Client) {{ .Name }}(ctx context.Context, {{ if .DirName }}filename, {{ end }}dest string) (int64, error) {
   973  	scheme := c.Scheme
   974  	if scheme == "" {
   975  		scheme = "{{ .CanonicalScheme }}"
   976  	}
   977  {{ if .DirName }}	p := path.Join("{{ .RequestDir }}", filename)
   978  {{ end }}	u := url.URL{Host: c.Host, Scheme: scheme, Path: {{ if .DirName }}p{{ else }}"{{ .RequestPath }}"{{ end }}}
   979  	req, err := http.NewRequest("GET", u.String(), nil)
   980  	if err != nil {
   981  		return 0, err
   982  	}
   983  	resp, err := c.Client.Do(ctx, req)
   984  	if err != nil {
   985  		return 0, err
   986  	}
   987  	if resp.StatusCode != 200 {
   988  		var body string
   989  		if b, err := ioutil.ReadAll(resp.Body); err != nil {
   990  			if len(b) > 0 {
   991  				body = ": "+ string(b)
   992  			}
   993  		}
   994  		return 0, fmt.Errorf("%s%s", resp.Status, body)
   995  	}
   996  	defer resp.Body.Close()
   997  	out, err := os.Create(dest)
   998  	if err != nil {
   999  		return 0, err
  1000  	}
  1001  	defer out.Close()
  1002  	return io.Copy(out, resp.Body)
  1003  }
  1004  `
  1005  
  1006  	requestsTmpl = `{{ $funcName := goify (printf "New%s%sRequest" (title .Name) (title .ResourceName)) true }}{{/*
  1007  */}}// {{ $funcName }} create the request corresponding to the {{ .Name }} action endpoint of the {{ .ResourceName }} resource.
  1008  func (c *Client) {{ $funcName }}(ctx context.Context, path string{{ if .Params }}, {{ .Params }}{{ end }}{{ if .HasPayload }}{{ if .HasMultiContent }}, contentType string{{ end }}{{ end }}) (*http.Request, error) {
  1009  {{ if .HasPayload }}	var body bytes.Buffer
  1010  {{ if .HasMultiContent }}	if contentType == "" {
  1011  		contentType = "*/*" // Use default encoder
  1012  	}
  1013  {{ end }}	err := c.Encoder.Encode(payload, &body, {{ if .HasMultiContent }}contentType{{ else }}"*/*"{{ end }})
  1014  	if err != nil {
  1015  		return nil, fmt.Errorf("failed to encode body: %s", err)
  1016  	}
  1017  {{ end }}	scheme := c.Scheme
  1018  	if scheme == "" {
  1019  		scheme = "{{ .CanonicalScheme }}"
  1020  	}
  1021  	u := url.URL{Host: c.Host, Scheme: scheme, Path: path}
  1022  {{ if .QueryParams }}	values := u.Query()
  1023  {{ range .QueryParams }}{{/*
  1024  
  1025  // ARRAY
  1026  */}}{{ if .IsArray }}		for _, p := range {{ .VarName }} {
  1027  {{ if .MustToString }}{{ $tmp := tempvar }}			{{ toString "p" $tmp .ElemAttribute }}
  1028  			values.Add("{{ .Name }}", {{ $tmp }})
  1029  {{ else }}			values.Add("{{ .Name }}", {{ .ValueName }})
  1030  {{ end }}	 }
  1031  {{/*
  1032  
  1033  // NON STRING
  1034  */}}{{ else if .MustToString }}{{ if .CheckNil }}	if {{ .VarName }} != nil {
  1035  	{{ end }}{{ $tmp := tempvar }}	{{ toString .ValueName $tmp .Attribute }}
  1036  	values.Set("{{ .Name }}", {{ $tmp }})
  1037  {{ if .CheckNil }}	}
  1038  {{ end }}{{/*
  1039  
  1040  // STRING
  1041  */}}{{ else }}{{ if .CheckNil }}	if {{ .VarName }} != nil {
  1042  	{{ end }}	values.Set("{{ .Name }}", {{ .ValueName }})
  1043  {{ if .CheckNil }}	}
  1044  {{ end }}{{ end }}{{ end }}	u.RawQuery = values.Encode()
  1045  {{ end }}{{ if .HasPayload }}	req, err := http.NewRequest({{ $route := index .Routes 0 }}"{{ $route.Verb }}", u.String(), &body)
  1046  {{ else }}	req, err := http.NewRequest({{ $route := index .Routes 0 }}"{{ $route.Verb }}", u.String(), nil)
  1047  {{ end }}	if err != nil {
  1048  		return nil, err
  1049  	}
  1050  {{ if or .Headers (and .HasPayload .HasMultiContent) }}	header := req.Header
  1051  {{ if .HasPayload }}{{ if .HasMultiContent }}	if contentType != "*/*" {
  1052  		header.Set("Content-Type", contentType)
  1053  	}
  1054  {{ end }}{{ end }}{{ range .Headers }}{{ if .CheckNil }}	if {{ .VarName }} != nil {
  1055  {{ end }}{{ if .MustToString }}{{ $tmp := tempvar }}	{{ toString .ValueName $tmp .Attribute }}
  1056  	header.Set("{{ .Name }}", {{ $tmp }}){{ else }}
  1057  	header.Set("{{ .Name }}", {{ .ValueName }})
  1058  {{ end }}{{ if .CheckNil }}	}{{ end }}
  1059  {{ end }}{{ end }}{{ if .Signer }}	if c.{{ .Signer }}Signer != nil {
  1060  		c.{{ .Signer }}Signer.Sign(req)
  1061  	}
  1062  {{ end }}	return req, nil
  1063  }
  1064  `
  1065  
  1066  	clientTmpl = `// Client is the {{ .API.Name }} service client.
  1067  type Client struct {
  1068  	*goaclient.Client{{range $security := .API.SecuritySchemes }}{{ $signer := signerType $security }}{{ if $signer }}
  1069  	{{ goify $security.SchemeName true }}Signer goaclient.Signer{{ end }}{{ end }}
  1070  	Encoder *goa.HTTPEncoder
  1071  	Decoder *goa.HTTPDecoder
  1072  }
  1073  
  1074  // New instantiates the client.
  1075  func New(c goaclient.Doer) *Client {
  1076  	client := &Client{
  1077  		Client: goaclient.New(c),
  1078  		Encoder: goa.NewHTTPEncoder(),
  1079  		Decoder: goa.NewHTTPDecoder(),
  1080  	}
  1081  
  1082  {{ if .Encoders }}	// Setup encoders and decoders
  1083  {{ range .Encoders }}{{/*
  1084  */}}	client.Encoder.Register({{ .PackageName }}.{{ .Function }}, "{{ joinStrings .MIMETypes "\", \"" }}")
  1085  {{ end }}{{ range .Decoders }}{{/*
  1086  */}}	client.Decoder.Register({{ .PackageName }}.{{ .Function }}, "{{ joinStrings .MIMETypes "\", \"" }}")
  1087  {{ end }}
  1088  
  1089  	// Setup default encoder and decoder
  1090  {{ range .Encoders }}{{ if .Default }}{{/*
  1091  */}}	client.Encoder.Register({{ .PackageName }}.{{ .Function }}, "*/*")
  1092  {{ end }}{{ end }}{{ range .Decoders }}{{ if .Default }}{{/*
  1093  */}}	client.Decoder.Register({{ .PackageName }}.{{ .Function }}, "*/*")
  1094  {{ end }}{{ end }}
  1095  {{ end }}	return client
  1096  }
  1097  
  1098  {{range $security := .API.SecuritySchemes }}{{ $signer := signerType $security }}{{ if $signer }}{{/*
  1099  */}}{{ $name := printf "%sSigner" (goify $security.SchemeName true) }}{{/*
  1100  */}}// Set{{ $name }} sets the request signer for the {{ $security.SchemeName }} security scheme.
  1101  func (c *Client) Set{{ $name }}(signer goaclient.Signer) {
  1102  	c.{{ $name }} = signer
  1103  }
  1104  {{ end }}{{ end }}
  1105  `
  1106  )