github.com/octohelm/cuekit@v0.0.0-20240424021256-e7df8d743066/internal/cmd/internalfork/main.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "flag" 6 "fmt" 7 "go/ast" 8 "go/build" 9 "go/format" 10 "go/parser" 11 "go/token" 12 "log" 13 "os" 14 "path" 15 "path/filepath" 16 "strconv" 17 "strings" 18 19 "golang.org/x/mod/modfile" 20 "golang.org/x/tools/go/ast/astutil" 21 ) 22 23 type Packages []string 24 25 func (i *Packages) String() string { 26 return "go packages" 27 } 28 29 func (i *Packages) Set(value string) error { 30 *i = append(*i, strings.TrimSpace(value)) 31 return nil 32 } 33 34 var importPaths = Packages{} 35 36 func main() { 37 flag.Var(&importPaths, "p", "import path") 38 flag.Parse() 39 40 output := flag.Args()[0] 41 42 if !(strings.HasSuffix(output, "internal")) { 43 panic(fmt.Errorf("output should be '*/internal', but got `%s`", output)) 44 } 45 46 if err := cleanup(output); err != nil { 47 panic(err) 48 } 49 50 prefixes := map[string]bool{} 51 52 for _, importPath := range importPaths { 53 prefixes[strings.Split(importPath, "/")[0]] = true 54 } 55 56 task, err := TaskFor(output, prefixes) 57 if err != nil { 58 panic(err) 59 } 60 61 for _, importPath := range importPaths { 62 if err := task.Scan(importPath); err != nil { 63 panic(err) 64 } 65 } 66 67 if err := task.Sync(); err != nil { 68 panic(err) 69 } 70 } 71 72 func cleanup(output string) error { 73 list, err := filepath.Glob(filepath.Join(output, "./*")) 74 if err != nil { 75 return err 76 } 77 78 for i := range list { 79 p := list[i] 80 81 f, err := os.Lstat(p) 82 if err != nil { 83 return err 84 } 85 86 if f.IsDir() && !strings.HasSuffix(f.Name(), "__gen__") { 87 log.Println("REMOVE", p) 88 if err := os.RemoveAll(p); err != nil { 89 return err 90 } 91 } 92 } 93 94 return nil 95 } 96 97 func TaskFor(dir string, prefixes map[string]bool) (*Task, error) { 98 if !filepath.IsAbs(dir) { 99 output, _ := os.Getwd() 100 dir = path.Join(output, dir) 101 } 102 103 d := dir 104 105 for d != "/" { 106 gmodfile := filepath.Join(d, "go.mod") 107 108 if data, err := os.ReadFile(gmodfile); err != nil { 109 if !os.IsNotExist(err) { 110 panic(err) 111 } 112 } else { 113 f, _ := modfile.Parse(gmodfile, data, nil) 114 115 rel, _ := filepath.Rel(d, dir) 116 117 t := &Task{ 118 Dir: filepath.Join(d, rel), 119 PkgPath: filepath.Join(f.Module.Mod.Path, rel), 120 Prefixes: prefixes, 121 } 122 123 return t, nil 124 } 125 126 d = filepath.Join(d, "../") 127 } 128 129 return nil, fmt.Errorf("missing go.mod") 130 } 131 132 type Task struct { 133 Dir string 134 PkgPath string 135 Prefixes map[string]bool 136 137 // map[importPath][filename]*parsedFile 138 packages map[string]map[string]*parsedFile 139 140 pkgInternals map[string]bool 141 pkgUsed map[string][]string 142 } 143 144 func (t *Task) Sync() error { 145 needToForks := map[string]bool{} 146 147 var findUsed func(importPath string) []string 148 149 findUsed = func(importPath string) (used []string) { 150 for _, importPath := range t.pkgUsed[importPath] { 151 used = append(append(used, importPath), findUsed(importPath)...) 152 } 153 return 154 } 155 156 for pkgImportPath := range t.pkgInternals { 157 needToForks[pkgImportPath] = true 158 159 for _, p := range findUsed(pkgImportPath) { 160 needToForks[p] = true 161 } 162 } 163 164 for pkgImportPath := range needToForks { 165 files := t.packages[pkgImportPath] 166 167 for filename := range files { 168 f := files[filename] 169 170 astFile := f.file 171 172 for _, i := range astFile.Imports { 173 importPath, _ := strconv.Unquote(i.Path.Value) 174 175 if needToForks[importPath] { 176 fmt.Println(filepath.Join(t.PkgPath, t.replaceInternal(importPath))) 177 178 _ = astutil.RewriteImport( 179 f.fset, astFile, 180 importPath, 181 filepath.Join(t.PkgPath, t.replaceInternal(importPath)), 182 ) 183 } 184 } 185 186 output := filepath.Join(t.Dir, t.replaceInternal(pkgImportPath), filepath.Base(filename)) 187 188 buf := bytes.NewBuffer(nil) 189 if err := format.Node(buf, f.fset, astFile); err != nil { 190 return err 191 } 192 if err := writeFile(output, buf.Bytes()); err != nil { 193 return err 194 } 195 } 196 } 197 198 return nil 199 } 200 201 func (t *Task) Scan(importPath string) error { 202 if _, ok := t.packages[importPath]; ok { 203 return nil 204 } 205 206 // skip not internal pkg 207 if !t.isInternalPkg(importPath) { 208 return nil 209 } 210 211 pkg, err := build.Import(importPath, "", build.FindOnly) 212 if err != nil { 213 return err 214 } 215 216 if err := t.scanPkg(pkg); err != nil { 217 return err 218 } 219 220 log.Printf("SCANED %s", importPath) 221 222 return nil 223 } 224 225 func (t *Task) isInternalPkg(importPath string) bool { 226 if strings.Contains(importPath, "internal/") { 227 return true 228 } 229 return strings.HasSuffix(importPath, "internal") || strings.HasPrefix(importPath, "internal") 230 } 231 232 func (t *Task) scanPkg(pkg *build.Package) error { 233 files, err := filepath.Glob(pkg.Dir + "/*.go") 234 if err != nil { 235 return err 236 } 237 238 if t.isInternalPkg(pkg.ImportPath) { 239 if t.pkgInternals == nil { 240 t.pkgInternals = map[string]bool{} 241 } 242 t.pkgInternals[pkg.ImportPath] = true 243 } 244 245 for _, f := range files { 246 // skip test file 247 if strings.HasSuffix(f, "_test.go") { 248 continue 249 } 250 251 if err := t.scanGoFile(f, pkg); err != nil { 252 return err 253 } 254 } 255 256 return nil 257 } 258 259 func (t *Task) scanGoFile(filename string, pkg *build.Package) error { 260 f, err := newParsedFile(filename) 261 if err != nil { 262 return err 263 } 264 265 file := f.file 266 267 pkgImportPath := pkg.ImportPath 268 269 if pkg.Name == "" { 270 if file.Name.Name != "main" { 271 pkg.Name = file.Name.Name 272 } 273 } 274 275 if file.Name.Name != pkg.Name { 276 return nil 277 } 278 279 for _, i := range file.Imports { 280 importPath, _ := strconv.Unquote(i.Path.Value) 281 282 if t.pkgUsed == nil { 283 t.pkgUsed = map[string][]string{} 284 } 285 286 t.pkgUsed[importPath] = append(t.pkgUsed[importPath], pkgImportPath) 287 288 if t.Prefixes[strings.Split(importPath, "/")[0]] { 289 if err := t.Scan(importPath); err != nil { 290 return err 291 } 292 } 293 } 294 295 if t.packages == nil { 296 t.packages = map[string]map[string]*parsedFile{} 297 } 298 299 if t.packages[pkgImportPath] == nil { 300 t.packages[pkgImportPath] = map[string]*parsedFile{} 301 } 302 303 t.packages[pkgImportPath][filename] = f 304 305 return nil 306 } 307 308 func newParsedFile(filename string) (*parsedFile, error) { 309 f := &parsedFile{} 310 311 src, err := os.ReadFile(filename) 312 if err != nil { 313 return nil, err 314 } 315 316 f.fset = token.NewFileSet() 317 318 file, err := parser.ParseFile(f.fset, filename, src, parser.ParseComments) 319 if err != nil { 320 return nil, err 321 } 322 323 f.file = file 324 325 return f, nil 326 } 327 328 type parsedFile struct { 329 fset *token.FileSet 330 file *ast.File 331 } 332 333 func (t *Task) replaceInternal(p string) string { 334 if strings.HasSuffix(p, "internal") { 335 return filepath.Join(filepath.Dir(p), "./internals") 336 } 337 338 return strings.Replace(p, "internal/", "internals/", -1) 339 } 340 341 func writeFile(filename string, data []byte) error { 342 if err := os.MkdirAll(filepath.Dir(filename), os.ModePerm); err != nil { 343 return err 344 } 345 return os.WriteFile(filename, data, os.ModePerm) 346 }