github.com/please-build/go-rules/tools/please_go@v0.0.0-20240319165128-ea27d6f5caba/goget/go_get.go (about) 1 package goget 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "log" 8 "net/http" 9 "os" 10 "strings" 11 12 "golang.org/x/mod/modfile" 13 "golang.org/x/mod/semver" 14 ) 15 16 var client = http.DefaultClient 17 18 type moduleVersion struct { 19 mod, ver string 20 } 21 22 type getter struct { 23 queryResults map[moduleVersion]*modfile.File 24 proxyUrl string 25 } 26 27 func newGetter() *getter { 28 return &getter{ 29 queryResults: map[moduleVersion]*modfile.File{}, 30 proxyUrl: "https://proxy.golang.org", 31 } 32 } 33 34 func (g *getter) getGoMod(mod, ver string) (*modfile.File, error) { 35 modVer := moduleVersion{mod, ver} 36 if modFile, ok := g.queryResults[modVer]; ok { 37 return modFile, nil 38 } 39 40 file := fmt.Sprintf("%s/%s/@v/%s.mod", g.proxyUrl, mod, ver) 41 resp, err := client.Get(file) 42 if err != nil { 43 return nil, err 44 } 45 46 body, err := io.ReadAll(resp.Body) 47 if err != nil { 48 return nil, err 49 } 50 51 if resp.StatusCode != 200 { 52 return nil, fmt.Errorf("%v %v: \n%v", file, resp.StatusCode, string(body)) 53 } 54 55 modFile, err := modfile.Parse(file, body, nil) 56 if err != nil { 57 return nil, err 58 } 59 60 g.queryResults[modVer] = modFile 61 return modFile, nil 62 } 63 64 // getGoModWithFallback attempts to get a go.mod for the given module and 65 // version with fallback for supporting modules with case insensitivity. 66 func (g *getter) getGoModWithFallback(mod, version string) (*modfile.File, error) { 67 modVersionsToAttempt := map[string]string{ 68 mod: version, 69 } 70 71 // attempt lowercasing entire mod string for packages like: 72 // - `github.com/Sirupsen/logrus` -> `github.com/sirupsen/logrus`. 73 // https://github.com/sirupsen/logrus/issues/543 74 modVersionsToAttempt[strings.ToLower(mod)] = version 75 76 var errs error 77 for mod, version := range modVersionsToAttempt { 78 modFile, err := g.getGoMod(mod, version) 79 if err != nil { 80 errs = errors.Join(errs, err) 81 continue 82 } 83 84 return modFile, nil 85 } 86 87 return nil, errs 88 } 89 90 func (g *getter) getDeps(deps map[string]string, mod, version string) error { 91 modFile, err := g.getGoModWithFallback(mod, version) 92 if err != nil { 93 return err 94 } 95 96 for _, req := range modFile.Require { 97 oldVer, ok := deps[req.Mod.Path] 98 if !ok || semver.Compare(oldVer, req.Mod.Version) < 0 { 99 deps[req.Mod.Path] = req.Mod.Version 100 if err := g.getDeps(deps, req.Mod.Path, req.Mod.Version); err != nil { 101 return err 102 } 103 } 104 } 105 106 return nil 107 } 108 109 func (g *getter) goGet(mods []string) error { 110 deps := map[string]string{} 111 for _, mod := range mods { 112 path, version, ok := strings.Cut(mod, "@") 113 if !ok { 114 log.Fatalf("Module spec %s is missing a version; must be in the format golang.org/x/sys@v1.0.0", mod) 115 } 116 deps[path] = version 117 } 118 119 for mod, ver := range deps { 120 if err := g.getDeps(deps, mod, ver); err != nil { 121 return err 122 } 123 } 124 125 for mod, ver := range deps { 126 fmt.Printf("go_repo(module=\"%s\", version=\"%s\")\n", mod, ver) 127 } 128 return nil 129 } 130 131 // GoGet is used to spit out a new go_get rule. The plan is to build this out into a tool to add new third party 132 // modules to the repo. 133 func GoGet(mods []string) error { 134 return newGetter().goGet(mods) 135 } 136 137 func GetMod(path string) error { 138 g := newGetter() 139 bs, err := os.ReadFile(path) 140 if err != nil { 141 return err 142 } 143 144 modFile, err := modfile.ParseLax(path, bs, nil) 145 if err != nil { 146 return err 147 } 148 149 paths := make([]string, len(modFile.Require)) 150 for i, req := range modFile.Require { 151 paths[i] = fmt.Sprintf("%v@%v", req.Mod.Path, req.Mod.Version) 152 } 153 154 return g.goGet(paths) 155 }