github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/internal/cmd/ast-to-pattern/parse.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/parser"
     8  	"go/scanner"
     9  	"go/token"
    10  	"text/template"
    11  )
    12  
    13  var tmplDecl = template.Must(template.New("").Parse(`` +
    14  	`package p; {{ . }}`))
    15  
    16  var tmplExprs = template.Must(template.New("").Parse(`` +
    17  	`package p; var _ = []interface{}{ {{ . }}, }`))
    18  
    19  var tmplStmts = template.Must(template.New("").Parse(`` +
    20  	`package p; func _() { {{ . }} }`))
    21  
    22  var tmplType = template.Must(template.New("").Parse(`` +
    23  	`package p; var _ {{ . }}`))
    24  
    25  var tmplValSpec = template.Must(template.New("").Parse(`` +
    26  	`package p; var {{ . }}`))
    27  
    28  func execTmpl(tmpl *template.Template, src string) string {
    29  	var buf bytes.Buffer
    30  	if err := tmpl.Execute(&buf, src); err != nil {
    31  		panic(err)
    32  	}
    33  	return buf.String()
    34  }
    35  
    36  func noBadNodes(node ast.Node) bool {
    37  	any := false
    38  	ast.Inspect(node, func(n ast.Node) bool {
    39  		if any {
    40  			return false
    41  		}
    42  		switch n.(type) {
    43  		case *ast.BadExpr, *ast.BadDecl:
    44  			any = true
    45  		}
    46  		return true
    47  	})
    48  	return !any
    49  }
    50  
    51  func parseType(fset *token.FileSet, src string) (ast.Expr, *ast.File, error) {
    52  	asType := execTmpl(tmplType, src)
    53  	f, err := parser.ParseFile(fset, "", asType, parser.SkipObjectResolution)
    54  	if err != nil {
    55  		return nil, nil, err
    56  	}
    57  	vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec)
    58  	return vs.Type, f, nil
    59  }
    60  
    61  // parseDetectingNode tries its best to parse the ast.Node contained in src, as
    62  // one of: *ast.File, ast.Decl, ast.Expr, ast.Stmt, *ast.ValueSpec.
    63  // It also returns the *ast.File used for the parsing, so that the returned node
    64  // can be easily type-checked.
    65  func parseDetectingNode(fset *token.FileSet, src string) (interface{}, error) {
    66  	file := fset.AddFile("", fset.Base(), len(src))
    67  	scan := scanner.Scanner{}
    68  	scan.Init(file, []byte(src), nil, 0)
    69  	if _, tok, _ := scan.Scan(); tok == token.EOF {
    70  		return nil, fmt.Errorf("empty source code")
    71  	}
    72  	var mainErr error
    73  
    74  	// first try as a whole file
    75  	if f, err := parser.ParseFile(fset, "", src, parser.SkipObjectResolution); err == nil && noBadNodes(f) {
    76  		return f, nil
    77  	}
    78  
    79  	// then as a single declaration, or many
    80  	asDecl := execTmpl(tmplDecl, src)
    81  	if f, err := parser.ParseFile(fset, "", asDecl, parser.SkipObjectResolution); err == nil && noBadNodes(f) {
    82  		if len(f.Decls) == 1 {
    83  			return f.Decls[0], nil
    84  		}
    85  		return f, nil
    86  	}
    87  
    88  	// then as value expressions
    89  	asExprs := execTmpl(tmplExprs, src)
    90  	if f, err := parser.ParseFile(fset, "", asExprs, parser.SkipObjectResolution); err == nil && noBadNodes(f) {
    91  		vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec)
    92  		cl := vs.Values[0].(*ast.CompositeLit)
    93  		if len(cl.Elts) == 1 {
    94  			return cl.Elts[0], nil
    95  		}
    96  		return cl.Elts, nil
    97  	}
    98  
    99  	// then try as statements
   100  	asStmts := execTmpl(tmplStmts, src)
   101  	if f, err := parser.ParseFile(fset, "", asStmts, parser.SkipObjectResolution); err == nil && noBadNodes(f) {
   102  		bl := f.Decls[0].(*ast.FuncDecl).Body
   103  		if len(bl.List) == 1 {
   104  			return bl.List[0], nil
   105  		}
   106  		return bl.List, nil
   107  	} else {
   108  		// Statements is what covers most cases, so it will give
   109  		// the best overall error message. Show positions
   110  		// relative to where the user's code is put in the
   111  		// template.
   112  		mainErr = err
   113  	}
   114  
   115  	// type expressions not yet picked up, for e.g. chans and interfaces
   116  	if typ, f, err := parseType(fset, src); err == nil && noBadNodes(f) {
   117  		return typ, nil
   118  	}
   119  
   120  	// value specs
   121  	asValSpec := execTmpl(tmplValSpec, src)
   122  	if f, err := parser.ParseFile(fset, "", asValSpec, parser.SkipObjectResolution); err == nil && noBadNodes(f) {
   123  		vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec)
   124  		return vs, nil
   125  	}
   126  	return nil, mainErr
   127  }