github.com/please-build/puku@v1.7.3-0.20240516143641-f7d7f4941f57/sync/sync.go (about)

     1  package sync
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  
     7  	"github.com/please-build/buildtools/build"
     8  	"github.com/please-build/buildtools/labels"
     9  	"golang.org/x/mod/modfile"
    10  
    11  	"github.com/please-build/puku/config"
    12  	"github.com/please-build/puku/edit"
    13  	"github.com/please-build/puku/graph"
    14  	"github.com/please-build/puku/licences"
    15  	"github.com/please-build/puku/please"
    16  	"github.com/please-build/puku/proxy"
    17  )
    18  
    19  type syncer struct {
    20  	plzConf  *please.Config
    21  	graph    *graph.Graph
    22  	licences *licences.Licenses
    23  }
    24  
    25  func newSyncer(plzConf *please.Config, g *graph.Graph) *syncer {
    26  	p := proxy.New(proxy.DefaultURL)
    27  	l := licences.New(p, g)
    28  	return &syncer{
    29  		plzConf:  plzConf,
    30  		graph:    g,
    31  		licences: l,
    32  	}
    33  }
    34  
    35  // Sync constructs the syncer struct and initiates the sync.
    36  // NB. the Graph is to be constructed in the calling code because it's useful
    37  // for it to be available outside the package for testing.
    38  func Sync(plzConf *please.Config, g *graph.Graph) error {
    39  	s := newSyncer(plzConf, g)
    40  	if err := s.sync(); err != nil {
    41  		return err
    42  	}
    43  	return s.graph.FormatFiles()
    44  }
    45  
    46  // SyncToStdout constructs the syncer and outputs the synced build file to stdout.
    47  func SyncToStdout(format string, plzConf *please.Config, g *graph.Graph) error { //nolint
    48  	s := newSyncer(plzConf, g)
    49  	if err := s.sync(); err != nil {
    50  		return err
    51  	}
    52  	return s.graph.FormatFilesWithWriter(os.Stdout, format)
    53  }
    54  
    55  func (s *syncer) sync() error {
    56  	if s.plzConf.ModFile() == "" {
    57  		return nil
    58  	}
    59  
    60  	conf, err := config.ReadConfig(".")
    61  	if err != nil {
    62  		return err
    63  	}
    64  
    65  	file, err := s.graph.LoadFile(conf.GetThirdPartyDir())
    66  	if err != nil {
    67  		return err
    68  	}
    69  
    70  	existingRules, err := s.readModules(file)
    71  	if err != nil {
    72  		return fmt.Errorf("failed to read third party rules: %v", err)
    73  	}
    74  
    75  	if err := s.syncModFile(conf, file, existingRules); err != nil {
    76  		return err
    77  	}
    78  	return nil
    79  }
    80  
    81  func (s *syncer) syncModFile(conf *config.Config, file *build.File, exitingRules map[string]*build.Rule) error {
    82  	outs, err := please.Build(conf.GetPlzPath(), s.plzConf.ModFile())
    83  	if err != nil {
    84  		return err
    85  	}
    86  
    87  	if len(outs) != 1 {
    88  		return fmt.Errorf("expected exactly one out from Plugin.Go.Modfile, got %v", len(outs))
    89  	}
    90  
    91  	modFile := outs[0]
    92  	bs, err := os.ReadFile(modFile)
    93  	if err != nil {
    94  		return err
    95  	}
    96  	f, err := modfile.Parse(modFile, bs, nil)
    97  	if err != nil {
    98  		return err
    99  	}
   100  
   101  	for _, req := range f.Require {
   102  		reqVersion := req.Mod.Version
   103  		var replace *modfile.Replace
   104  		for _, r := range f.Replace {
   105  			if r.Old.Path == req.Mod.Path {
   106  				reqVersion = r.New.Version
   107  				if r.New.Path == req.Mod.Path { // we are just replacing version so don't need a replace
   108  					continue
   109  				}
   110  				replace = r
   111  			}
   112  		}
   113  
   114  		// Existing rule will point to the go_mod_download with the version on it so we should use the original path
   115  		r, ok := exitingRules[req.Mod.Path]
   116  		if ok {
   117  			if replace != nil && r.Kind() == "go_repo" {
   118  				// Looks like we've added in a replace for this module so we need to delete the old go_repo rule
   119  				// and regen with a go_mod_download and a go_repo.
   120  				edit.RemoveTarget(file, r)
   121  			} else {
   122  				// Make sure the version is up-to-date
   123  				r.SetAttr("version", edit.NewStringExpr(reqVersion))
   124  				continue
   125  			}
   126  		}
   127  
   128  		ls, err := s.licences.Get(req.Mod.Path, req.Mod.Version)
   129  		if err != nil {
   130  			return fmt.Errorf("failed to get licences for %v: %v", req.Mod.Path, err)
   131  		}
   132  
   133  		if replace == nil {
   134  			file.Stmt = append(file.Stmt, edit.NewGoRepoRule(req.Mod.Path, reqVersion, "", ls))
   135  			continue
   136  		}
   137  
   138  		dl, dlName := edit.NewModDownloadRule(replace.New.Path, replace.New.Version, ls)
   139  		file.Stmt = append(file.Stmt, dl)
   140  		file.Stmt = append(file.Stmt, edit.NewGoRepoRule(req.Mod.Path, "", dlName, nil))
   141  	}
   142  
   143  	return nil
   144  }
   145  
   146  func (s *syncer) readModules(file *build.File) (map[string]*build.Rule, error) {
   147  	// existingRules contains the rules for modules. These are synced to the go.mod's version as necessary. For modules
   148  	// that use `go_mod_download`, this map will point to that rule as that is the rule that has the version field.
   149  	existingRules := make(map[string]*build.Rule)
   150  	for _, repoRule := range append(file.Rules("go_repo"), file.Rules("go_module")...) {
   151  		if repoRule.AttrString("version") != "" {
   152  			existingRules[repoRule.AttrString("module")] = repoRule
   153  		} else {
   154  			// If we're using a go_mod_download for this module, then find the download rule instead.
   155  			t := labels.ParseRelative(repoRule.AttrString("download"), file.Pkg)
   156  			f, err := s.graph.LoadFile(t.Package)
   157  			if err != nil {
   158  				return nil, err
   159  			}
   160  			existingRules[repoRule.AttrString("module")] = edit.FindTargetByName(f, t.Target)
   161  		}
   162  	}
   163  
   164  	return existingRules, nil
   165  }