github.com/Deiz/tracegen@v0.2.2/cmd/tracegen/update_test.go (about)

     1  package main
     2  
     3  import (
     4  	"os"
     5  	"path/filepath"
     6  	"testing"
     7  
     8  	"github.com/Deiz/tracegen"
     9  )
    10  
    11  const gomod = `module test
    12  
    13  go 1.17
    14  
    15  require github.com/opentracing/opentracing-go v1.2.0`
    16  
    17  const input0 = `package main
    18  
    19  import "context"
    20  
    21  func Foo(ctx context.Context) {}
    22  `
    23  
    24  const output0 = `package main
    25  
    26  import (
    27  	"context"
    28  
    29  	"github.com/opentracing/opentracing-go"
    30  )
    31  
    32  func Foo(ctx context.Context) {
    33  	span, ctx := opentracing.StartSpanFromContext(ctx, "Foo")
    34  	defer span.Finish()
    35  }
    36  `
    37  
    38  const input1 = output0
    39  const output1 = input0
    40  
    41  const input2 = `package main
    42  
    43  import (
    44  	"context"
    45  
    46  	"github.com/opentracing/opentracing-go"
    47  )
    48  
    49  //trace:skip
    50  func Foo(ctx context.Context) {
    51  	span, ctx := opentracing.StartSpanFromContext(ctx, "Foo")
    52  	defer span.Finish()
    53  }
    54  `
    55  
    56  const output2 = `package main
    57  
    58  import "context"
    59  
    60  //trace:skip
    61  func Foo(ctx context.Context) {}
    62  `
    63  
    64  func check(t *testing.T, err error) {
    65  	t.Helper()
    66  
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  }
    71  
    72  func writeModule(t *testing.T, sample string) (path string) {
    73  	t.Helper()
    74  
    75  	dir, err := os.MkdirTemp("", "")
    76  	check(t, err)
    77  
    78  	path = filepath.Join(dir, "sample.go")
    79  	err = os.WriteFile(path, []byte(sample), 0644)
    80  	check(t, err)
    81  
    82  	err = os.WriteFile(filepath.Join(dir, "go.mod"), []byte(gomod), 0644)
    83  	check(t, err)
    84  
    85  	return path
    86  }
    87  
    88  func TestUpdater(t *testing.T) {
    89  	tests := map[string]struct {
    90  		input    string
    91  		expected string
    92  		settings tracegen.Settings
    93  	}{
    94  		"add span":           {input0, output0, tracegen.Settings{}},
    95  		"remove span":        {input1, output1, tracegen.Settings{Methods: true}},
    96  		"remove span (skip)": {input2, output2, tracegen.Settings{}},
    97  	}
    98  
    99  	for name, test := range tests {
   100  		t.Run(name, func(t *testing.T) {
   101  			path := writeModule(t, test.input)
   102  
   103  			err := os.Chdir(filepath.Dir(path))
   104  			check(t, err)
   105  
   106  			err = tracegen.Process(
   107  				test.settings,
   108  				[]string{"."},
   109  				update,
   110  				getResolver,
   111  			)
   112  			check(t, err)
   113  
   114  			data, err := os.ReadFile(path)
   115  			check(t, err)
   116  
   117  			if string(data) != test.expected {
   118  				t.Fatalf("mismatched output in %s:\ngot:\n%s\nexpected:\n%s", path, string(data), test.expected)
   119  			}
   120  		})
   121  	}
   122  }