github.com/please-build/puku@v1.7.3-0.20240516143641-f7d7f4941f57/generate/generate.go (about)

     1  package generate
     2  
     3  import (
     4  	"fmt"
     5  	"io/fs"
     6  	"os"
     7  	"path/filepath"
     8  	"strings"
     9  
    10  	"github.com/please-build/buildtools/build"
    11  	"github.com/please-build/buildtools/labels"
    12  
    13  	"github.com/please-build/puku/config"
    14  	"github.com/please-build/puku/edit"
    15  	"github.com/please-build/puku/eval"
    16  	"github.com/please-build/puku/glob"
    17  	"github.com/please-build/puku/graph"
    18  	"github.com/please-build/puku/kinds"
    19  	"github.com/please-build/puku/licences"
    20  	"github.com/please-build/puku/logging"
    21  	"github.com/please-build/puku/please"
    22  	"github.com/please-build/puku/proxy"
    23  	"github.com/please-build/puku/trie"
    24  )
    25  
    26  var log = logging.GetLogger()
    27  
    28  type Proxy interface {
    29  	ResolveModuleForPackage(pattern string) (*proxy.Module, error)
    30  	ResolveDeps(mods, newMods []*proxy.Module) ([]*proxy.Module, error)
    31  }
    32  
    33  type updater struct {
    34  	plzConf       *please.Config
    35  	usingGoModule bool
    36  
    37  	graph *graph.Graph
    38  
    39  	newModules      []*proxy.Module
    40  	modules         []string
    41  	resolvedImports map[string]string
    42  	installs        *trie.Trie
    43  	eval            *eval.Eval
    44  
    45  	paths []string
    46  
    47  	proxy    Proxy
    48  	licences *licences.Licenses
    49  }
    50  
    51  func newUpdaterWithGraph(g *graph.Graph, conf *please.Config) *updater {
    52  	p := proxy.New(proxy.DefaultURL)
    53  	l := licences.New(p, g)
    54  	return &updater{
    55  		proxy:           p,
    56  		licences:        l,
    57  		plzConf:         conf,
    58  		graph:           g,
    59  		installs:        trie.New(),
    60  		eval:            eval.New(glob.New()),
    61  		resolvedImports: map[string]string{},
    62  	}
    63  }
    64  
    65  // newUpdater initialises a new updater struct. It's intended to be only used for testing (as is
    66  // newUpdaterWithGraph). In most instances the Update function should be called directly.
    67  func newUpdater(conf *please.Config) *updater {
    68  	g := graph.New(conf.BuildFileNames()).WithExperimentalDirs(conf.Parse.ExperimentalDir...)
    69  
    70  	return newUpdaterWithGraph(g, conf)
    71  }
    72  
    73  func Update(plzConf *please.Config, paths ...string) error {
    74  	u := newUpdater(plzConf)
    75  	if err := u.update(paths...); err != nil {
    76  		return err
    77  	}
    78  	return u.graph.FormatFiles()
    79  }
    80  
    81  func UpdateToStdout(format string, plzConf *please.Config, paths ...string) error {
    82  	u := newUpdater(plzConf)
    83  	if err := u.update(paths...); err != nil {
    84  		return err
    85  	}
    86  	return u.graph.FormatFilesWithWriter(os.Stdout, format)
    87  }
    88  
    89  func (u *updater) readAllModules(conf *config.Config) error {
    90  	return filepath.WalkDir(conf.GetThirdPartyDir(), func(path string, info fs.DirEntry, _ error) error {
    91  		for _, buildFileName := range u.plzConf.BuildFileNames() {
    92  			if info.Name() == buildFileName {
    93  				file, err := u.graph.LoadFile(filepath.Dir(path))
    94  				if err != nil {
    95  					return err
    96  				}
    97  
    98  				if err := u.readModules(file); err != nil {
    99  					return err
   100  				}
   101  			}
   102  		}
   103  		return nil
   104  	})
   105  }
   106  
   107  // readModules returns the defined third party modules in this project
   108  func (u *updater) readModules(file *build.File) error {
   109  	addInstalls := func(targetName, modName string, installs []string) {
   110  		for _, install := range installs {
   111  			path := filepath.Join(modName, install)
   112  			target := BuildTarget(targetName, file.Pkg, "")
   113  			u.installs.Add(path, target)
   114  		}
   115  	}
   116  
   117  	for _, repoRule := range file.Rules("go_repo") {
   118  		module := repoRule.AttrString("module")
   119  		u.modules = append(u.modules, module)
   120  
   121  		// we do not add installs for go_repos. We prefer to resolve deps
   122  		// to the subrepo targets since this is more efficient for please.
   123  	}
   124  
   125  	goMods := file.Rules("go_module")
   126  	u.usingGoModule = len(goMods) > 0 || u.usingGoModule
   127  
   128  	for _, mod := range goMods {
   129  		module := mod.AttrString("module")
   130  		installs := mod.AttrStrings("install")
   131  		if len(installs) == 0 {
   132  			installs = []string{"."}
   133  		}
   134  		addInstalls(mod.Name(), module, installs)
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  // update loops through the provided paths, updating and creating any build rules it finds.
   141  func (u *updater) update(paths ...string) error {
   142  	conf, err := config.ReadConfig(".")
   143  	if err != nil {
   144  		return err
   145  	}
   146  	u.paths = paths
   147  
   148  	if err := u.readAllModules(conf); err != nil {
   149  		return fmt.Errorf("failed to read third party rules: %v", err)
   150  	}
   151  
   152  	for _, path := range u.paths {
   153  		conf, err := config.ReadConfig(path)
   154  		if err != nil {
   155  			return err
   156  		}
   157  
   158  		if conf.GetStop() {
   159  			return nil
   160  		}
   161  
   162  		if err := u.updateOne(conf, path); err != nil {
   163  			return fmt.Errorf("failed to update %v: %v", path, err)
   164  		}
   165  	}
   166  
   167  	// Save any new modules we needed back to the third party file
   168  	return u.addNewModules(conf)
   169  }
   170  
   171  func (u *updater) updateOne(conf *config.Config, path string) error {
   172  	// Find all the files in the dir
   173  	sources, err := ImportDir(path)
   174  	if err != nil {
   175  		return err
   176  	}
   177  
   178  	// Parse the build file
   179  	file, err := u.graph.LoadFile(path)
   180  	if err != nil {
   181  		return err
   182  	}
   183  
   184  	if !u.plzConf.GoIsPreloaded() && conf.ShouldEnsureSubincludes() {
   185  		edit.EnsureSubinclude(file)
   186  	}
   187  
   188  	// Read existing rules from file
   189  	rules, calls := u.readRulesFromFile(conf, file, path)
   190  
   191  	// Allocate the sources to the rules, creating new rules as necessary
   192  	newRules, err := u.allocateSources(conf, path, sources, rules)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	rules = append(rules, newRules...)
   198  
   199  	// Update the existing call expressions in the build file
   200  	return u.updateDeps(conf, file, calls, rules, sources)
   201  }
   202  
   203  func (u *updater) addNewModules(conf *config.Config) error {
   204  	file, err := u.graph.LoadFile(conf.GetThirdPartyDir())
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	if !u.plzConf.GoIsPreloaded() && conf.ShouldEnsureSubincludes() {
   210  		edit.EnsureSubinclude(file)
   211  	}
   212  
   213  	goRepos := file.Rules("go_repo")
   214  	mods := make([]*proxy.Module, 0, len(goRepos))
   215  	existingRules := make(map[string]*build.Rule)
   216  	for _, rule := range goRepos {
   217  		mod, ver := rule.AttrString("module"), rule.AttrString("version")
   218  		existingRules[rule.AttrString("module")] = rule
   219  		mods = append(mods, &proxy.Module{Module: mod, Version: ver})
   220  	}
   221  
   222  	allMods, err := u.proxy.ResolveDeps(mods, u.newModules)
   223  	if err != nil {
   224  		return err
   225  	}
   226  
   227  	for _, mod := range allMods {
   228  		if rule, ok := existingRules[mod.Module]; ok {
   229  			// Modules might be using go_mod_download, which we don't handle.
   230  			if rule.Attr("version") != nil {
   231  				rule.SetAttr("version", edit.NewStringExpr(mod.Version))
   232  			}
   233  			continue
   234  		}
   235  		ls, err := u.licences.Get(mod.Module, mod.Version)
   236  		if err != nil {
   237  			return fmt.Errorf("failed to get license for mod %v: %v", mod.Module, err)
   238  		}
   239  		file.Stmt = append(file.Stmt, edit.NewGoRepoRule(mod.Module, mod.Version, "", ls))
   240  	}
   241  	return nil
   242  }
   243  
   244  // allSources calculates the sources for a target. It will evaluate the source list resolving globs, and building any
   245  // srcs that are other build targets.
   246  //
   247  // passedSources is a slice of filepaths, which contains source files passed to the rule, after resolving globs and
   248  // building any targets. These source files can be looked up in goFiles, if they exist.
   249  //
   250  // goFiles contains a mapping of source files to their GoFile. This map might be missing entries from passedSources, if
   251  // the source doesn't actually exist. In which case, this should be removed from the rule, as the user likely deleted
   252  // the file.
   253  func (u *updater) allSources(conf *config.Config, r *rule, sourceMap map[string]*GoFile) (passedSources []string, goFiles map[string]*GoFile, err error) {
   254  	srcs, err := u.eval.BuildSources(conf.GetPlzPath(), r.dir, r.Rule, r.SrcsAttr())
   255  	if err != nil {
   256  		return nil, nil, err
   257  	}
   258  
   259  	sources := make(map[string]*GoFile, len(srcs))
   260  	for _, src := range srcs {
   261  		if file, ok := sourceMap[src]; ok {
   262  			sources[src] = file
   263  			continue
   264  		}
   265  
   266  		// These are generated sources in plz-out/gen
   267  		f, err := importFile(".", src)
   268  		if err != nil {
   269  			continue
   270  		}
   271  		sources[src] = f
   272  	}
   273  	return srcs, sources, nil
   274  }
   275  
   276  // updateRuleDeps updates the dependencies of a build rule based on the imports of its sources
   277  func (u *updater) updateRuleDeps(conf *config.Config, rule *rule, rules []*rule, packageFiles map[string]*GoFile) error {
   278  	done := map[string]struct{}{}
   279  
   280  	// If the rule operates on non-go source files (e.g. *.proto for proto_library) then we should skip updating
   281  	// it as we can't determine its deps from sources this way.
   282  	if rule.kind.NonGoSources {
   283  		return nil
   284  	}
   285  
   286  	srcs, targetFiles, err := u.allSources(conf, rule, packageFiles)
   287  	if err != nil {
   288  		return err
   289  	}
   290  
   291  	label := BuildTarget(rule.Name(), rule.dir, "")
   292  
   293  	deps := map[string]struct{}{}
   294  	for _, src := range srcs {
   295  		f := targetFiles[src]
   296  		if f == nil {
   297  			rule.removeSrc(src) // The src doesn't exist so remove it from the list of srcs
   298  			continue
   299  		}
   300  		for _, i := range f.Imports {
   301  			if _, ok := done[i]; ok {
   302  				continue
   303  			}
   304  			done[i] = struct{}{}
   305  
   306  			// If the dep is provided by the kind (i.e. the build def adds it) then skip this import
   307  
   308  			dep, err := u.resolveImport(conf, i)
   309  			if err != nil {
   310  				log.Warningf("couldn't resolve %q for %v: %v", i, rule.label(), err)
   311  				continue
   312  			}
   313  			if dep == "" {
   314  				continue
   315  			}
   316  			if rule.kind.IsProvided(dep) {
   317  				continue
   318  			}
   319  
   320  			dep = shorten(rule.dir, dep)
   321  
   322  			if _, ok := deps[dep]; !ok {
   323  				deps[dep] = struct{}{}
   324  			}
   325  		}
   326  	}
   327  
   328  	// Add any libraries for the same package as us
   329  	if rule.kind.Type == kinds.Test && !rule.isExternal() {
   330  		pkgName, err := u.rulePkg(conf, packageFiles, rule)
   331  		if err != nil {
   332  			return err
   333  		}
   334  
   335  		for _, libRule := range rules {
   336  			if libRule.kind.Type == kinds.Test {
   337  				continue
   338  			}
   339  			libPkgName, err := u.rulePkg(conf, packageFiles, libRule)
   340  			if err != nil {
   341  				return err
   342  			}
   343  
   344  			if libPkgName != pkgName {
   345  				continue
   346  			}
   347  
   348  			t := libRule.localLabel()
   349  			if _, ok := deps[t]; !ok {
   350  				deps[t] = struct{}{}
   351  			}
   352  		}
   353  	}
   354  
   355  	depSlice := make([]string, 0, len(deps))
   356  	for dep := range deps {
   357  		u.graph.EnsureVisibility(label, dep)
   358  		depSlice = append(depSlice, dep)
   359  	}
   360  
   361  	rule.setOrDeleteAttr("deps", depSlice)
   362  
   363  	return nil
   364  }
   365  
   366  // shorten will shorten lables to the local package
   367  func shorten(pkg, label string) string {
   368  	if strings.HasPrefix(label, "///") || strings.HasPrefix(label, "@") {
   369  		return label
   370  	}
   371  
   372  	return labels.Shorten(label, pkg)
   373  }
   374  
   375  // readRulesFromFile reads the existing build rules from the BUILD file
   376  func (u *updater) readRulesFromFile(conf *config.Config, file *build.File, pkgDir string) ([]*rule, map[string]*build.Rule) {
   377  	ruleExprs := file.Rules("")
   378  	rules := make([]*rule, 0, len(ruleExprs))
   379  	calls := map[string]*build.Rule{}
   380  
   381  	for _, expr := range ruleExprs {
   382  		kind := conf.GetKind(expr.Kind())
   383  		if kind == nil {
   384  			continue
   385  		}
   386  		rule := newRule(expr, kind, pkgDir)
   387  		rules = append(rules, rule)
   388  		calls[rule.Name()] = expr
   389  	}
   390  
   391  	return rules, calls
   392  }
   393  
   394  // updateDeps updates the existing rules and creates any new rules in the BUILD file
   395  func (u *updater) updateDeps(conf *config.Config, file *build.File, ruleExprs map[string]*build.Rule, rules []*rule, sources map[string]*GoFile) error {
   396  	for _, rule := range rules {
   397  		if _, ok := ruleExprs[rule.Name()]; !ok {
   398  			file.Stmt = append(file.Stmt, rule.Call)
   399  		}
   400  		if err := u.updateRuleDeps(conf, rule, rules, sources); err != nil {
   401  			return err
   402  		}
   403  	}
   404  	return nil
   405  }
   406  
   407  // allocateSources allocates sources to rules. If there's no existing rule, a new rule will be created and returned
   408  // from this function
   409  func (u *updater) allocateSources(conf *config.Config, pkgDir string, sources map[string]*GoFile, rules []*rule) ([]*rule, error) {
   410  	unallocated, err := u.unallocatedSources(sources, rules)
   411  	if err != nil {
   412  		return nil, err
   413  	}
   414  
   415  	var newRules []*rule
   416  	for _, src := range unallocated {
   417  		importedFile := sources[src]
   418  		if importedFile == nil {
   419  			continue // Something went wrong and we haven't imported the file don't try to allocate it
   420  		}
   421  		var rule *rule
   422  		for _, r := range append(rules, newRules...) {
   423  			if r.kind.Type != importedFile.kindType() {
   424  				continue
   425  			}
   426  
   427  			rulePkgName, err := u.rulePkg(conf, sources, r)
   428  			if err != nil {
   429  				return nil, fmt.Errorf("failed to determine package name for //%v:%v: %w", pkgDir, r.Name(), err)
   430  			}
   431  
   432  			// Find a rule that's for the same package and of the same kind (i.e. bin, lib, test)
   433  			// NB: we return when we find the first one so if there are multiple options, we will pick one essentially at
   434  			//     random.
   435  			if rulePkgName == "" || rulePkgName == importedFile.Name {
   436  				rule = r
   437  				break
   438  			}
   439  		}
   440  		if rule == nil {
   441  			name := filepath.Base(pkgDir)
   442  			kind := "go_library"
   443  			if importedFile.IsTest() {
   444  				name += "_test"
   445  				kind = "go_test"
   446  			}
   447  			if importedFile.IsCmd() {
   448  				kind = "go_binary"
   449  				name = "main"
   450  			}
   451  			rule = newRule(edit.NewRuleExpr(kind, name), kinds.DefaultKinds[kind], pkgDir)
   452  			if importedFile.IsExternal(filepath.Join(u.plzConf.ImportPath(), pkgDir)) {
   453  				rule.setExternal()
   454  			}
   455  			newRules = append(newRules, rule)
   456  		}
   457  
   458  		rule.addSrc(src)
   459  	}
   460  	return newRules, nil
   461  }
   462  
   463  // rulePkg checks the first source it finds for a rule and returns the name from the "package name" directive at the top
   464  // of the file
   465  func (u *updater) rulePkg(conf *config.Config, srcs map[string]*GoFile, rule *rule) (string, error) {
   466  	// This is a safe bet if we can't use the source files to figure this out.
   467  	if rule.kind.NonGoSources {
   468  		return rule.Name(), nil
   469  	}
   470  
   471  	ss, srcs, err := u.allSources(conf, rule, srcs)
   472  	if err != nil {
   473  		return "", err
   474  	}
   475  
   476  	for _, s := range ss {
   477  		if src, ok := srcs[s]; ok {
   478  			return src.Name, nil
   479  		}
   480  	}
   481  
   482  	return "", nil
   483  }
   484  
   485  // unallocatedSources returns all the sources that don't already belong to a rule
   486  func (u *updater) unallocatedSources(srcs map[string]*GoFile, rules []*rule) ([]string, error) {
   487  	var ret []string
   488  	for src := range srcs {
   489  		found := false
   490  		for _, rule := range rules {
   491  			if found {
   492  				break
   493  			}
   494  
   495  			ruleSrcs, err := u.eval.EvalGlobs(rule.dir, rule.Rule, rule.SrcsAttr())
   496  			if err != nil {
   497  				return nil, err
   498  			}
   499  			for _, s := range ruleSrcs {
   500  				if s == src {
   501  					found = true
   502  					break
   503  				}
   504  			}
   505  		}
   506  		if !found {
   507  			ret = append(ret, src)
   508  		}
   509  	}
   510  	return ret, nil
   511  }