github.com/bgentry/go@v0.0.0-20150121062915-6cf5a733d54d/src/cmd/fix/main_test.go (about)

     1  // Copyright 2011 The Go Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"go/ast"
     9  	"go/parser"
    10  	"strings"
    11  	"testing"
    12  )
    13  
    14  type testCase struct {
    15  	Name string
    16  	Fn   func(*ast.File) bool
    17  	In   string
    18  	Out  string
    19  }
    20  
    21  var testCases []testCase
    22  
    23  func addTestCases(t []testCase, fn func(*ast.File) bool) {
    24  	// Fill in fn to avoid repetition in definitions.
    25  	if fn != nil {
    26  		for i := range t {
    27  			if t[i].Fn == nil {
    28  				t[i].Fn = fn
    29  			}
    30  		}
    31  	}
    32  	testCases = append(testCases, t...)
    33  }
    34  
    35  func fnop(*ast.File) bool { return false }
    36  
    37  func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
    38  	file, err := parser.ParseFile(fset, desc, in, parserMode)
    39  	if err != nil {
    40  		t.Errorf("%s: parsing: %v", desc, err)
    41  		return
    42  	}
    43  
    44  	outb, err := gofmtFile(file)
    45  	if err != nil {
    46  		t.Errorf("%s: printing: %v", desc, err)
    47  		return
    48  	}
    49  	if s := string(outb); in != s && mustBeGofmt {
    50  		t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
    51  			desc, desc, in, desc, s)
    52  		tdiff(t, in, s)
    53  		return
    54  	}
    55  
    56  	if fn == nil {
    57  		for _, fix := range fixes {
    58  			if fix.f(file) {
    59  				fixed = true
    60  			}
    61  		}
    62  	} else {
    63  		fixed = fn(file)
    64  	}
    65  
    66  	outb, err = gofmtFile(file)
    67  	if err != nil {
    68  		t.Errorf("%s: printing: %v", desc, err)
    69  		return
    70  	}
    71  
    72  	return string(outb), fixed, true
    73  }
    74  
    75  func TestRewrite(t *testing.T) {
    76  	for _, tt := range testCases {
    77  		// Apply fix: should get tt.Out.
    78  		out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
    79  		if !ok {
    80  			continue
    81  		}
    82  
    83  		// reformat to get printing right
    84  		out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
    85  		if !ok {
    86  			continue
    87  		}
    88  
    89  		if out != tt.Out {
    90  			t.Errorf("%s: incorrect output.\n", tt.Name)
    91  			if !strings.HasPrefix(tt.Name, "testdata/") {
    92  				t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
    93  			}
    94  			tdiff(t, out, tt.Out)
    95  			continue
    96  		}
    97  
    98  		if changed := out != tt.In; changed != fixed {
    99  			t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
   100  			continue
   101  		}
   102  
   103  		// Should not change if run again.
   104  		out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
   105  		if !ok {
   106  			continue
   107  		}
   108  
   109  		if fixed2 {
   110  			t.Errorf("%s: applied fixes during second round", tt.Name)
   111  			continue
   112  		}
   113  
   114  		if out2 != out {
   115  			t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
   116  				tt.Name, out, out2)
   117  			tdiff(t, out, out2)
   118  		}
   119  	}
   120  }
   121  
   122  func tdiff(t *testing.T, a, b string) {
   123  	data, err := diff([]byte(a), []byte(b))
   124  	if err != nil {
   125  		t.Error(err)
   126  		return
   127  	}
   128  	t.Error(string(data))
   129  }