github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/code/restorer_test.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package code_test
    16  
    17  import (
    18  	"io/ioutil"
    19  	"os"
    20  	"path/filepath"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/pingcap/failpoint/code"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  func TestRestore(t *testing.T) {
    29  	restorer := code.NewRestorer("not-exists-path")
    30  	err := restorer.Restore()
    31  	require.EqualError(t, err, `lstat not-exists-path: no such file or directory`)
    32  }
    33  
    34  func TestRestoreModification(t *testing.T) {
    35  	var cases = []struct {
    36  		filepath string
    37  		original string
    38  		modified string
    39  		expected string
    40  	}{
    41  		{
    42  			filepath: "modified-test.go",
    43  			original: `
    44  package rewriter_test
    45  
    46  import (
    47  	"fmt"
    48  
    49  	"github.com/pingcap/failpoint"
    50  )
    51  
    52  func unittest() {
    53  	failpoint.Inject("failpoint-name", func(val failpoint.Value) {
    54  		fmt.Println("unit-test", val)
    55  	})
    56  }
    57  `,
    58  			modified: `
    59  package rewriter_test
    60  
    61  import (
    62  	"fmt"
    63  
    64  	"github.com/pingcap/failpoint"
    65  )
    66  
    67  func unittest() {
    68  	if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name")); _err_ == nil {
    69  		fmt.Println("extra add line")
    70  		fmt.Println("unit-test", val)
    71  	}
    72  }
    73  `,
    74  			expected: `
    75  package rewriter_test
    76  
    77  import (
    78  	"fmt"
    79  
    80  	"github.com/pingcap/failpoint"
    81  )
    82  
    83  func unittest() {
    84  	failpoint.Inject("failpoint-name", func(val failpoint.Value) {
    85  		fmt.Println("extra add line")
    86  		fmt.Println("unit-test", val)
    87  	})
    88  }
    89  `,
    90  		},
    91  
    92  		{
    93  			filepath: "modified-test-2.go",
    94  			original: `
    95  package rewriter_test
    96  
    97  import (
    98  	"fmt"
    99  
   100  	"github.com/pingcap/failpoint"
   101  )
   102  
   103  func unittest() {
   104  	failpoint.Inject("failpoint-name", func(val failpoint.Value) {
   105  		fmt.Println("unit-test", val)
   106  	})
   107  }
   108  `,
   109  			modified: `
   110  package rewriter_test
   111  
   112  import (
   113  	"fmt"
   114  
   115  	"github.com/pingcap/failpoint"
   116  )
   117  
   118  func unittest() {
   119  	if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name")); _err_ == nil {
   120  		fmt.Println("extra add line")
   121  		fmt.Println("unit-test", val)
   122  	}
   123  	fmt.Println("extra add line2")
   124  }
   125  `,
   126  			expected: `
   127  package rewriter_test
   128  
   129  import (
   130  	"fmt"
   131  
   132  	"github.com/pingcap/failpoint"
   133  )
   134  
   135  func unittest() {
   136  	failpoint.Inject("failpoint-name", func(val failpoint.Value) {
   137  		fmt.Println("extra add line")
   138  		fmt.Println("unit-test", val)
   139  	})
   140  	fmt.Println("extra add line2")
   141  }
   142  `,
   143  		},
   144  
   145  		{
   146  			filepath: "modified-test-3.go",
   147  			original: `
   148  package rewriter_test
   149  
   150  import (
   151  	"fmt"
   152  
   153  	"github.com/pingcap/failpoint"
   154  )
   155  
   156  func unittest() {
   157  	failpoint.Inject("failpoint-name", func(val failpoint.Value) {
   158  		fmt.Println("unit-test", val)
   159  	})
   160  }
   161  `,
   162  			modified: `
   163  package rewriter_test
   164  
   165  import (
   166  	"fmt"
   167  
   168  	"github.com/pingcap/failpoint"
   169  )
   170  
   171  func unittest() {
   172  	if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name-extra-part")); _err_ == nil {
   173  		fmt.Println("extra add line")
   174  		fmt.Println("unit-test", val)
   175  	}
   176  	fmt.Println("extra add line2")
   177  }
   178  `,
   179  			expected: `
   180  package rewriter_test
   181  
   182  import (
   183  	"fmt"
   184  
   185  	"github.com/pingcap/failpoint"
   186  )
   187  
   188  func unittest() {
   189  	failpoint.Inject("failpoint-name-extra-part", func(val failpoint.Value) {
   190  		fmt.Println("extra add line")
   191  		fmt.Println("unit-test", val)
   192  	})
   193  	fmt.Println("extra add line2")
   194  }
   195  `,
   196  		},
   197  	}
   198  
   199  	// Create temp files
   200  	err := os.MkdirAll(restorePath, 0755)
   201  	require.NoError(t, err)
   202  	for _, cs := range cases {
   203  		original := filepath.Join(restorePath, cs.filepath)
   204  		err := ioutil.WriteFile(original, []byte(cs.original), 0644)
   205  		require.NoError(t, err)
   206  	}
   207  
   208  	// Clean all temp files
   209  	defer func() {
   210  		err := os.RemoveAll(restorePath)
   211  		require.NoError(t, err)
   212  	}()
   213  
   214  	rewriter := code.NewRewriter(restorePath)
   215  	err = rewriter.Rewrite()
   216  	require.NoError(t, err)
   217  
   218  	for _, cs := range cases {
   219  		modified := filepath.Join(restorePath, cs.filepath)
   220  		err := ioutil.WriteFile(modified, []byte(cs.modified), 0644)
   221  		require.NoError(t, err)
   222  	}
   223  
   224  	// Restore workspace
   225  	restorer := code.NewRestorer(restorePath)
   226  	err = restorer.Restore()
   227  	require.NoError(t, err)
   228  
   229  	for _, cs := range cases {
   230  		expected := filepath.Join(restorePath, cs.filepath)
   231  		content, err := ioutil.ReadFile(expected)
   232  		require.NoError(t, err)
   233  		require.Equalf(t, strings.TrimSpace(cs.expected), strings.TrimSpace(string(content)), "%v", cs.filepath)
   234  	}
   235  }
   236  
   237  func TestRestoreModificationBad(t *testing.T) {
   238  	var cases = []struct {
   239  		filepath string
   240  		original string
   241  		modified string
   242  	}{
   243  		{
   244  			filepath: "bad-modification-test.go",
   245  			original: `
   246  package rewriter_test
   247  
   248  import (
   249  	"fmt"
   250  
   251  	"github.com/pingcap/failpoint"
   252  )
   253  
   254  func unittest() {
   255  	failpoint.Inject("failpoint-name", func(val failpoint.Value) {
   256  		fmt.Println("unit-test", val)
   257  	})
   258  }
   259  `,
   260  			modified: `
   261  package rewriter_test
   262  
   263  import (
   264  	"fmt"
   265  
   266  	"github.com/pingcap/failpoint"
   267  )
   268  
   269  func unittest() {
   270  	if val, _err_ := failpoint.EvalContext(nil, _curpkg_("failpoint-name-extra-part")); _err_ == nil {
   271  		fmt.Println("extra add line")
   272  		fmt.Println("unit-test", val)
   273  	}
   274  }
   275  `,
   276  		},
   277  	}
   278  
   279  	// Create temp files
   280  	err := os.MkdirAll(restorePath, 0755)
   281  	require.NoError(t, err)
   282  	for _, cs := range cases {
   283  		original := filepath.Join(restorePath, cs.filepath)
   284  		err := ioutil.WriteFile(original, []byte(cs.original), 0644)
   285  		require.NoError(t, err)
   286  	}
   287  
   288  	// Clean all temp files
   289  	defer func() {
   290  		err := os.RemoveAll(restorePath)
   291  		require.NoError(t, err)
   292  	}()
   293  
   294  	rewriter := code.NewRewriter(restorePath)
   295  	err = rewriter.Rewrite()
   296  	require.NoError(t, err)
   297  
   298  	for _, cs := range cases {
   299  		modified := filepath.Join(restorePath, cs.filepath)
   300  		err := ioutil.WriteFile(modified, []byte(cs.modified), 0644)
   301  		require.NoError(t, err)
   302  	}
   303  
   304  	restorer := code.NewRestorer(restorePath)
   305  	err = restorer.Restore()
   306  	require.Error(t, err)
   307  	require.Regexp(t, `cannot merge modifications back automatically.*`, err.Error())
   308  }