github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/code/restorer_test.go (about) 1 // Copyright 2021 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package code_test 16 17 import ( 18 "io/ioutil" 19 "os" 20 "path/filepath" 21 "strings" 22 "testing" 23 24 "github.com/pingcap/failpoint/code" 25 "github.com/stretchr/testify/require" 26 ) 27 28 func TestRestore(t *testing.T) { 29 restorer := code.NewRestorer("not-exists-path") 30 err := restorer.Restore() 31 require.EqualError(t, err, `lstat not-exists-path: no such file or directory`) 32 } 33 34 func TestRestoreModification(t *testing.T) { 35 var cases = []struct { 36 filepath string 37 original string 38 modified string 39 expected string 40 }{ 41 { 42 filepath: "modified-test.go", 43 original: ` 44 package rewriter_test 45 46 import ( 47 "fmt" 48 49 "github.com/pingcap/failpoint" 50 ) 51 52 func unittest() { 53 failpoint.Inject("failpoint-name", func(val failpoint.Value) { 54 fmt.Println("unit-test", val) 55 }) 56 } 57 `, 58 modified: ` 59 package rewriter_test 60 61 import ( 62 "fmt" 63 64 "github.com/pingcap/failpoint" 65 ) 66 67 func unittest() { 68 if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name")); _err_ == nil { 69 fmt.Println("extra add line") 70 fmt.Println("unit-test", val) 71 } 72 } 73 `, 74 expected: ` 75 package rewriter_test 76 77 import ( 78 "fmt" 79 80 "github.com/pingcap/failpoint" 81 ) 82 83 func unittest() { 84 failpoint.Inject("failpoint-name", func(val failpoint.Value) { 85 fmt.Println("extra add line") 86 fmt.Println("unit-test", val) 87 }) 88 } 89 `, 90 }, 91 92 { 93 filepath: "modified-test-2.go", 94 original: ` 95 package rewriter_test 96 97 import ( 98 "fmt" 99 100 "github.com/pingcap/failpoint" 101 ) 102 103 func unittest() { 104 failpoint.Inject("failpoint-name", func(val failpoint.Value) { 105 fmt.Println("unit-test", val) 106 }) 107 } 108 `, 109 modified: ` 110 package rewriter_test 111 112 import ( 113 "fmt" 114 115 "github.com/pingcap/failpoint" 116 ) 117 118 func unittest() { 119 if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name")); _err_ == nil { 120 fmt.Println("extra add line") 121 fmt.Println("unit-test", val) 122 } 123 fmt.Println("extra add line2") 124 } 125 `, 126 expected: ` 127 package rewriter_test 128 129 import ( 130 "fmt" 131 132 "github.com/pingcap/failpoint" 133 ) 134 135 func unittest() { 136 failpoint.Inject("failpoint-name", func(val failpoint.Value) { 137 fmt.Println("extra add line") 138 fmt.Println("unit-test", val) 139 }) 140 fmt.Println("extra add line2") 141 } 142 `, 143 }, 144 145 { 146 filepath: "modified-test-3.go", 147 original: ` 148 package rewriter_test 149 150 import ( 151 "fmt" 152 153 "github.com/pingcap/failpoint" 154 ) 155 156 func unittest() { 157 failpoint.Inject("failpoint-name", func(val failpoint.Value) { 158 fmt.Println("unit-test", val) 159 }) 160 } 161 `, 162 modified: ` 163 package rewriter_test 164 165 import ( 166 "fmt" 167 168 "github.com/pingcap/failpoint" 169 ) 170 171 func unittest() { 172 if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name-extra-part")); _err_ == nil { 173 fmt.Println("extra add line") 174 fmt.Println("unit-test", val) 175 } 176 fmt.Println("extra add line2") 177 } 178 `, 179 expected: ` 180 package rewriter_test 181 182 import ( 183 "fmt" 184 185 "github.com/pingcap/failpoint" 186 ) 187 188 func unittest() { 189 failpoint.Inject("failpoint-name-extra-part", func(val failpoint.Value) { 190 fmt.Println("extra add line") 191 fmt.Println("unit-test", val) 192 }) 193 fmt.Println("extra add line2") 194 } 195 `, 196 }, 197 } 198 199 // Create temp files 200 err := os.MkdirAll(restorePath, 0755) 201 require.NoError(t, err) 202 for _, cs := range cases { 203 original := filepath.Join(restorePath, cs.filepath) 204 err := ioutil.WriteFile(original, []byte(cs.original), 0644) 205 require.NoError(t, err) 206 } 207 208 // Clean all temp files 209 defer func() { 210 err := os.RemoveAll(restorePath) 211 require.NoError(t, err) 212 }() 213 214 rewriter := code.NewRewriter(restorePath) 215 err = rewriter.Rewrite() 216 require.NoError(t, err) 217 218 for _, cs := range cases { 219 modified := filepath.Join(restorePath, cs.filepath) 220 err := ioutil.WriteFile(modified, []byte(cs.modified), 0644) 221 require.NoError(t, err) 222 } 223 224 // Restore workspace 225 restorer := code.NewRestorer(restorePath) 226 err = restorer.Restore() 227 require.NoError(t, err) 228 229 for _, cs := range cases { 230 expected := filepath.Join(restorePath, cs.filepath) 231 content, err := ioutil.ReadFile(expected) 232 require.NoError(t, err) 233 require.Equalf(t, strings.TrimSpace(cs.expected), strings.TrimSpace(string(content)), "%v", cs.filepath) 234 } 235 } 236 237 func TestRestoreModificationBad(t *testing.T) { 238 var cases = []struct { 239 filepath string 240 original string 241 modified string 242 }{ 243 { 244 filepath: "bad-modification-test.go", 245 original: ` 246 package rewriter_test 247 248 import ( 249 "fmt" 250 251 "github.com/pingcap/failpoint" 252 ) 253 254 func unittest() { 255 failpoint.Inject("failpoint-name", func(val failpoint.Value) { 256 fmt.Println("unit-test", val) 257 }) 258 } 259 `, 260 modified: ` 261 package rewriter_test 262 263 import ( 264 "fmt" 265 266 "github.com/pingcap/failpoint" 267 ) 268 269 func unittest() { 270 if val, _err_ := failpoint.EvalContext(nil, _curpkg_("failpoint-name-extra-part")); _err_ == nil { 271 fmt.Println("extra add line") 272 fmt.Println("unit-test", val) 273 } 274 } 275 `, 276 }, 277 } 278 279 // Create temp files 280 err := os.MkdirAll(restorePath, 0755) 281 require.NoError(t, err) 282 for _, cs := range cases { 283 original := filepath.Join(restorePath, cs.filepath) 284 err := ioutil.WriteFile(original, []byte(cs.original), 0644) 285 require.NoError(t, err) 286 } 287 288 // Clean all temp files 289 defer func() { 290 err := os.RemoveAll(restorePath) 291 require.NoError(t, err) 292 }() 293 294 rewriter := code.NewRewriter(restorePath) 295 err = rewriter.Rewrite() 296 require.NoError(t, err) 297 298 for _, cs := range cases { 299 modified := filepath.Join(restorePath, cs.filepath) 300 err := ioutil.WriteFile(modified, []byte(cs.modified), 0644) 301 require.NoError(t, err) 302 } 303 304 restorer := code.NewRestorer(restorePath) 305 err = restorer.Restore() 306 require.Error(t, err) 307 require.Regexp(t, `cannot merge modifications back automatically.*`, err.Error()) 308 }