github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/code/expr_rewriter.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package code 16 17 import ( 18 "fmt" 19 "go/ast" 20 "go/token" 21 "strings" 22 ) 23 24 type exprRewriter func(rewriter *Rewriter, call *ast.CallExpr) (rewritten bool, result ast.Stmt, err error) 25 26 var exprRewriters = map[string]exprRewriter{ 27 "Inject": (*Rewriter).rewriteInject, 28 "InjectContext": (*Rewriter).rewriteInjectContext, 29 "Break": (*Rewriter).rewriteBreak, 30 "Continue": (*Rewriter).rewriteContinue, 31 "Label": (*Rewriter).rewriteLabel, 32 "Goto": (*Rewriter).rewriteGoto, 33 "Fallthrough": (*Rewriter).rewriteFallthrough, 34 "Return": (*Rewriter).rewriteReturn, 35 } 36 37 func (r *Rewriter) rewriteInject(call *ast.CallExpr) (bool, ast.Stmt, error) { 38 if len(call.Args) != 2 { 39 return false, nil, fmt.Errorf("failpoint.Inject: expect 2 arguments but got %v in %s", len(call.Args), r.pos(call.Pos())) 40 } 41 // First argument need not to be a string literal, any string type stuff is ok. 42 // Type safe is convinced by compiler. 43 fpname, ok := call.Args[0].(ast.Expr) 44 if !ok { 45 return false, nil, fmt.Errorf("failpoint.Inject: first argument expect a valid expression in %s", r.pos(call.Pos())) 46 } 47 48 // failpoint.Inject("failpoint-name", nil) 49 ident, ok := call.Args[1].(*ast.Ident) 50 isNilFunc := ok && ident.Name == "nil" 51 52 // failpoint.Inject("failpoint-name", func(){...}) 53 // failpoint.Inject("failpoint-name", func(val failpoint.Value){...}) 54 fpbody, isFuncLit := call.Args[1].(*ast.FuncLit) 55 if !isNilFunc && !isFuncLit { 56 return false, nil, fmt.Errorf("failpoint.Inject: second argument expect closure in %s", r.pos(call.Pos())) 57 } 58 if isFuncLit { 59 if len(fpbody.Type.Params.List) > 1 { 60 return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos())) 61 } 62 63 if len(fpbody.Type.Params.List) == 1 && len(fpbody.Type.Params.List[0].Names) > 1 { 64 return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos())) 65 } 66 } 67 68 fpnameExtendCall := &ast.CallExpr{ 69 Fun: ast.NewIdent(ExtendPkgName), 70 Args: []ast.Expr{fpname}, 71 } 72 73 checkCall := &ast.CallExpr{ 74 Fun: &ast.SelectorExpr{ 75 X: &ast.Ident{NamePos: call.Pos(), Name: r.failpointName}, 76 Sel: ast.NewIdent(evalFunction), 77 }, 78 Args: []ast.Expr{fpnameExtendCall}, 79 } 80 if isNilFunc || len(fpbody.Body.List) < 1 { 81 return true, &ast.ExprStmt{X: checkCall}, nil 82 } 83 84 ifBody := &ast.BlockStmt{ 85 Lbrace: call.Pos(), 86 List: fpbody.Body.List, 87 Rbrace: call.End(), 88 } 89 90 // closure signature: 91 // func(val failpoint.Value) {...} 92 // func() {...} 93 var argName *ast.Ident 94 if len(fpbody.Type.Params.List) > 0 { 95 arg := fpbody.Type.Params.List[0] 96 selector, ok := arg.Type.(*ast.SelectorExpr) 97 if !ok || selector.Sel.Name != "Value" || selector.X.(*ast.Ident).Name != r.failpointName { 98 return false, nil, fmt.Errorf("failpoint.Inject: invalid signature in %s", r.pos(call.Pos())) 99 } 100 argName = arg.Names[0] 101 } else { 102 argName = ast.NewIdent("_") 103 } 104 105 err := ast.NewIdent("_err_") 106 init := &ast.AssignStmt{ 107 Lhs: []ast.Expr{argName, err}, 108 Rhs: []ast.Expr{checkCall}, 109 Tok: token.DEFINE, 110 } 111 112 cond := &ast.BinaryExpr{ 113 X: err, 114 Op: token.EQL, 115 Y: ast.NewIdent("nil"), 116 } 117 stmt := &ast.IfStmt{ 118 If: call.Pos(), 119 Init: init, 120 Cond: cond, 121 Body: ifBody, 122 } 123 return true, stmt, nil 124 } 125 126 func (r *Rewriter) rewriteInjectContext(call *ast.CallExpr) (bool, ast.Stmt, error) { 127 if len(call.Args) != 3 { 128 return false, nil, fmt.Errorf("failpoint.InjectContext: expect 3 arguments but got %v in %s", len(call.Args), r.pos(call.Pos())) 129 } 130 131 // Second argument need not to be a identifier, any context type token (e.g. selector) is OK. 132 // Type safe is convinced by compiler. 133 ctxname, ok := call.Args[0].(ast.Expr) 134 if !ok { 135 return false, nil, fmt.Errorf("failpoint.InjectContext: first argument expect context in %s, which must be an expression", r.pos(call.Pos())) 136 } 137 // Second argument need not to be a string literal, any string type stuff is ok. 138 // Type safe is convinced by compiler. 139 fpname, ok := call.Args[1].(ast.Expr) 140 if !ok { 141 return false, nil, fmt.Errorf("failpoint.InjectContext: second argument expect a valid expression in %s", r.pos(call.Pos())) 142 } 143 144 // failpoint.InjectContext("failpoint-name", ctx, nil) 145 ident, ok := call.Args[2].(*ast.Ident) 146 isNilFunc := ok && ident.Name == "nil" 147 148 // failpoint.InjectContext("failpoint-name", ctx, func(){...}) 149 // failpoint.InjectContext("failpoint-name", ctx, func(val failpoint.Value){...}) 150 fpbody, isFuncLit := call.Args[2].(*ast.FuncLit) 151 if !isNilFunc && !isFuncLit { 152 return false, nil, fmt.Errorf("failpoint.InjectContext: third argument expect closure in %s", r.pos(call.Pos())) 153 } 154 155 if isFuncLit { 156 if len(fpbody.Type.Params.List) > 1 { 157 return false, nil, fmt.Errorf("failpoint.InjectContext: closure signature illegal in %s", r.pos(call.Pos())) 158 } 159 160 if len(fpbody.Type.Params.List) == 1 && len(fpbody.Type.Params.List[0].Names) > 1 { 161 return false, nil, fmt.Errorf("failpoint.InjectContext: closure signature illegal in %s", r.pos(call.Pos())) 162 } 163 } 164 165 fpnameExtendCall := &ast.CallExpr{ 166 Fun: ast.NewIdent(ExtendPkgName), 167 Args: []ast.Expr{fpname}, 168 } 169 170 checkCall := &ast.CallExpr{ 171 Fun: &ast.SelectorExpr{ 172 X: &ast.Ident{NamePos: call.Pos(), Name: r.failpointName}, 173 Sel: ast.NewIdent(evalCtxFunction), 174 }, 175 Args: []ast.Expr{ctxname, fpnameExtendCall}, 176 } 177 if isNilFunc || len(fpbody.Body.List) < 1 { 178 return true, &ast.ExprStmt{X: checkCall}, nil 179 } 180 181 ifBody := &ast.BlockStmt{ 182 Lbrace: call.Pos(), 183 List: fpbody.Body.List, 184 Rbrace: call.End(), 185 } 186 187 // closure signature: 188 // func(val failpoint.Value) {...} 189 // func() {...} 190 var argName *ast.Ident 191 if len(fpbody.Type.Params.List) > 0 { 192 arg := fpbody.Type.Params.List[0] 193 selector, ok := arg.Type.(*ast.SelectorExpr) 194 if !ok || selector.Sel.Name != "Value" || selector.X.(*ast.Ident).Name != r.failpointName { 195 return false, nil, fmt.Errorf("failpoint.InjectContext: invalid signature in %s", r.pos(call.Pos())) 196 } 197 argName = arg.Names[0] 198 } else { 199 argName = ast.NewIdent("_") 200 } 201 202 err := ast.NewIdent("_err_") 203 init := &ast.AssignStmt{ 204 Lhs: []ast.Expr{argName, err}, 205 Rhs: []ast.Expr{checkCall}, 206 Tok: token.DEFINE, 207 } 208 209 cond := &ast.BinaryExpr{ 210 X: err, 211 Op: token.EQL, 212 Y: ast.NewIdent("nil"), 213 } 214 stmt := &ast.IfStmt{ 215 If: call.Pos(), 216 Init: init, 217 Cond: cond, 218 Body: ifBody, 219 } 220 return true, stmt, nil 221 } 222 223 func (r *Rewriter) rewriteBreak(call *ast.CallExpr) (bool, ast.Stmt, error) { 224 if count := len(call.Args); count > 1 { 225 return false, nil, fmt.Errorf("failpoint.Break expect 1 or 0 arguments, but got %v in %s", count, r.pos(call.Pos())) 226 } 227 var stmt *ast.BranchStmt 228 if len(call.Args) > 0 { 229 label := call.Args[0].(*ast.BasicLit).Value 230 label = strings.Trim(label, "`\"") 231 stmt = &ast.BranchStmt{ 232 TokPos: call.Pos(), 233 Tok: token.BREAK, 234 Label: ast.NewIdent(label), 235 } 236 } else { 237 stmt = &ast.BranchStmt{ 238 TokPos: call.Pos(), 239 Tok: token.BREAK, 240 } 241 } 242 return true, stmt, nil 243 } 244 245 func (r *Rewriter) rewriteContinue(call *ast.CallExpr) (bool, ast.Stmt, error) { 246 if count := len(call.Args); count > 1 { 247 return false, nil, fmt.Errorf("failpoint.Continue expect 1 or 0 arguments, but got %v in %s", count, r.pos(call.Pos())) 248 } 249 var stmt *ast.BranchStmt 250 if len(call.Args) > 0 { 251 label := call.Args[0].(*ast.BasicLit).Value 252 label = strings.Trim(label, "`\"") 253 stmt = &ast.BranchStmt{ 254 TokPos: call.Pos(), 255 Tok: token.CONTINUE, 256 Label: ast.NewIdent(label), 257 } 258 } else { 259 stmt = &ast.BranchStmt{ 260 TokPos: call.Pos(), 261 Tok: token.CONTINUE, 262 } 263 } 264 return true, stmt, nil 265 } 266 267 func (r *Rewriter) rewriteLabel(call *ast.CallExpr) (bool, ast.Stmt, error) { 268 if count := len(call.Args); count != 1 { 269 return false, nil, fmt.Errorf("failpoint.Label expect 1 arguments, but got %v in %s", count, r.pos(call.Pos())) 270 } 271 label := call.Args[0].(*ast.BasicLit).Value 272 label = strings.Trim(label, "`\"") 273 stmt := &ast.LabeledStmt{ 274 Colon: call.Pos(), 275 Label: ast.NewIdent(label + labelSuffix), // It's a trick here 276 } 277 return true, stmt, nil 278 } 279 280 func (r *Rewriter) rewriteGoto(call *ast.CallExpr) (bool, ast.Stmt, error) { 281 if count := len(call.Args); count != 1 { 282 return false, nil, fmt.Errorf("failpoint.Goto expect 1 arguments, but got %v in %s", count, r.pos(call.Pos())) 283 } 284 label := call.Args[0].(*ast.BasicLit).Value 285 label = strings.Trim(label, "`\"") 286 stmt := &ast.BranchStmt{ 287 TokPos: call.Pos(), 288 Tok: token.GOTO, 289 Label: ast.NewIdent(label), 290 } 291 return true, stmt, nil 292 } 293 294 func (r *Rewriter) rewriteFallthrough(call *ast.CallExpr) (bool, ast.Stmt, error) { 295 stmt := &ast.BranchStmt{ 296 TokPos: call.Pos(), 297 Tok: token.FALLTHROUGH, 298 } 299 return true, stmt, nil 300 } 301 302 func (r *Rewriter) rewriteReturn(call *ast.CallExpr) (bool, ast.Stmt, error) { 303 stmt := &ast.ReturnStmt{ 304 Return: call.Pos(), 305 Results: call.Args, 306 } 307 return true, stmt, nil 308 }