github.com/please-build/go-rules/tools/please_go@v0.0.0-20240319165128-ea27d6f5caba/goget/go_get.go (about)

     1  package goget
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"net/http"
     9  	"os"
    10  	"strings"
    11  
    12  	"golang.org/x/mod/modfile"
    13  	"golang.org/x/mod/semver"
    14  )
    15  
    16  var client = http.DefaultClient
    17  
    18  type moduleVersion struct {
    19  	mod, ver string
    20  }
    21  
    22  type getter struct {
    23  	queryResults map[moduleVersion]*modfile.File
    24  	proxyUrl     string
    25  }
    26  
    27  func newGetter() *getter {
    28  	return &getter{
    29  		queryResults: map[moduleVersion]*modfile.File{},
    30  		proxyUrl:     "https://proxy.golang.org",
    31  	}
    32  }
    33  
    34  func (g *getter) getGoMod(mod, ver string) (*modfile.File, error) {
    35  	modVer := moduleVersion{mod, ver}
    36  	if modFile, ok := g.queryResults[modVer]; ok {
    37  		return modFile, nil
    38  	}
    39  
    40  	file := fmt.Sprintf("%s/%s/@v/%s.mod", g.proxyUrl, mod, ver)
    41  	resp, err := client.Get(file)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	body, err := io.ReadAll(resp.Body)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	if resp.StatusCode != 200 {
    52  		return nil, fmt.Errorf("%v %v: \n%v", file, resp.StatusCode, string(body))
    53  	}
    54  
    55  	modFile, err := modfile.Parse(file, body, nil)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	g.queryResults[modVer] = modFile
    61  	return modFile, nil
    62  }
    63  
    64  // getGoModWithFallback attempts to get a go.mod for the given module and
    65  // version with fallback for supporting modules with case insensitivity.
    66  func (g *getter) getGoModWithFallback(mod, version string) (*modfile.File, error) {
    67  	modVersionsToAttempt := map[string]string{
    68  		mod: version,
    69  	}
    70  
    71  	// attempt lowercasing entire mod string for packages like:
    72  	// - `github.com/Sirupsen/logrus` -> `github.com/sirupsen/logrus`.
    73  	// https://github.com/sirupsen/logrus/issues/543
    74  	modVersionsToAttempt[strings.ToLower(mod)] = version
    75  
    76  	var errs error
    77  	for mod, version := range modVersionsToAttempt {
    78  		modFile, err := g.getGoMod(mod, version)
    79  		if err != nil {
    80  			errs = errors.Join(errs, err)
    81  			continue
    82  		}
    83  
    84  		return modFile, nil
    85  	}
    86  
    87  	return nil, errs
    88  }
    89  
    90  func (g *getter) getDeps(deps map[string]string, mod, version string) error {
    91  	modFile, err := g.getGoModWithFallback(mod, version)
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	for _, req := range modFile.Require {
    97  		oldVer, ok := deps[req.Mod.Path]
    98  		if !ok || semver.Compare(oldVer, req.Mod.Version) < 0 {
    99  			deps[req.Mod.Path] = req.Mod.Version
   100  			if err := g.getDeps(deps, req.Mod.Path, req.Mod.Version); err != nil {
   101  				return err
   102  			}
   103  		}
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func (g *getter) goGet(mods []string) error {
   110  	deps := map[string]string{}
   111  	for _, mod := range mods {
   112  		path, version, ok := strings.Cut(mod, "@")
   113  		if !ok {
   114  			log.Fatalf("Module spec %s is missing a version; must be in the format golang.org/x/sys@v1.0.0", mod)
   115  		}
   116  		deps[path] = version
   117  	}
   118  
   119  	for mod, ver := range deps {
   120  		if err := g.getDeps(deps, mod, ver); err != nil {
   121  			return err
   122  		}
   123  	}
   124  
   125  	for mod, ver := range deps {
   126  		fmt.Printf("go_repo(module=\"%s\", version=\"%s\")\n", mod, ver)
   127  	}
   128  	return nil
   129  }
   130  
   131  // GoGet is used to spit out a new go_get rule. The plan is to build this out into a tool to add new third party
   132  // modules to the repo.
   133  func GoGet(mods []string) error {
   134  	return newGetter().goGet(mods)
   135  }
   136  
   137  func GetMod(path string) error {
   138  	g := newGetter()
   139  	bs, err := os.ReadFile(path)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	modFile, err := modfile.ParseLax(path, bs, nil)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	paths := make([]string, len(modFile.Require))
   150  	for i, req := range modFile.Require {
   151  		paths[i] = fmt.Sprintf("%v@%v", req.Mod.Path, req.Mod.Version)
   152  	}
   153  
   154  	return g.goGet(paths)
   155  }