
     1  package main
     3  import (
     4  	"bytes"
     5  	"container/list"
     6  	"errors"
     7  	"flag"
     8  	"fmt"
     9  	"go/build"
    10  	"hash/fnv"
    11  	"html/template"
    12  	"io"
    13  	"io/ioutil"
    14  	"os"
    15  	"path"
    16  	"path/filepath"
    17  	"runtime/debug"
    18  	"strings"
    20  	""
    21  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  )
    29  // PackageSpec defines the data available to the --output option's template.
    30  // Information is recomputed for each package generated.
    31  type PackageSpec struct {
    32  	// Dir holds the local path where the package is located. If the package is
    33  	// a remote package, this will always be ".".
    34  	Dir string
    36  	// ImportPath holds a representation of the package that should be unique
    37  	// for most purposes. If a package is on the filesystem, this is equivalent
    38  	// to the value of Dir. For remote packages, this holds the string used to
    39  	// import that package in code (e.g. "encoding/json").
    40  	ImportPath string
    41  	isWildcard bool
    42  	isLocal    bool
    43  	outputFile string
    44  	pkg        *lang.Package
    45  }
    47  type commandOptions struct {
    48  	repository            lang.Repo
    49  	output                string
    50  	header                string
    51  	headerFile            string
    52  	footer                string
    53  	footerFile            string
    54  	format                string
    55  	tags                  []string
    56  	excludeDirs           []string
    57  	templateOverrides     map[string]string
    58  	templateFileOverrides map[string]string
    59  	verbosity             int
    60  	includeUnexported     bool
    61  	check                 bool
    62  	embed                 bool
    63  	version               bool
    64  	fileOnly              bool
    65  	file                  string
    66  	overrideImportPath    string
    67  }
    69  var version = "v1.0.1"
    71  const configFilePrefix = ".gomarkdoc"
    73  func buildCommand() *cobra.Command {
    74  	var opts commandOptions
    75  	var configFile string
    77  	// cobra.OnInitialize(func() { buildConfig(configFile) })
    79  	var command = &cobra.Command{
    80  		Use:   "gomarkdoc [package ...]",
    81  		Short: "generate markdown documentation for golang code",
    82  		RunE: func(cmd *cobra.Command, args []string) error {
    83  			if opts.version {
    84  				printVersion()
    85  				return nil
    86  			}
    88  			buildConfig(configFile)
    90  			// Load configuration from viper
    91  			opts.includeUnexported = viper.GetBool("includeUnexported")
    92  			opts.output = viper.GetString("output")
    93  			opts.check = viper.GetBool("check")
    94  			opts.embed = viper.GetBool("embed")
    95  			opts.format = viper.GetString("format")
    96  			opts.templateOverrides = viper.GetStringMapString("template")
    97  			opts.templateFileOverrides = viper.GetStringMapString("templateFile")
    98  			opts.header = viper.GetString("header")
    99  			opts.headerFile = viper.GetString("headerFile")
   100  			opts.footer = viper.GetString("footer")
   101  			opts.footerFile = viper.GetString("footerFile")
   102  			opts.tags = viper.GetStringSlice("tags")
   103  			opts.excludeDirs = viper.GetStringSlice("excludeDirs")
   104  			opts.repository.Remote = viper.GetString("repository.url")
   105  			opts.repository.DefaultBranch = viper.GetString("repository.defaultBranch")
   106  			opts.repository.PathFromRoot = viper.GetString("repository.path")
   107  			opts.fileOnly = viper.GetBool("fileOnly")
   108  			opts.overrideImportPath = viper.GetString("overrideImportPath")
   110  			if opts.check && opts.output == "" {
   111  				return errors.New("gomarkdoc: check mode cannot be run without an output set")
   112  			}
   114  			if opts.fileOnly {
   115  				if len(args) == 0 {
   116  					return errors.New("gomarkdoc: if file-only flag is set, then a valid file must be passed")
   117  				} else if len(args) > 1 {
   118  					return errors.New("gomarkdoc: if file-only flag is set, then only one file can be passed")
   119  				} else if path.Ext(args[0]) != ".go" {
   120  					return errors.New("gomarkdoc: if file-only flag is set, then the file passed must be a go file")
   121  				}
   123  				filePath := args[0]
   124  				if !filepath.IsAbs(filePath) {
   125  					var err error
   126  					filePath, err = filepath.Abs(filePath)
   127  					if err != nil {
   128  						return fmt.Errorf("gomarkdoc: failed to get absolute path for file: %w", err)
   129  					}
   130  				}
   132  				opts.file = filePath
   133  			}
   135  			if len(args) == 0 {
   136  				// Default to current directory
   137  				args = []string{"."}
   138  			}
   140  			return runCommand(args, opts)
   141  		},
   142  	}
   143  	command.Flags().StringVar(
   144  		&configFile,
   145  		"config",
   146  		"",
   147  		fmt.Sprintf("File from which to load configuration (default: %s.yml)", configFilePrefix),
   148  	)
   149  	command.Flags().BoolVarP(
   150  		&opts.includeUnexported,
   151  		"include-unexported",
   152  		"u",
   153  		false,
   154  		"Output documentation for unexported symbols, methods and fields in addition to exported ones.",
   155  	)
   156  	command.Flags().StringVarP(
   157  		&opts.output,
   158  		"output",
   159  		"o",
   160  		"",
   161  		"File or pattern specifying where to write documentation output. Defaults to printing to stdout.",
   162  	)
   163  	command.Flags().BoolVarP(
   164  		&opts.check,
   165  		"check",
   166  		"c",
   167  		false,
   168  		"Check the output to see if it matches the generated documentation. --output must be specified to use this.",
   169  	)
   170  	command.Flags().BoolVarP(
   171  		&opts.embed,
   172  		"embed",
   173  		"e",
   174  		false,
   175  		"Embed documentation into existing markdown files if available, otherwise append to file.",
   176  	)
   177  	command.Flags().StringVarP(
   178  		&opts.format,
   179  		"format",
   180  		"f",
   181  		"github",
   182  		"Format to use for writing output data. Valid options: github (default), azure-devops, plain",
   183  	)
   184  	command.Flags().StringToStringVarP(
   185  		&opts.templateOverrides,
   186  		"template",
   187  		"t",
   188  		map[string]string{},
   189  		"Custom template string to use for the provided template name instead of the default template.",
   190  	)
   191  	command.Flags().StringToStringVar(
   192  		&opts.templateFileOverrides,
   193  		"template-file",
   194  		map[string]string{},
   195  		"Custom template file to use for the provided template name instead of the default template.",
   196  	)
   197  	command.Flags().StringVar(
   198  		&opts.header,
   199  		"header",
   200  		"",
   201  		"Additional content to inject at the beginning of each output file.",
   202  	)
   203  	command.Flags().StringVar(
   204  		&opts.headerFile,
   205  		"header-file",
   206  		"",
   207  		"File containing additional content to inject at the beginning of each output file.",
   208  	)
   209  	command.Flags().StringVar(
   210  		&opts.footer,
   211  		"footer",
   212  		"",
   213  		"Additional content to inject at the end of each output file.",
   214  	)
   215  	command.Flags().StringVar(
   216  		&opts.footerFile,
   217  		"footer-file",
   218  		"",
   219  		"File containing additional content to inject at the end of each output file.",
   220  	)
   221  	command.Flags().StringSliceVar(
   222  		&opts.tags,
   223  		"tags",
   224  		defaultTags(),
   225  		"Set of build tags to apply when choosing which files to include for documentation generation.",
   226  	)
   227  	command.Flags().StringSliceVar(
   228  		&opts.excludeDirs,
   229  		"exclude-dirs",
   230  		nil,
   231  		"List of package directories to ignore when producing documentation.",
   232  	)
   233  	command.Flags().CountVarP(
   234  		&opts.verbosity,
   235  		"verbose",
   236  		"v",
   237  		"Log additional output from the execution of the command. Can be chained for additional verbosity.",
   238  	)
   239  	command.Flags().StringVar(
   240  		&opts.repository.Remote,
   241  		"repository.url",
   242  		"",
   243  		"Manual override for the git repository URL used in place of automatic detection.",
   244  	)
   245  	command.Flags().StringVar(
   246  		&opts.repository.DefaultBranch,
   247  		"repository.default-branch",
   248  		"",
   249  		"Manual override for the git repository URL used in place of automatic detection.",
   250  	)
   251  	command.Flags().StringVar(
   252  		&opts.repository.PathFromRoot,
   253  		"repository.path",
   254  		"",
   255  		"Manual override for the path from the root of the git repository used in place of automatic detection.",
   256  	)
   257  	command.Flags().BoolVar(
   258  		&opts.version,
   259  		"version",
   260  		false,
   261  		"Print the version.",
   262  	)
   263  	command.Flags().BoolVar(
   264  		&opts.fileOnly,
   265  		"file-only",
   266  		false,
   267  		"Only includes definition inside the defined files",
   268  	)
   269  	command.Flags().StringVar(
   270  		&opts.overrideImportPath,
   271  		"override-import-path",
   272  		"",
   273  		"Override the import path of the package. This is useful when the package is not in the GOPATH.",
   274  	)
   276  	// We ignore the errors here because they only happen if the specified flag doesn't exist
   277  	_ = viper.BindPFlag("includeUnexported", command.Flags().Lookup("include-unexported"))
   278  	_ = viper.BindPFlag("output", command.Flags().Lookup("output"))
   279  	_ = viper.BindPFlag("check", command.Flags().Lookup("check"))
   280  	_ = viper.BindPFlag("embed", command.Flags().Lookup("embed"))
   281  	_ = viper.BindPFlag("format", command.Flags().Lookup("format"))
   282  	_ = viper.BindPFlag("template", command.Flags().Lookup("template"))
   283  	_ = viper.BindPFlag("templateFile", command.Flags().Lookup("template-file"))
   284  	_ = viper.BindPFlag("header", command.Flags().Lookup("header"))
   285  	_ = viper.BindPFlag("headerFile", command.Flags().Lookup("header-file"))
   286  	_ = viper.BindPFlag("footer", command.Flags().Lookup("footer"))
   287  	_ = viper.BindPFlag("footerFile", command.Flags().Lookup("footer-file"))
   288  	_ = viper.BindPFlag("tags", command.Flags().Lookup("tags"))
   289  	_ = viper.BindPFlag("excludeDirs", command.Flags().Lookup("exclude-dirs"))
   290  	_ = viper.BindPFlag("repository.url", command.Flags().Lookup("repository.url"))
   291  	_ = viper.BindPFlag("repository.defaultBranch", command.Flags().Lookup("repository.default-branch"))
   292  	_ = viper.BindPFlag("repository.path", command.Flags().Lookup("repository.path"))
   293  	_ = viper.BindPFlag("fileOnly", command.Flags().Lookup("file-only"))
   294  	_ = viper.BindPFlag("overrideImportPath", command.Flags().Lookup("override-import-path"))
   296  	return command
   297  }
   299  func defaultTags() []string {
   300  	f, ok := os.LookupEnv("GOFLAGS")
   301  	if !ok {
   302  		return nil
   303  	}
   305  	fs := flag.NewFlagSet("goflags", flag.ContinueOnError)
   306  	tags := fs.String("tags", "", "")
   308  	if err := fs.Parse(strings.Fields(f)); err != nil {
   309  		return nil
   310  	}
   312  	if tags == nil {
   313  		return nil
   314  	}
   316  	return strings.Split(*tags, ",")
   317  }
   319  func buildConfig(configFile string) {
   320  	if configFile != "" {
   321  		viper.SetConfigFile(configFile)
   322  	} else {
   323  		viper.AddConfigPath(".")
   324  		viper.SetConfigName(configFilePrefix)
   325  	}
   327  	viper.AutomaticEnv()
   329  	if err := viper.ReadInConfig(); err != nil {
   330  		if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
   331  			// TODO: better handling
   332  			fmt.Println(err)
   333  		}
   334  	}
   335  }
   337  func runCommand(paths []string, opts commandOptions) error {
   338  	outputTmpl, err := template.New("output").Parse(opts.output)
   339  	if err != nil {
   340  		return fmt.Errorf("gomarkdoc: invalid output template: %w", err)
   341  	}
   343  	specs := getSpecs(paths...)
   345  	excluded := getSpecs(opts.excludeDirs...)
   346  	if err := validateExcludes(excluded); err != nil {
   347  		return err
   348  	}
   350  	specs = removeExcludes(specs, excluded)
   352  	if err := resolveOutput(specs, outputTmpl); err != nil {
   353  		return err
   354  	}
   356  	if err := loadPackages(specs, opts); err != nil {
   357  		return err
   358  	}
   360  	return writeOutput(specs, opts)
   361  }
   363  func resolveOutput(specs []*PackageSpec, outputTmpl *template.Template) error {
   364  	for _, spec := range specs {
   365  		var outputFile strings.Builder
   366  		if err := outputTmpl.Execute(&outputFile, spec); err != nil {
   367  			return err
   368  		}
   370  		outputStr := outputFile.String()
   371  		if outputStr == "" {
   372  			// Preserve empty values
   373  			spec.outputFile = ""
   374  		} else {
   375  			// Clean up other values
   376  			spec.outputFile = filepath.Clean(outputFile.String())
   377  		}
   378  	}
   380  	return nil
   381  }
   383  func resolveOverrides(opts commandOptions) ([]gomarkdoc.RendererOption, error) {
   384  	var overrides []gomarkdoc.RendererOption
   386  	// Content overrides take precedence over file overrides
   387  	for name, s := range opts.templateOverrides {
   388  		overrides = append(overrides, gomarkdoc.WithTemplateOverride(name, s))
   389  	}
   391  	for name, f := range opts.templateFileOverrides {
   392  		// File overrides get applied only if there isn't already a content
   393  		// override.
   394  		if _, ok := opts.templateOverrides[name]; ok {
   395  			continue
   396  		}
   398  		b, err := ioutil.ReadFile(f)
   399  		if err != nil {
   400  			return nil, fmt.Errorf("gomarkdoc: couldn't resolve template for %s: %w", name, err)
   401  		}
   403  		overrides = append(overrides, gomarkdoc.WithTemplateOverride(name, string(b)))
   404  	}
   406  	var f format.Format
   407  	switch opts.format {
   408  	case "github":
   409  		f = &format.GitHubFlavoredMarkdown{}
   410  	case "azure-devops":
   411  		f = &format.AzureDevOpsMarkdown{}
   412  	case "plain":
   413  		f = &format.PlainMarkdown{}
   414  	default:
   415  		return nil, fmt.Errorf("gomarkdoc: invalid format: %s", opts.format)
   416  	}
   418  	overrides = append(overrides, gomarkdoc.WithFormat(f))
   420  	return overrides, nil
   421  }
   423  func resolveHeader(opts commandOptions) (string, error) {
   424  	if opts.header != "" {
   425  		return opts.header, nil
   426  	}
   428  	if opts.headerFile != "" {
   429  		b, err := ioutil.ReadFile(opts.headerFile)
   430  		if err != nil {
   431  			return "", fmt.Errorf("gomarkdoc: couldn't resolve header file: %w", err)
   432  		}
   434  		return string(b), nil
   435  	}
   437  	return "", nil
   438  }
   440  func resolveFooter(opts commandOptions) (string, error) {
   441  	if opts.footer != "" {
   442  		return opts.footer, nil
   443  	}
   445  	if opts.footerFile != "" {
   446  		b, err := ioutil.ReadFile(opts.footerFile)
   447  		if err != nil {
   448  			return "", fmt.Errorf("gomarkdoc: couldn't resolve footer file: %w", err)
   449  		}
   451  		return string(b), nil
   452  	}
   454  	return "", nil
   455  }
   457  func loadPackages(specs []*PackageSpec, opts commandOptions) error {
   458  	for _, spec := range specs {
   459  		log := logger.New(getLogLevel(opts.verbosity), logger.WithField("dir", spec.Dir))
   461  		buildPkg, err := getBuildPackage(spec.ImportPath, opts.tags)
   462  		if err != nil {
   463  			log.Debugf("unable to load package in directory: %s", err)
   464  			// We don't care if a wildcard path produces nothing
   465  			if spec.isWildcard {
   466  				continue
   467  			}
   469  			return err
   470  		}
   472  		var pkgOpts []lang.PackageOption
   473  		pkgOpts = append(pkgOpts, lang.PackageWithRepositoryOverrides(&opts.repository))
   475  		if opts.includeUnexported {
   476  			pkgOpts = append(pkgOpts, lang.PackageWithUnexportedIncluded())
   477  		}
   479  		if opts.fileOnly {
   480  			pkgOpts = append(pkgOpts, lang.PackageWithFileFilter(opts.file))
   481  		}
   483  		if opts.overrideImportPath != "" {
   484  			pkgOpts = append(pkgOpts, lang.PackageWithOverrideImport(opts.overrideImportPath))
   485  		}
   487  		pkg, err := lang.NewPackageFromBuild(log, buildPkg, pkgOpts...)
   488  		if err != nil {
   489  			return err
   490  		}
   492  		spec.pkg = pkg
   493  	}
   495  	return nil
   496  }
   498  func getBuildPackage(path string, tags []string) (*build.Package, error) {
   499  	ctx := build.Default
   500  	ctx.BuildTags = tags
   502  	if isLocalPath(path) {
   503  		pkg, err := ctx.ImportDir(path, build.ImportComment)
   504  		if err != nil {
   505  			return nil, fmt.Errorf("gomarkdoc: invalid package in directory: %s", path)
   506  		}
   508  		return pkg, nil
   509  	}
   511  	wd, err := os.Getwd()
   512  	if err != nil {
   513  		return nil, err
   514  	}
   516  	pkg, err := ctx.Import(path, wd, build.ImportComment)
   517  	if err != nil {
   518  		return nil, fmt.Errorf("gomarkdoc: invalid package at import path: %s", path)
   519  	}
   521  	return pkg, nil
   522  }
   524  func getSpecs(paths ...string) []*PackageSpec {
   525  	var expanded []*PackageSpec
   526  	for _, path := range paths {
   527  		// Ensure that the path we're working with is normalized for the OS
   528  		// we're using (i.e. "\" for windows, "/" for everything else)
   529  		path = filepath.FromSlash(path)
   531  		// Not a recursive path
   532  		if !strings.HasSuffix(path, fmt.Sprintf("%s...", string(os.PathSeparator))) {
   533  			isLocal := isLocalPath(path)
   534  			var dir string
   535  			if isLocal {
   536  				dir = path
   537  			} else {
   538  				dir = "."
   539  			}
   540  			expanded = append(expanded, &PackageSpec{
   541  				Dir:        dir,
   542  				ImportPath: path,
   543  				isWildcard: false,
   544  				isLocal:    isLocal,
   545  			})
   546  			continue
   547  		}
   549  		// Remove the recursive marker so we can work with the path
   550  		trimmedPath := path[0 : len(path)-3]
   552  		// Not a file path. Add the original path back to the list so as to not
   553  		// mislead someone into thinking we're processing the recursive path
   554  		if !isLocalPath(trimmedPath) {
   555  			expanded = append(expanded, &PackageSpec{
   556  				Dir:        ".",
   557  				ImportPath: path,
   558  				isWildcard: false,
   559  				isLocal:    false,
   560  			})
   561  			continue
   562  		}
   564  		expanded = append(expanded, &PackageSpec{
   565  			Dir:        trimmedPath,
   566  			ImportPath: trimmedPath,
   567  			isWildcard: true,
   568  			isLocal:    true,
   569  		})
   571  		queue := list.New()
   572  		queue.PushBack(trimmedPath)
   573  		for e := queue.Front(); e != nil; e = e.Next() {
   574  			prev := e.Prev()
   575  			if prev != nil {
   576  				queue.Remove(prev)
   577  			}
   579  			p := e.Value.(string)
   581  			files, err := ioutil.ReadDir(p)
   582  			if err != nil {
   583  				// If we couldn't read the folder, there are no directories that
   584  				// we're going to find beneath it
   585  				continue
   586  			}
   588  			for _, f := range files {
   589  				if isIgnoredDir(f.Name()) {
   590  					continue
   591  				}
   593  				if f.IsDir() {
   594  					subPath := filepath.Join(p, f.Name())
   596  					// Some local paths have their prefixes stripped by Join().
   597  					// If the path is no longer a local path, add the current
   598  					// working directory.
   599  					if !isLocalPath(subPath) {
   600  						subPath = fmt.Sprintf("%s%s", cwdPathPrefix, subPath)
   601  					}
   603  					expanded = append(expanded, &PackageSpec{
   604  						Dir:        subPath,
   605  						ImportPath: subPath,
   606  						isWildcard: true,
   607  						isLocal:    true,
   608  					})
   609  					queue.PushBack(subPath)
   610  				}
   611  			}
   612  		}
   613  	}
   615  	return expanded
   616  }
   618  var ignoredDirs = []string{".git"}
   620  // isIgnoredDir identifies if the dir is one we want to intentionally ignore.
   621  func isIgnoredDir(dirname string) bool {
   622  	for _, ignored := range ignoredDirs {
   623  		if ignored == dirname {
   624  			return true
   625  		}
   626  	}
   628  	return false
   629  }
   631  // validateExcludes checks that the exclude dirs are all directories, not
   632  // packages.
   633  func validateExcludes(specs []*PackageSpec) error {
   634  	for _, s := range specs {
   635  		if !s.isLocal {
   636  			return fmt.Errorf("gomarkdoc: invalid directory specified as an exclude directory: %s", s.ImportPath)
   637  		}
   638  	}
   640  	return nil
   641  }
   643  // removeExcludes removes any package specs that were specified as excluded.
   644  func removeExcludes(specs []*PackageSpec, excludes []*PackageSpec) []*PackageSpec {
   645  	out := make([]*PackageSpec, 0, len(specs))
   646  	for _, s := range specs {
   647  		var exclude bool
   648  		for _, e := range excludes {
   649  			if !s.isLocal || !e.isLocal {
   650  				continue
   651  			}
   653  			if r, err := filepath.Rel(s.Dir, e.Dir); err == nil && r == "." {
   654  				exclude = true
   655  				break
   656  			}
   657  		}
   659  		if !exclude {
   660  			out = append(out, s)
   661  		}
   662  	}
   664  	return out
   665  }
   667  const (
   668  	cwdPathPrefix    = "." + string(os.PathSeparator)
   669  	parentPathPrefix = ".." + string(os.PathSeparator)
   670  )
   672  func isLocalPath(path string) bool {
   673  	return strings.HasPrefix(path, ".") || strings.HasPrefix(path, parentPathPrefix) || filepath.IsAbs(path)
   674  }
   676  func compare(r1, r2 io.Reader) (bool, error) {
   677  	r1Hash := fnv.New128()
   678  	if _, err := io.Copy(r1Hash, r1); err != nil {
   679  		return false, fmt.Errorf("gomarkdoc: failed when checking documentation: %w", err)
   680  	}
   682  	r2Hash := fnv.New128()
   683  	if _, err := io.Copy(r2Hash, r2); err != nil {
   684  		return false, fmt.Errorf("gomarkdoc: failed when checking documentation: %w", err)
   685  	}
   687  	return bytes.Equal(r1Hash.Sum(nil), r2Hash.Sum(nil)), nil
   688  }
   690  func getLogLevel(verbosity int) logger.Level {
   691  	switch verbosity {
   692  	case 0:
   693  		return logger.WarnLevel
   694  	case 1:
   695  		return logger.InfoLevel
   696  	case 2:
   697  		return logger.DebugLevel
   698  	default:
   699  		return logger.DebugLevel
   700  	}
   701  }
   703  func printVersion() {
   704  	if version != "" {
   705  		fmt.Println(version)
   706  		return
   707  	}
   709  	if info, ok := debug.ReadBuildInfo(); ok {
   710  		fmt.Println(info.Main.Version)
   711  	} else {
   712  		fmt.Println("<unknown>")
   713  	}
   714  }