launchpad.net/gocheck@v0.0.0-20140225173054-000000000087/printer.go (about)

     1  package gocheck
     2  
     3  import (
     4  	"bytes"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/printer"
     8  	"go/token"
     9  	"os"
    10  )
    11  
    12  func indent(s, with string) (r string) {
    13  	eol := true
    14  	for i := 0; i != len(s); i++ {
    15  		c := s[i]
    16  		switch {
    17  		case eol && c == '\n' || c == '\r':
    18  		case c == '\n' || c == '\r':
    19  			eol = true
    20  		case eol:
    21  			eol = false
    22  			s = s[:i] + with + s[i:]
    23  			i += len(with)
    24  		}
    25  	}
    26  	return s
    27  }
    28  
    29  func printLine(filename string, line int) (string, error) {
    30  	fset := token.NewFileSet()
    31  	file, err := os.Open(filename)
    32  	if err != nil {
    33  		return "", err
    34  	}
    35  	fnode, err := parser.ParseFile(fset, filename, file, parser.ParseComments)
    36  	if err != nil {
    37  		return "", err
    38  	}
    39  	config := &printer.Config{Mode: printer.UseSpaces, Tabwidth: 4}
    40  	lp := &linePrinter{fset: fset, fnode: fnode, line: line, config: config}
    41  	ast.Walk(lp, fnode)
    42  	result := lp.output.Bytes()
    43  	// Comments leave \n at the end.
    44  	n := len(result)
    45  	for n > 0 && result[n-1] == '\n' {
    46  		n--
    47  	}
    48  	return string(result[:n]), nil
    49  }
    50  
    51  type linePrinter struct {
    52  	config *printer.Config
    53  	fset   *token.FileSet
    54  	fnode  *ast.File
    55  	line   int
    56  	output bytes.Buffer
    57  	stmt   ast.Stmt
    58  }
    59  
    60  func (lp *linePrinter) emit() bool {
    61  	if lp.stmt != nil {
    62  		lp.trim(lp.stmt)
    63  		lp.printWithComments(lp.stmt)
    64  		lp.stmt = nil
    65  		return true
    66  	}
    67  	return false
    68  }
    69  
    70  func (lp *linePrinter) printWithComments(n ast.Node) {
    71  	nfirst := lp.fset.Position(n.Pos()).Line
    72  	nlast := lp.fset.Position(n.End()).Line
    73  	for _, g := range lp.fnode.Comments {
    74  		cfirst := lp.fset.Position(g.Pos()).Line
    75  		clast := lp.fset.Position(g.End()).Line
    76  		if clast == nfirst-1 && lp.fset.Position(n.Pos()).Column == lp.fset.Position(g.Pos()).Column {
    77  			for _, c := range g.List {
    78  				lp.output.WriteString(c.Text)
    79  				lp.output.WriteByte('\n')
    80  			}
    81  		}
    82  		if cfirst >= nfirst && cfirst <= nlast && n.End() <= g.List[0].Slash {
    83  			// The printer will not include the comment if it starts past
    84  			// the node itself. Trick it into printing by overlapping the
    85  			// slash with the end of the statement.
    86  			g.List[0].Slash = n.End() - 1
    87  		}
    88  	}
    89  	node := &printer.CommentedNode{n, lp.fnode.Comments}
    90  	lp.config.Fprint(&lp.output, lp.fset, node)
    91  }
    92  
    93  func (lp *linePrinter) Visit(n ast.Node) (w ast.Visitor) {
    94  	if n == nil {
    95  		if lp.output.Len() == 0 {
    96  			lp.emit()
    97  		}
    98  		return nil
    99  	}
   100  	first := lp.fset.Position(n.Pos()).Line
   101  	last := lp.fset.Position(n.End()).Line
   102  	if first <= lp.line && last >= lp.line {
   103  		// Print the innermost statement containing the line.
   104  		if stmt, ok := n.(ast.Stmt); ok {
   105  			if _, ok := n.(*ast.BlockStmt); !ok {
   106  				lp.stmt = stmt
   107  			}
   108  		}
   109  		if first == lp.line && lp.emit() {
   110  			return nil
   111  		}
   112  		return lp
   113  	}
   114  	return nil
   115  }
   116  
   117  func (lp *linePrinter) trim(n ast.Node) bool {
   118  	stmt, ok := n.(ast.Stmt)
   119  	if !ok {
   120  		return true
   121  	}
   122  	line := lp.fset.Position(n.Pos()).Line
   123  	if line != lp.line {
   124  		return false
   125  	}
   126  	switch stmt := stmt.(type) {
   127  	case *ast.IfStmt:
   128  		stmt.Body = lp.trimBlock(stmt.Body)
   129  	case *ast.SwitchStmt:
   130  		stmt.Body = lp.trimBlock(stmt.Body)
   131  	case *ast.TypeSwitchStmt:
   132  		stmt.Body = lp.trimBlock(stmt.Body)
   133  	case *ast.CaseClause:
   134  		stmt.Body = lp.trimList(stmt.Body)
   135  	case *ast.CommClause:
   136  		stmt.Body = lp.trimList(stmt.Body)
   137  	case *ast.BlockStmt:
   138  		stmt.List = lp.trimList(stmt.List)
   139  	}
   140  	return true
   141  }
   142  
   143  func (lp *linePrinter) trimBlock(stmt *ast.BlockStmt) *ast.BlockStmt {
   144  	if !lp.trim(stmt) {
   145  		return lp.emptyBlock(stmt)
   146  	}
   147  	stmt.Rbrace = stmt.Lbrace
   148  	return stmt
   149  }
   150  
   151  func (lp *linePrinter) trimList(stmts []ast.Stmt) []ast.Stmt {
   152  	for i := 0; i != len(stmts); i++ {
   153  		if !lp.trim(stmts[i]) {
   154  			stmts[i] = lp.emptyStmt(stmts[i])
   155  			break
   156  		}
   157  	}
   158  	return stmts
   159  }
   160  
   161  func (lp *linePrinter) emptyStmt(n ast.Node) *ast.ExprStmt {
   162  	return &ast.ExprStmt{&ast.Ellipsis{n.Pos(), nil}}
   163  }
   164  
   165  func (lp *linePrinter) emptyBlock(n ast.Node) *ast.BlockStmt {
   166  	p := n.Pos()
   167  	return &ast.BlockStmt{p, []ast.Stmt{lp.emptyStmt(n)}, p}
   168  }