github.com/octohelm/cuemod@v0.9.4/tool/internalfork/main.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "flag" 6 "fmt" 7 "go/ast" 8 "go/build" 9 "go/format" 10 "go/parser" 11 "go/token" 12 "log" 13 "os" 14 "path" 15 "path/filepath" 16 "strconv" 17 "strings" 18 19 "golang.org/x/mod/modfile" 20 "golang.org/x/tools/go/ast/astutil" 21 ) 22 23 type Packages []string 24 25 func (i *Packages) String() string { 26 return "go packages" 27 } 28 29 func (i *Packages) Set(value string) error { 30 *i = append(*i, strings.TrimSpace(value)) 31 return nil 32 } 33 34 var importPaths = Packages{} 35 36 func main() { 37 flag.Var(&importPaths, "p", "import path") 38 flag.Parse() 39 40 output := flag.Args()[0] 41 42 if !(strings.HasSuffix(output, "internal")) { 43 panic(fmt.Errorf("output should be '*/internal', but got `%s`", output)) 44 } 45 46 if err := cleanup(output); err != nil { 47 panic(err) 48 } 49 50 prefixes := map[string]bool{} 51 52 for _, importPath := range importPaths { 53 prefixes[strings.Split(importPath, "/")[0]] = true 54 } 55 56 task, err := TaskFor(output, prefixes) 57 if err != nil { 58 panic(err) 59 } 60 61 for _, importPath := range importPaths { 62 if err := task.Scan(importPath); err != nil { 63 panic(err) 64 } 65 } 66 67 if err := task.Sync(); err != nil { 68 panic(err) 69 } 70 } 71 72 func cleanup(output string) error { 73 list, err := filepath.Glob(filepath.Join(output, "./*")) 74 if err != nil { 75 return err 76 } 77 78 for i := range list { 79 p := list[i] 80 81 f, err := os.Lstat(p) 82 if err != nil { 83 return err 84 } 85 86 if f.IsDir() && !strings.HasSuffix(f.Name(), "__gen__") { 87 log.Println("REMOVE", p) 88 if err := os.RemoveAll(p); err != nil { 89 return err 90 } 91 } 92 } 93 94 return nil 95 } 96 97 func TaskFor(dir string, prefixes map[string]bool) (*Task, error) { 98 if !filepath.IsAbs(dir) { 99 output, _ := os.Getwd() 100 dir = path.Join(output, dir) 101 } 102 103 d := dir 104 105 for d != "/" { 106 gmodfile := filepath.Join(d, "go.mod") 107 108 if data, err := os.ReadFile(gmodfile); err != nil { 109 if !os.IsNotExist(err) { 110 panic(err) 111 } 112 } else { 113 f, _ := modfile.Parse(gmodfile, data, nil) 114 115 rel, _ := filepath.Rel(d, dir) 116 117 t := &Task{ 118 Dir: filepath.Join(d, rel), 119 PkgPath: filepath.Join(f.Module.Mod.Path, rel), 120 Prefixes: prefixes, 121 } 122 123 return t, nil 124 } 125 126 d = filepath.Join(d, "../") 127 } 128 129 return nil, fmt.Errorf("missing go.mod") 130 } 131 132 type Task struct { 133 Dir string 134 PkgPath string 135 Prefixes map[string]bool 136 137 // map[importPath][filename]*parsedFile 138 packages map[string]map[string]*parsedFile 139 140 pkgInternals map[string]bool 141 pkgUsed map[string][]string 142 } 143 144 func (t *Task) Sync() error { 145 needToForks := map[string]bool{} 146 147 var findUsed func(importPath string) []string 148 findUsed = func(importPath string) (used []string) { 149 for _, importPath := range t.pkgUsed[importPath] { 150 used = append(append(used, importPath), findUsed(importPath)...) 151 } 152 return 153 } 154 155 for pkgImportPath := range t.pkgInternals { 156 needToForks[pkgImportPath] = true 157 158 for _, p := range findUsed(pkgImportPath) { 159 needToForks[p] = true 160 } 161 } 162 163 for pkgImportPath := range needToForks { 164 files := t.packages[pkgImportPath] 165 166 for filename := range files { 167 f := files[filename] 168 169 astFile := f.file 170 171 for _, i := range astFile.Imports { 172 importPath, _ := strconv.Unquote(i.Path.Value) 173 174 if needToForks[importPath] { 175 fmt.Println(filepath.Join(t.PkgPath, t.replaceInternal(importPath))) 176 177 _ = astutil.RewriteImport( 178 f.fset, astFile, 179 importPath, 180 filepath.Join(t.PkgPath, t.replaceInternal(importPath)), 181 ) 182 } 183 } 184 185 output := filepath.Join(t.Dir, t.replaceInternal(pkgImportPath), filepath.Base(filename)) 186 187 buf := bytes.NewBuffer(nil) 188 if err := format.Node(buf, f.fset, astFile); err != nil { 189 return err 190 } 191 if err := writeFile(output, buf.Bytes()); err != nil { 192 return err 193 } 194 } 195 } 196 197 return nil 198 } 199 200 func (t *Task) Scan(importPath string) error { 201 if _, ok := t.packages[importPath]; ok { 202 return nil 203 } 204 205 pkg, err := build.Import(importPath, "", build.FindOnly) 206 if err != nil { 207 return err 208 } 209 210 if err := t.scanPkg(pkg); err != nil { 211 return err 212 } 213 214 log.Printf("SCANED %s", importPath) 215 216 return nil 217 } 218 219 func (t *Task) isInternalPkg(importPath string) bool { 220 if strings.Contains(importPath, "internal/") { 221 return true 222 } 223 return strings.HasSuffix(importPath, "internal") || strings.HasPrefix(importPath, "internal") 224 } 225 226 func (t *Task) scanPkg(pkg *build.Package) error { 227 files, err := filepath.Glob(pkg.Dir + "/*.go") 228 if err != nil { 229 return err 230 } 231 232 if t.isInternalPkg(pkg.ImportPath) { 233 if t.pkgInternals == nil { 234 t.pkgInternals = map[string]bool{} 235 } 236 t.pkgInternals[pkg.ImportPath] = true 237 } 238 239 for _, f := range files { 240 // skip test file 241 if strings.HasSuffix(f, "_test.go") { 242 continue 243 } 244 245 if err := t.scanGoFile(f, pkg); err != nil { 246 return err 247 } 248 } 249 250 return nil 251 } 252 253 func (t *Task) scanGoFile(filename string, pkg *build.Package) error { 254 f, err := newParsedFile(filename) 255 if err != nil { 256 return err 257 } 258 259 file := f.file 260 261 pkgImportPath := pkg.ImportPath 262 263 if pkg.Name == "" { 264 if file.Name.Name != "main" { 265 pkg.Name = file.Name.Name 266 } 267 } 268 269 if file.Name.Name != pkg.Name { 270 return nil 271 } 272 273 for _, i := range file.Imports { 274 importPath, _ := strconv.Unquote(i.Path.Value) 275 276 if t.pkgUsed == nil { 277 t.pkgUsed = map[string][]string{} 278 } 279 280 t.pkgUsed[importPath] = append(t.pkgUsed[importPath], pkgImportPath) 281 282 if t.Prefixes[strings.Split(importPath, "/")[0]] { 283 if err := t.Scan(importPath); err != nil { 284 return err 285 } 286 } 287 } 288 289 if t.packages == nil { 290 t.packages = map[string]map[string]*parsedFile{} 291 } 292 293 if t.packages[pkgImportPath] == nil { 294 t.packages[pkgImportPath] = map[string]*parsedFile{} 295 } 296 297 t.packages[pkgImportPath][filename] = f 298 299 return nil 300 } 301 302 func newParsedFile(filename string) (*parsedFile, error) { 303 f := &parsedFile{} 304 305 src, err := os.ReadFile(filename) 306 if err != nil { 307 return nil, err 308 } 309 310 f.fset = token.NewFileSet() 311 312 file, err := parser.ParseFile(f.fset, filename, src, parser.ParseComments) 313 if err != nil { 314 return nil, err 315 } 316 317 f.file = file 318 319 return f, nil 320 } 321 322 type parsedFile struct { 323 fset *token.FileSet 324 file *ast.File 325 } 326 327 func (t *Task) replaceInternal(p string) string { 328 if strings.HasSuffix(p, "internal") { 329 return filepath.Join(filepath.Dir(p), "./internals") 330 } 331 332 return strings.Replace( 333 p, 334 "internal/", 335 "internals/", 336 -1, 337 ) 338 } 339 340 func writeFile(filename string, data []byte) error { 341 if err := os.MkdirAll(filepath.Dir(filename), os.ModePerm); err != nil { 342 return err 343 } 344 return os.WriteFile(filename, data, os.ModePerm) 345 }