github.com/octohelm/cuemod@v0.9.4/tool/internalfork/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"go/ast"
     8  	"go/build"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/token"
    12  	"log"
    13  	"os"
    14  	"path"
    15  	"path/filepath"
    16  	"strconv"
    17  	"strings"
    18  
    19  	"golang.org/x/mod/modfile"
    20  	"golang.org/x/tools/go/ast/astutil"
    21  )
    22  
    23  type Packages []string
    24  
    25  func (i *Packages) String() string {
    26  	return "go packages"
    27  }
    28  
    29  func (i *Packages) Set(value string) error {
    30  	*i = append(*i, strings.TrimSpace(value))
    31  	return nil
    32  }
    33  
    34  var importPaths = Packages{}
    35  
    36  func main() {
    37  	flag.Var(&importPaths, "p", "import path")
    38  	flag.Parse()
    39  
    40  	output := flag.Args()[0]
    41  
    42  	if !(strings.HasSuffix(output, "internal")) {
    43  		panic(fmt.Errorf("output should be '*/internal', but got `%s`", output))
    44  	}
    45  
    46  	if err := cleanup(output); err != nil {
    47  		panic(err)
    48  	}
    49  
    50  	prefixes := map[string]bool{}
    51  
    52  	for _, importPath := range importPaths {
    53  		prefixes[strings.Split(importPath, "/")[0]] = true
    54  	}
    55  
    56  	task, err := TaskFor(output, prefixes)
    57  	if err != nil {
    58  		panic(err)
    59  	}
    60  
    61  	for _, importPath := range importPaths {
    62  		if err := task.Scan(importPath); err != nil {
    63  			panic(err)
    64  		}
    65  	}
    66  
    67  	if err := task.Sync(); err != nil {
    68  		panic(err)
    69  	}
    70  }
    71  
    72  func cleanup(output string) error {
    73  	list, err := filepath.Glob(filepath.Join(output, "./*"))
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	for i := range list {
    79  		p := list[i]
    80  
    81  		f, err := os.Lstat(p)
    82  		if err != nil {
    83  			return err
    84  		}
    85  
    86  		if f.IsDir() && !strings.HasSuffix(f.Name(), "__gen__") {
    87  			log.Println("REMOVE", p)
    88  			if err := os.RemoveAll(p); err != nil {
    89  				return err
    90  			}
    91  		}
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func TaskFor(dir string, prefixes map[string]bool) (*Task, error) {
    98  	if !filepath.IsAbs(dir) {
    99  		output, _ := os.Getwd()
   100  		dir = path.Join(output, dir)
   101  	}
   102  
   103  	d := dir
   104  
   105  	for d != "/" {
   106  		gmodfile := filepath.Join(d, "go.mod")
   107  
   108  		if data, err := os.ReadFile(gmodfile); err != nil {
   109  			if !os.IsNotExist(err) {
   110  				panic(err)
   111  			}
   112  		} else {
   113  			f, _ := modfile.Parse(gmodfile, data, nil)
   114  
   115  			rel, _ := filepath.Rel(d, dir)
   116  
   117  			t := &Task{
   118  				Dir:      filepath.Join(d, rel),
   119  				PkgPath:  filepath.Join(f.Module.Mod.Path, rel),
   120  				Prefixes: prefixes,
   121  			}
   122  
   123  			return t, nil
   124  		}
   125  
   126  		d = filepath.Join(d, "../")
   127  	}
   128  
   129  	return nil, fmt.Errorf("missing go.mod")
   130  }
   131  
   132  type Task struct {
   133  	Dir      string
   134  	PkgPath  string
   135  	Prefixes map[string]bool
   136  
   137  	// map[importPath][filename]*parsedFile
   138  	packages map[string]map[string]*parsedFile
   139  
   140  	pkgInternals map[string]bool
   141  	pkgUsed      map[string][]string
   142  }
   143  
   144  func (t *Task) Sync() error {
   145  	needToForks := map[string]bool{}
   146  
   147  	var findUsed func(importPath string) []string
   148  	findUsed = func(importPath string) (used []string) {
   149  		for _, importPath := range t.pkgUsed[importPath] {
   150  			used = append(append(used, importPath), findUsed(importPath)...)
   151  		}
   152  		return
   153  	}
   154  
   155  	for pkgImportPath := range t.pkgInternals {
   156  		needToForks[pkgImportPath] = true
   157  
   158  		for _, p := range findUsed(pkgImportPath) {
   159  			needToForks[p] = true
   160  		}
   161  	}
   162  
   163  	for pkgImportPath := range needToForks {
   164  		files := t.packages[pkgImportPath]
   165  
   166  		for filename := range files {
   167  			f := files[filename]
   168  
   169  			astFile := f.file
   170  
   171  			for _, i := range astFile.Imports {
   172  				importPath, _ := strconv.Unquote(i.Path.Value)
   173  
   174  				if needToForks[importPath] {
   175  					fmt.Println(filepath.Join(t.PkgPath, t.replaceInternal(importPath)))
   176  
   177  					_ = astutil.RewriteImport(
   178  						f.fset, astFile,
   179  						importPath,
   180  						filepath.Join(t.PkgPath, t.replaceInternal(importPath)),
   181  					)
   182  				}
   183  			}
   184  
   185  			output := filepath.Join(t.Dir, t.replaceInternal(pkgImportPath), filepath.Base(filename))
   186  
   187  			buf := bytes.NewBuffer(nil)
   188  			if err := format.Node(buf, f.fset, astFile); err != nil {
   189  				return err
   190  			}
   191  			if err := writeFile(output, buf.Bytes()); err != nil {
   192  				return err
   193  			}
   194  		}
   195  	}
   196  
   197  	return nil
   198  }
   199  
   200  func (t *Task) Scan(importPath string) error {
   201  	if _, ok := t.packages[importPath]; ok {
   202  		return nil
   203  	}
   204  
   205  	pkg, err := build.Import(importPath, "", build.FindOnly)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	if err := t.scanPkg(pkg); err != nil {
   211  		return err
   212  	}
   213  
   214  	log.Printf("SCANED %s", importPath)
   215  
   216  	return nil
   217  }
   218  
   219  func (t *Task) isInternalPkg(importPath string) bool {
   220  	if strings.Contains(importPath, "internal/") {
   221  		return true
   222  	}
   223  	return strings.HasSuffix(importPath, "internal") || strings.HasPrefix(importPath, "internal")
   224  }
   225  
   226  func (t *Task) scanPkg(pkg *build.Package) error {
   227  	files, err := filepath.Glob(pkg.Dir + "/*.go")
   228  	if err != nil {
   229  		return err
   230  	}
   231  
   232  	if t.isInternalPkg(pkg.ImportPath) {
   233  		if t.pkgInternals == nil {
   234  			t.pkgInternals = map[string]bool{}
   235  		}
   236  		t.pkgInternals[pkg.ImportPath] = true
   237  	}
   238  
   239  	for _, f := range files {
   240  		// skip test file
   241  		if strings.HasSuffix(f, "_test.go") {
   242  			continue
   243  		}
   244  
   245  		if err := t.scanGoFile(f, pkg); err != nil {
   246  			return err
   247  		}
   248  	}
   249  
   250  	return nil
   251  }
   252  
   253  func (t *Task) scanGoFile(filename string, pkg *build.Package) error {
   254  	f, err := newParsedFile(filename)
   255  	if err != nil {
   256  		return err
   257  	}
   258  
   259  	file := f.file
   260  
   261  	pkgImportPath := pkg.ImportPath
   262  
   263  	if pkg.Name == "" {
   264  		if file.Name.Name != "main" {
   265  			pkg.Name = file.Name.Name
   266  		}
   267  	}
   268  
   269  	if file.Name.Name != pkg.Name {
   270  		return nil
   271  	}
   272  
   273  	for _, i := range file.Imports {
   274  		importPath, _ := strconv.Unquote(i.Path.Value)
   275  
   276  		if t.pkgUsed == nil {
   277  			t.pkgUsed = map[string][]string{}
   278  		}
   279  
   280  		t.pkgUsed[importPath] = append(t.pkgUsed[importPath], pkgImportPath)
   281  
   282  		if t.Prefixes[strings.Split(importPath, "/")[0]] {
   283  			if err := t.Scan(importPath); err != nil {
   284  				return err
   285  			}
   286  		}
   287  	}
   288  
   289  	if t.packages == nil {
   290  		t.packages = map[string]map[string]*parsedFile{}
   291  	}
   292  
   293  	if t.packages[pkgImportPath] == nil {
   294  		t.packages[pkgImportPath] = map[string]*parsedFile{}
   295  	}
   296  
   297  	t.packages[pkgImportPath][filename] = f
   298  
   299  	return nil
   300  }
   301  
   302  func newParsedFile(filename string) (*parsedFile, error) {
   303  	f := &parsedFile{}
   304  
   305  	src, err := os.ReadFile(filename)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  
   310  	f.fset = token.NewFileSet()
   311  
   312  	file, err := parser.ParseFile(f.fset, filename, src, parser.ParseComments)
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	f.file = file
   318  
   319  	return f, nil
   320  }
   321  
   322  type parsedFile struct {
   323  	fset *token.FileSet
   324  	file *ast.File
   325  }
   326  
   327  func (t *Task) replaceInternal(p string) string {
   328  	if strings.HasSuffix(p, "internal") {
   329  		return filepath.Join(filepath.Dir(p), "./internals")
   330  	}
   331  
   332  	return strings.Replace(
   333  		p,
   334  		"internal/",
   335  		"internals/",
   336  		-1,
   337  	)
   338  }
   339  
   340  func writeFile(filename string, data []byte) error {
   341  	if err := os.MkdirAll(filepath.Dir(filename), os.ModePerm); err != nil {
   342  		return err
   343  	}
   344  	return os.WriteFile(filename, data, os.ModePerm)
   345  }