github.com/octohelm/cuekit@v0.0.0-20240424021256-e7df8d743066/internal/cmd/internalfork/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"go/ast"
     8  	"go/build"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/token"
    12  	"log"
    13  	"os"
    14  	"path"
    15  	"path/filepath"
    16  	"strconv"
    17  	"strings"
    18  
    19  	"golang.org/x/mod/modfile"
    20  	"golang.org/x/tools/go/ast/astutil"
    21  )
    22  
    23  type Packages []string
    24  
    25  func (i *Packages) String() string {
    26  	return "go packages"
    27  }
    28  
    29  func (i *Packages) Set(value string) error {
    30  	*i = append(*i, strings.TrimSpace(value))
    31  	return nil
    32  }
    33  
    34  var importPaths = Packages{}
    35  
    36  func main() {
    37  	flag.Var(&importPaths, "p", "import path")
    38  	flag.Parse()
    39  
    40  	output := flag.Args()[0]
    41  
    42  	if !(strings.HasSuffix(output, "internal")) {
    43  		panic(fmt.Errorf("output should be '*/internal', but got `%s`", output))
    44  	}
    45  
    46  	if err := cleanup(output); err != nil {
    47  		panic(err)
    48  	}
    49  
    50  	prefixes := map[string]bool{}
    51  
    52  	for _, importPath := range importPaths {
    53  		prefixes[strings.Split(importPath, "/")[0]] = true
    54  	}
    55  
    56  	task, err := TaskFor(output, prefixes)
    57  	if err != nil {
    58  		panic(err)
    59  	}
    60  
    61  	for _, importPath := range importPaths {
    62  		if err := task.Scan(importPath); err != nil {
    63  			panic(err)
    64  		}
    65  	}
    66  
    67  	if err := task.Sync(); err != nil {
    68  		panic(err)
    69  	}
    70  }
    71  
    72  func cleanup(output string) error {
    73  	list, err := filepath.Glob(filepath.Join(output, "./*"))
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	for i := range list {
    79  		p := list[i]
    80  
    81  		f, err := os.Lstat(p)
    82  		if err != nil {
    83  			return err
    84  		}
    85  
    86  		if f.IsDir() && !strings.HasSuffix(f.Name(), "__gen__") {
    87  			log.Println("REMOVE", p)
    88  			if err := os.RemoveAll(p); err != nil {
    89  				return err
    90  			}
    91  		}
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func TaskFor(dir string, prefixes map[string]bool) (*Task, error) {
    98  	if !filepath.IsAbs(dir) {
    99  		output, _ := os.Getwd()
   100  		dir = path.Join(output, dir)
   101  	}
   102  
   103  	d := dir
   104  
   105  	for d != "/" {
   106  		gmodfile := filepath.Join(d, "go.mod")
   107  
   108  		if data, err := os.ReadFile(gmodfile); err != nil {
   109  			if !os.IsNotExist(err) {
   110  				panic(err)
   111  			}
   112  		} else {
   113  			f, _ := modfile.Parse(gmodfile, data, nil)
   114  
   115  			rel, _ := filepath.Rel(d, dir)
   116  
   117  			t := &Task{
   118  				Dir:      filepath.Join(d, rel),
   119  				PkgPath:  filepath.Join(f.Module.Mod.Path, rel),
   120  				Prefixes: prefixes,
   121  			}
   122  
   123  			return t, nil
   124  		}
   125  
   126  		d = filepath.Join(d, "../")
   127  	}
   128  
   129  	return nil, fmt.Errorf("missing go.mod")
   130  }
   131  
   132  type Task struct {
   133  	Dir      string
   134  	PkgPath  string
   135  	Prefixes map[string]bool
   136  
   137  	// map[importPath][filename]*parsedFile
   138  	packages map[string]map[string]*parsedFile
   139  
   140  	pkgInternals map[string]bool
   141  	pkgUsed      map[string][]string
   142  }
   143  
   144  func (t *Task) Sync() error {
   145  	needToForks := map[string]bool{}
   146  
   147  	var findUsed func(importPath string) []string
   148  
   149  	findUsed = func(importPath string) (used []string) {
   150  		for _, importPath := range t.pkgUsed[importPath] {
   151  			used = append(append(used, importPath), findUsed(importPath)...)
   152  		}
   153  		return
   154  	}
   155  
   156  	for pkgImportPath := range t.pkgInternals {
   157  		needToForks[pkgImportPath] = true
   158  
   159  		for _, p := range findUsed(pkgImportPath) {
   160  			needToForks[p] = true
   161  		}
   162  	}
   163  
   164  	for pkgImportPath := range needToForks {
   165  		files := t.packages[pkgImportPath]
   166  
   167  		for filename := range files {
   168  			f := files[filename]
   169  
   170  			astFile := f.file
   171  
   172  			for _, i := range astFile.Imports {
   173  				importPath, _ := strconv.Unquote(i.Path.Value)
   174  
   175  				if needToForks[importPath] {
   176  					fmt.Println(filepath.Join(t.PkgPath, t.replaceInternal(importPath)))
   177  
   178  					_ = astutil.RewriteImport(
   179  						f.fset, astFile,
   180  						importPath,
   181  						filepath.Join(t.PkgPath, t.replaceInternal(importPath)),
   182  					)
   183  				}
   184  			}
   185  
   186  			output := filepath.Join(t.Dir, t.replaceInternal(pkgImportPath), filepath.Base(filename))
   187  
   188  			buf := bytes.NewBuffer(nil)
   189  			if err := format.Node(buf, f.fset, astFile); err != nil {
   190  				return err
   191  			}
   192  			if err := writeFile(output, buf.Bytes()); err != nil {
   193  				return err
   194  			}
   195  		}
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  func (t *Task) Scan(importPath string) error {
   202  	if _, ok := t.packages[importPath]; ok {
   203  		return nil
   204  	}
   205  
   206  	// skip not internal pkg
   207  	if !t.isInternalPkg(importPath) {
   208  		return nil
   209  	}
   210  
   211  	pkg, err := build.Import(importPath, "", build.FindOnly)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	if err := t.scanPkg(pkg); err != nil {
   217  		return err
   218  	}
   219  
   220  	log.Printf("SCANED %s", importPath)
   221  
   222  	return nil
   223  }
   224  
   225  func (t *Task) isInternalPkg(importPath string) bool {
   226  	if strings.Contains(importPath, "internal/") {
   227  		return true
   228  	}
   229  	return strings.HasSuffix(importPath, "internal") || strings.HasPrefix(importPath, "internal")
   230  }
   231  
   232  func (t *Task) scanPkg(pkg *build.Package) error {
   233  	files, err := filepath.Glob(pkg.Dir + "/*.go")
   234  	if err != nil {
   235  		return err
   236  	}
   237  
   238  	if t.isInternalPkg(pkg.ImportPath) {
   239  		if t.pkgInternals == nil {
   240  			t.pkgInternals = map[string]bool{}
   241  		}
   242  		t.pkgInternals[pkg.ImportPath] = true
   243  	}
   244  
   245  	for _, f := range files {
   246  		// skip test file
   247  		if strings.HasSuffix(f, "_test.go") {
   248  			continue
   249  		}
   250  
   251  		if err := t.scanGoFile(f, pkg); err != nil {
   252  			return err
   253  		}
   254  	}
   255  
   256  	return nil
   257  }
   258  
   259  func (t *Task) scanGoFile(filename string, pkg *build.Package) error {
   260  	f, err := newParsedFile(filename)
   261  	if err != nil {
   262  		return err
   263  	}
   264  
   265  	file := f.file
   266  
   267  	pkgImportPath := pkg.ImportPath
   268  
   269  	if pkg.Name == "" {
   270  		if file.Name.Name != "main" {
   271  			pkg.Name = file.Name.Name
   272  		}
   273  	}
   274  
   275  	if file.Name.Name != pkg.Name {
   276  		return nil
   277  	}
   278  
   279  	for _, i := range file.Imports {
   280  		importPath, _ := strconv.Unquote(i.Path.Value)
   281  
   282  		if t.pkgUsed == nil {
   283  			t.pkgUsed = map[string][]string{}
   284  		}
   285  
   286  		t.pkgUsed[importPath] = append(t.pkgUsed[importPath], pkgImportPath)
   287  
   288  		if t.Prefixes[strings.Split(importPath, "/")[0]] {
   289  			if err := t.Scan(importPath); err != nil {
   290  				return err
   291  			}
   292  		}
   293  	}
   294  
   295  	if t.packages == nil {
   296  		t.packages = map[string]map[string]*parsedFile{}
   297  	}
   298  
   299  	if t.packages[pkgImportPath] == nil {
   300  		t.packages[pkgImportPath] = map[string]*parsedFile{}
   301  	}
   302  
   303  	t.packages[pkgImportPath][filename] = f
   304  
   305  	return nil
   306  }
   307  
   308  func newParsedFile(filename string) (*parsedFile, error) {
   309  	f := &parsedFile{}
   310  
   311  	src, err := os.ReadFile(filename)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  
   316  	f.fset = token.NewFileSet()
   317  
   318  	file, err := parser.ParseFile(f.fset, filename, src, parser.ParseComments)
   319  	if err != nil {
   320  		return nil, err
   321  	}
   322  
   323  	f.file = file
   324  
   325  	return f, nil
   326  }
   327  
   328  type parsedFile struct {
   329  	fset *token.FileSet
   330  	file *ast.File
   331  }
   332  
   333  func (t *Task) replaceInternal(p string) string {
   334  	if strings.HasSuffix(p, "internal") {
   335  		return filepath.Join(filepath.Dir(p), "./internals")
   336  	}
   337  
   338  	return strings.Replace(p, "internal/", "internals/", -1)
   339  }
   340  
   341  func writeFile(filename string, data []byte) error {
   342  	if err := os.MkdirAll(filepath.Dir(filename), os.ModePerm); err != nil {
   343  		return err
   344  	}
   345  	return os.WriteFile(filename, data, os.ModePerm)
   346  }