github.com/alkemics/goflow@v0.2.1/gfutil/gfgo/nodes.go (about)

     1  package gfgo
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/doc"
     7  	"go/parser"
     8  	"go/token"
     9  	"go/types"
    10  	"path"
    11  	"strings"
    12  	"sync"
    13  
    14  	"golang.org/x/tools/go/packages"
    15  
    16  	"github.com/alkemics/goflow"
    17  )
    18  
    19  type Node struct {
    20  	Pkg          string
    21  	Typ          string
    22  	PkgPath      string
    23  	Doc          string
    24  	Constructor  string
    25  	Method       string
    26  	Filename     string
    27  	Imports      []goflow.Import
    28  	Dependencies []goflow.Field
    29  	Inputs       []goflow.Field
    30  	Outputs      []goflow.Field
    31  }
    32  
    33  func (n Node) Match(typ string) bool {
    34  	if n.Method != "" && n.Method != "Run" {
    35  		return typ == fmt.Sprintf("%s.%s.%s", n.Pkg, n.Typ, n.Method)
    36  	}
    37  	return typ == fmt.Sprintf("%s.%s", n.Pkg, n.Typ)
    38  }
    39  
    40  type NodeLoader struct {
    41  	mu    sync.Mutex
    42  	pkgs  map[string]string
    43  	nodes []Node
    44  }
    45  
    46  func (l *NodeLoader) Load(pkgs ...string) error {
    47  	nodes, pkgMap, err := loadNodes(pkgs)
    48  	if err != nil {
    49  		return err
    50  	}
    51  
    52  	l.mu.Lock()
    53  	defer l.mu.Unlock()
    54  
    55  	if l.pkgs == nil {
    56  		l.pkgs = make(map[string]string)
    57  	}
    58  
    59  	for k, v := range pkgMap {
    60  		l.pkgs[k] = v
    61  	}
    62  
    63  	l.nodes = nodes
    64  
    65  	return nil
    66  }
    67  
    68  func (l *NodeLoader) Refresh(pkgName string) error {
    69  	l.mu.Lock()
    70  	defer l.mu.Unlock()
    71  
    72  	pkgPath := l.pkgs[pkgName]
    73  	if pkgPath == "" {
    74  		return fmt.Errorf("unkown pkg: %s", pkgName)
    75  	}
    76  
    77  	nodes, _, err := loadNodes([]string{pkgPath})
    78  	if err != nil {
    79  		return err
    80  	}
    81  
    82  	otherNodes := make([]Node, 0, len(l.nodes))
    83  
    84  	for _, node := range l.nodes {
    85  		if node.PkgPath != pkgPath {
    86  			otherNodes = append(otherNodes, node)
    87  		}
    88  	}
    89  
    90  	l.nodes = append(otherNodes, nodes...)
    91  
    92  	return nil
    93  }
    94  
    95  func (l *NodeLoader) All() []Node {
    96  	l.mu.Lock()
    97  	defer l.mu.Unlock()
    98  
    99  	nodes := make([]Node, len(l.nodes))
   100  	copy(nodes, l.nodes)
   101  
   102  	return nodes
   103  }
   104  
   105  func (l *NodeLoader) Find(typ string) Node {
   106  	l.mu.Lock()
   107  	defer l.mu.Unlock()
   108  
   109  	for _, node := range l.nodes {
   110  		if node.Match(typ) {
   111  			return node
   112  		}
   113  	}
   114  
   115  	return Node{}
   116  }
   117  
   118  type typeConstructor struct {
   119  	name         string
   120  	imports      []goflow.Import
   121  	dependencies []goflow.Field
   122  }
   123  
   124  func loadNodes(nodesPackages []string) ([]Node, map[string]string, error) {
   125  	pkgs, err := packages.Load(&packages.Config{
   126  		Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedDeps | packages.NeedSyntax | packages.NeedTypes | packages.NeedTypesInfo,
   127  	}, nodesPackages...)
   128  	if err != nil {
   129  		return nil, nil, err
   130  	}
   131  
   132  	errs := make([]error, 0)
   133  	nodes := make([]Node, 0)
   134  	pkgMap := make(map[string]string)
   135  	for _, pkg := range pkgs {
   136  		if pkg.TypesInfo == nil {
   137  			continue
   138  		}
   139  
   140  		pkgMap[pkg.Name] = pkg.PkgPath
   141  
   142  		typesDoc, err := parseTypesDoc(pkg)
   143  		if err != nil {
   144  			errs = append(errs, PkgError{
   145  				PkgPath: pkg.PkgPath,
   146  				Err:     err,
   147  			})
   148  		}
   149  
   150  		// Get constructors first.
   151  		// Start by registering all types.
   152  		allTypes := make(map[string]struct{})
   153  		ignoredTypes := make(map[string]struct{})
   154  		for k, v := range pkg.TypesInfo.Defs {
   155  			if v == nil ||
   156  				!v.Exported() ||
   157  				k.Obj == nil ||
   158  				k.Obj.Decl == nil ||
   159  				k.Obj.Kind != ast.Typ {
   160  				continue
   161  			}
   162  			decl, ok := k.Obj.Decl.(*ast.TypeSpec)
   163  			if !ok {
   164  				continue
   165  			}
   166  			typ := decl.Name.String()
   167  			allTypes[typ] = struct{}{}
   168  			if shouldIgnoreNode(typesDoc[typ]) {
   169  				ignoredTypes[typ] = struct{}{}
   170  			}
   171  		}
   172  
   173  		// Then register all constructors by type.
   174  		constructorNames := make(map[string]struct{})
   175  		constructors := make(map[string]typeConstructor)
   176  		for k, v := range pkg.TypesInfo.Defs {
   177  			if v == nil ||
   178  				!v.Exported() ||
   179  				k.Obj == nil ||
   180  				k.Obj.Decl == nil ||
   181  				k.Obj.Kind != ast.Fun ||
   182  				!strings.HasPrefix(k.Obj.Name, "New") {
   183  				continue
   184  			}
   185  
   186  			signature, ok := v.Type().(*types.Signature)
   187  			if !ok || signature.Results() == nil {
   188  				continue
   189  			}
   190  
   191  			decl, ok := k.Obj.Decl.(*ast.FuncDecl)
   192  			if !ok {
   193  				continue
   194  			}
   195  
   196  			imports, dependencies, results, err := ParseSignature(signature)
   197  			if err != nil {
   198  				errs = append(errs, PkgError{
   199  					PkgPath: pkg.PkgPath,
   200  					Err:     fmt.Errorf("parsing %s: %w", decl.Name.String(), err),
   201  				})
   202  			}
   203  
   204  			if len(results) != 1 {
   205  				continue
   206  			}
   207  
   208  			// Get the type of returned by the constructor.
   209  			typ := strings.TrimPrefix(results[0].Type, "*")
   210  			split := strings.Split(typ, ".")
   211  			if len(split) > 1 {
   212  				typ = split[len(split)-1]
   213  			}
   214  
   215  			if _, ok := allTypes[typ]; !ok {
   216  				// It's not a constructor.
   217  				continue
   218  			}
   219  
   220  			constructorNames[decl.Name.String()] = struct{}{}
   221  			if _, ok := ignoredTypes[typ]; ok {
   222  				continue
   223  			}
   224  
   225  			constructors[typ] = typeConstructor{
   226  				name:         decl.Name.String(),
   227  				imports:      imports,
   228  				dependencies: dependencies,
   229  			}
   230  		}
   231  
   232  		for k, v := range pkg.TypesInfo.Defs {
   233  			if v == nil || !v.Exported() || k.Obj == nil || k.Obj.Decl == nil {
   234  				continue
   235  			}
   236  			switch k.Obj.Kind {
   237  			case ast.Fun:
   238  				// Load the functions (we don't have the methods here).
   239  				signature, ok := v.Type().(*types.Signature)
   240  				if !ok {
   241  					continue
   242  				}
   243  				decl, ok := k.Obj.Decl.(*ast.FuncDecl)
   244  				if !ok {
   245  					continue
   246  				}
   247  				if _, ok := constructorNames[decl.Name.String()]; ok {
   248  					// Ignore constructors.
   249  					continue
   250  				}
   251  				if shouldIgnoreNode(decl.Doc.Text()) {
   252  					continue
   253  				}
   254  				node, err := createGoNodeFromFunc(decl, signature, pkg)
   255  				if err != nil {
   256  					errs = append(errs, PkgError{
   257  						PkgPath: pkg.PkgPath,
   258  						Err:     fmt.Errorf("parsing %s: %w", decl.Name, err),
   259  					})
   260  				} else {
   261  					nodes = append(nodes, node)
   262  				}
   263  			case ast.Typ:
   264  				// Load the types with their methods.
   265  				named, ok := v.Type().(*types.Named)
   266  				if !ok {
   267  					continue
   268  				}
   269  				decl, ok := k.Obj.Decl.(*ast.TypeSpec)
   270  				if !ok {
   271  					continue
   272  				}
   273  				typ := decl.Name.String()
   274  				if _, ok := ignoredTypes[typ]; ok {
   275  					continue
   276  				}
   277  				constructor := constructors[typ]
   278  				if constructor.name == "" {
   279  					continue
   280  				}
   281  				ns, err := createGoNodeFromType(decl, constructor, named, pkg)
   282  				if err != nil {
   283  					errs = append(errs, PkgError{
   284  						PkgPath: pkg.PkgPath,
   285  						Err:     fmt.Errorf("parsing %s: %w", decl.Name, err),
   286  					})
   287  				} else {
   288  					nodes = append(nodes, ns...)
   289  				}
   290  			default:
   291  				continue
   292  			}
   293  		}
   294  	}
   295  
   296  	if len(errs) > 0 {
   297  		err = goflow.MultiError{Errs: errs}
   298  	}
   299  
   300  	return nodes, pkgMap, err
   301  }
   302  
   303  func createGoNodeFromFunc(decl *ast.FuncDecl, signature *types.Signature, pkg *packages.Package) (Node, error) {
   304  	imports, inputs, outputs, err := ParseSignature(signature)
   305  	if err != nil {
   306  		return Node{}, err
   307  	}
   308  
   309  	return Node{
   310  		Pkg:     pkg.Name,
   311  		Typ:     decl.Name.String(),
   312  		PkgPath: pkg.PkgPath,
   313  		Doc:     decl.Doc.Text(),
   314  		// Add node import.
   315  		Imports: append(
   316  			imports,
   317  			goflow.Import{
   318  				Pkg: pkg.Name,
   319  				Dir: pkg.PkgPath,
   320  			},
   321  		),
   322  		Inputs:   inputs,
   323  		Outputs:  outputs,
   324  		Filename: pkg.Fset.File(decl.Pos()).Name(),
   325  	}, nil
   326  }
   327  
   328  func createGoNodeFromType(decl *ast.TypeSpec, constructor typeConstructor, named *types.Named, pkg *packages.Package) ([]Node, error) {
   329  	baseImports := append(constructor.imports, goflow.Import{
   330  		Pkg: pkg.Name,
   331  		Dir: pkg.PkgPath,
   332  	})
   333  
   334  	// Add one node per exported method.
   335  	errs := make([]error, 0)
   336  	nodes := make([]Node, 0, named.NumMethods())
   337  	for i := 0; i < named.NumMethods(); i++ {
   338  		method := named.Method(i)
   339  		if !method.Exported() {
   340  			continue
   341  		}
   342  
   343  		doc := getDocFromPackage(pkg, method.Pos())
   344  		if shouldIgnoreNode(doc) {
   345  			continue
   346  		}
   347  
   348  		signature, ok := method.Type().(*types.Signature)
   349  		if !ok {
   350  			continue
   351  		}
   352  
   353  		imports, inputs, outputs, err := ParseSignature(signature)
   354  		if err != nil {
   355  			errs = append(errs, err)
   356  			continue
   357  		}
   358  
   359  		nodes = append(nodes, Node{
   360  			Pkg:     pkg.Name,
   361  			Typ:     decl.Name.String(),
   362  			PkgPath: pkg.PkgPath,
   363  			Doc:     doc,
   364  			// TODO: parse deps from struct rather than from constructor if not supplied
   365  			Dependencies: constructor.dependencies,
   366  			Constructor:  constructor.name,
   367  			Method:       method.Name(),
   368  			Imports:      append(imports, baseImports...),
   369  			Inputs:       inputs,
   370  			Outputs:      outputs,
   371  			Filename:     pkg.Fset.File(decl.Pos()).Name(),
   372  		})
   373  	}
   374  
   375  	var err error
   376  	if len(errs) > 0 {
   377  		err = goflow.MultiError{Errs: errs}
   378  	}
   379  
   380  	return nodes, err
   381  }
   382  
   383  func shouldIgnoreNode(doc string) bool {
   384  	for _, line := range strings.Split(doc, "\n") {
   385  		if line == "node:ignore" {
   386  			return true
   387  		}
   388  	}
   389  	return false
   390  }
   391  
   392  // getDocFromPackage loads the doc directly from the syntax if we don't have it already.
   393  func getDocFromPackage(pkg *packages.Package, pos token.Pos) string {
   394  	for _, f := range pkg.Syntax {
   395  		for _, d := range f.Decls {
   396  			switch decl := d.(type) {
   397  			case *ast.FuncDecl:
   398  				if decl.Name.NamePos == pos {
   399  					return decl.Doc.Text()
   400  				}
   401  			}
   402  		}
   403  	}
   404  	return ""
   405  }
   406  
   407  // parseTypesDoc hacks the thing by grabbing info of types using doc.
   408  // TODO: review all this and improve if possible...
   409  //       shall we try with https://pkg.go.dev/go/ast?tab=doc#CommentGroup maybe?
   410  func parseTypesDoc(pkg *packages.Package) (map[string]string, error) {
   411  	if len(pkg.GoFiles) == 0 {
   412  		return nil, nil
   413  	}
   414  
   415  	pkgs, err := parser.ParseDir(pkg.Fset, path.Dir(pkg.GoFiles[0]), nil, parser.ParseComments)
   416  	if err != nil {
   417  		return nil, err
   418  	}
   419  
   420  	if _, ok := pkgs[pkg.Name]; !ok {
   421  		return nil, PkgError{
   422  			PkgPath: pkg.PkgPath,
   423  			Err:     fmt.Errorf("package %s not found", pkg.Name),
   424  		}
   425  	}
   426  
   427  	allDecls := doc.New(pkgs[pkg.Name], "", doc.AllDecls)
   428  	mappedDoc := make(map[string]string)
   429  	for _, t := range allDecls.Types {
   430  		mappedDoc[t.Name] = t.Doc
   431  	}
   432  	return mappedDoc, nil
   433  }