github.com/pingcap/failpoint@v0.0.0-20240412033321-fd0796e60f86/code/expr_rewriter.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package code
    16  
    17  import (
    18  	"fmt"
    19  	"go/ast"
    20  	"go/token"
    21  	"strings"
    22  )
    23  
    24  type exprRewriter func(rewriter *Rewriter, call *ast.CallExpr) (rewritten bool, result ast.Stmt, err error)
    25  
    26  var exprRewriters = map[string]exprRewriter{
    27  	"Inject":        (*Rewriter).rewriteInject,
    28  	"InjectContext": (*Rewriter).rewriteInjectContext,
    29  	"Break":         (*Rewriter).rewriteBreak,
    30  	"Continue":      (*Rewriter).rewriteContinue,
    31  	"Label":         (*Rewriter).rewriteLabel,
    32  	"Goto":          (*Rewriter).rewriteGoto,
    33  	"Fallthrough":   (*Rewriter).rewriteFallthrough,
    34  	"Return":        (*Rewriter).rewriteReturn,
    35  }
    36  
    37  func (r *Rewriter) rewriteInject(call *ast.CallExpr) (bool, ast.Stmt, error) {
    38  	if len(call.Args) != 2 {
    39  		return false, nil, fmt.Errorf("failpoint.Inject: expect 2 arguments but got %v in %s", len(call.Args), r.pos(call.Pos()))
    40  	}
    41  	// First argument need not to be a string literal, any string type stuff is ok.
    42  	// Type safe is convinced by compiler.
    43  	fpname, ok := call.Args[0].(ast.Expr)
    44  	if !ok {
    45  		return false, nil, fmt.Errorf("failpoint.Inject: first argument expect a valid expression in %s", r.pos(call.Pos()))
    46  	}
    47  
    48  	// failpoint.Inject("failpoint-name", nil)
    49  	ident, ok := call.Args[1].(*ast.Ident)
    50  	isNilFunc := ok && ident.Name == "nil"
    51  
    52  	// failpoint.Inject("failpoint-name", func(){...})
    53  	// failpoint.Inject("failpoint-name", func(val failpoint.Value){...})
    54  	fpbody, isFuncLit := call.Args[1].(*ast.FuncLit)
    55  	if !isNilFunc && !isFuncLit {
    56  		return false, nil, fmt.Errorf("failpoint.Inject: second argument expect closure in %s", r.pos(call.Pos()))
    57  	}
    58  	if isFuncLit {
    59  		if len(fpbody.Type.Params.List) > 1 {
    60  			return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos()))
    61  		}
    62  
    63  		if len(fpbody.Type.Params.List) == 1 && len(fpbody.Type.Params.List[0].Names) > 1 {
    64  			return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos()))
    65  		}
    66  	}
    67  
    68  	fpnameExtendCall := &ast.CallExpr{
    69  		Fun:  ast.NewIdent(ExtendPkgName),
    70  		Args: []ast.Expr{fpname},
    71  	}
    72  
    73  	checkCall := &ast.CallExpr{
    74  		Fun: &ast.SelectorExpr{
    75  			X:   &ast.Ident{NamePos: call.Pos(), Name: r.failpointName},
    76  			Sel: ast.NewIdent(evalFunction),
    77  		},
    78  		Args: []ast.Expr{fpnameExtendCall},
    79  	}
    80  	if isNilFunc || len(fpbody.Body.List) < 1 {
    81  		return true, &ast.ExprStmt{X: checkCall}, nil
    82  	}
    83  
    84  	ifBody := &ast.BlockStmt{
    85  		Lbrace: call.Pos(),
    86  		List:   fpbody.Body.List,
    87  		Rbrace: call.End(),
    88  	}
    89  
    90  	// closure signature:
    91  	// func(val failpoint.Value) {...}
    92  	// func() {...}
    93  	var argName *ast.Ident
    94  	if len(fpbody.Type.Params.List) > 0 {
    95  		arg := fpbody.Type.Params.List[0]
    96  		selector, ok := arg.Type.(*ast.SelectorExpr)
    97  		if !ok || selector.Sel.Name != "Value" || selector.X.(*ast.Ident).Name != r.failpointName {
    98  			return false, nil, fmt.Errorf("failpoint.Inject: invalid signature in %s", r.pos(call.Pos()))
    99  		}
   100  		argName = arg.Names[0]
   101  	} else {
   102  		argName = ast.NewIdent("_")
   103  	}
   104  
   105  	err := ast.NewIdent("_err_")
   106  	init := &ast.AssignStmt{
   107  		Lhs: []ast.Expr{argName, err},
   108  		Rhs: []ast.Expr{checkCall},
   109  		Tok: token.DEFINE,
   110  	}
   111  
   112  	cond := &ast.BinaryExpr{
   113  		X:  err,
   114  		Op: token.EQL,
   115  		Y:  ast.NewIdent("nil"),
   116  	}
   117  	stmt := &ast.IfStmt{
   118  		If:   call.Pos(),
   119  		Init: init,
   120  		Cond: cond,
   121  		Body: ifBody,
   122  	}
   123  	return true, stmt, nil
   124  }
   125  
   126  func (r *Rewriter) rewriteInjectContext(call *ast.CallExpr) (bool, ast.Stmt, error) {
   127  	if len(call.Args) != 3 {
   128  		return false, nil, fmt.Errorf("failpoint.InjectContext: expect 3 arguments but got %v in %s", len(call.Args), r.pos(call.Pos()))
   129  	}
   130  
   131  	// Second argument need not to be a identifier, any context type token (e.g. selector) is OK.
   132  	// Type safe is convinced by compiler.
   133  	ctxname, ok := call.Args[0].(ast.Expr)
   134  	if !ok {
   135  		return false, nil, fmt.Errorf("failpoint.InjectContext: first argument expect context in %s, which must be an expression", r.pos(call.Pos()))
   136  	}
   137  	// Second argument need not to be a string literal, any string type stuff is ok.
   138  	// Type safe is convinced by compiler.
   139  	fpname, ok := call.Args[1].(ast.Expr)
   140  	if !ok {
   141  		return false, nil, fmt.Errorf("failpoint.InjectContext: second argument expect a valid expression in %s", r.pos(call.Pos()))
   142  	}
   143  
   144  	// failpoint.InjectContext("failpoint-name", ctx, nil)
   145  	ident, ok := call.Args[2].(*ast.Ident)
   146  	isNilFunc := ok && ident.Name == "nil"
   147  
   148  	// failpoint.InjectContext("failpoint-name", ctx, func(){...})
   149  	// failpoint.InjectContext("failpoint-name", ctx, func(val failpoint.Value){...})
   150  	fpbody, isFuncLit := call.Args[2].(*ast.FuncLit)
   151  	if !isNilFunc && !isFuncLit {
   152  		return false, nil, fmt.Errorf("failpoint.InjectContext: third argument expect closure in %s", r.pos(call.Pos()))
   153  	}
   154  
   155  	if isFuncLit {
   156  		if len(fpbody.Type.Params.List) > 1 {
   157  			return false, nil, fmt.Errorf("failpoint.InjectContext: closure signature illegal in %s", r.pos(call.Pos()))
   158  		}
   159  
   160  		if len(fpbody.Type.Params.List) == 1 && len(fpbody.Type.Params.List[0].Names) > 1 {
   161  			return false, nil, fmt.Errorf("failpoint.InjectContext: closure signature illegal in %s", r.pos(call.Pos()))
   162  		}
   163  	}
   164  
   165  	fpnameExtendCall := &ast.CallExpr{
   166  		Fun:  ast.NewIdent(ExtendPkgName),
   167  		Args: []ast.Expr{fpname},
   168  	}
   169  
   170  	checkCall := &ast.CallExpr{
   171  		Fun: &ast.SelectorExpr{
   172  			X:   &ast.Ident{NamePos: call.Pos(), Name: r.failpointName},
   173  			Sel: ast.NewIdent(evalCtxFunction),
   174  		},
   175  		Args: []ast.Expr{ctxname, fpnameExtendCall},
   176  	}
   177  	if isNilFunc || len(fpbody.Body.List) < 1 {
   178  		return true, &ast.ExprStmt{X: checkCall}, nil
   179  	}
   180  
   181  	ifBody := &ast.BlockStmt{
   182  		Lbrace: call.Pos(),
   183  		List:   fpbody.Body.List,
   184  		Rbrace: call.End(),
   185  	}
   186  
   187  	// closure signature:
   188  	// func(val failpoint.Value) {...}
   189  	// func() {...}
   190  	var argName *ast.Ident
   191  	if len(fpbody.Type.Params.List) > 0 {
   192  		arg := fpbody.Type.Params.List[0]
   193  		selector, ok := arg.Type.(*ast.SelectorExpr)
   194  		if !ok || selector.Sel.Name != "Value" || selector.X.(*ast.Ident).Name != r.failpointName {
   195  			return false, nil, fmt.Errorf("failpoint.InjectContext: invalid signature in %s", r.pos(call.Pos()))
   196  		}
   197  		argName = arg.Names[0]
   198  	} else {
   199  		argName = ast.NewIdent("_")
   200  	}
   201  
   202  	err := ast.NewIdent("_err_")
   203  	init := &ast.AssignStmt{
   204  		Lhs: []ast.Expr{argName, err},
   205  		Rhs: []ast.Expr{checkCall},
   206  		Tok: token.DEFINE,
   207  	}
   208  
   209  	cond := &ast.BinaryExpr{
   210  		X:  err,
   211  		Op: token.EQL,
   212  		Y:  ast.NewIdent("nil"),
   213  	}
   214  	stmt := &ast.IfStmt{
   215  		If:   call.Pos(),
   216  		Init: init,
   217  		Cond: cond,
   218  		Body: ifBody,
   219  	}
   220  	return true, stmt, nil
   221  }
   222  
   223  func (r *Rewriter) rewriteBreak(call *ast.CallExpr) (bool, ast.Stmt, error) {
   224  	if count := len(call.Args); count > 1 {
   225  		return false, nil, fmt.Errorf("failpoint.Break expect 1 or 0 arguments, but got %v in %s", count, r.pos(call.Pos()))
   226  	}
   227  	var stmt *ast.BranchStmt
   228  	if len(call.Args) > 0 {
   229  		label := call.Args[0].(*ast.BasicLit).Value
   230  		label = strings.Trim(label, "`\"")
   231  		stmt = &ast.BranchStmt{
   232  			TokPos: call.Pos(),
   233  			Tok:    token.BREAK,
   234  			Label:  ast.NewIdent(label),
   235  		}
   236  	} else {
   237  		stmt = &ast.BranchStmt{
   238  			TokPos: call.Pos(),
   239  			Tok:    token.BREAK,
   240  		}
   241  	}
   242  	return true, stmt, nil
   243  }
   244  
   245  func (r *Rewriter) rewriteContinue(call *ast.CallExpr) (bool, ast.Stmt, error) {
   246  	if count := len(call.Args); count > 1 {
   247  		return false, nil, fmt.Errorf("failpoint.Continue expect 1 or 0 arguments, but got %v in %s", count, r.pos(call.Pos()))
   248  	}
   249  	var stmt *ast.BranchStmt
   250  	if len(call.Args) > 0 {
   251  		label := call.Args[0].(*ast.BasicLit).Value
   252  		label = strings.Trim(label, "`\"")
   253  		stmt = &ast.BranchStmt{
   254  			TokPos: call.Pos(),
   255  			Tok:    token.CONTINUE,
   256  			Label:  ast.NewIdent(label),
   257  		}
   258  	} else {
   259  		stmt = &ast.BranchStmt{
   260  			TokPos: call.Pos(),
   261  			Tok:    token.CONTINUE,
   262  		}
   263  	}
   264  	return true, stmt, nil
   265  }
   266  
   267  func (r *Rewriter) rewriteLabel(call *ast.CallExpr) (bool, ast.Stmt, error) {
   268  	if count := len(call.Args); count != 1 {
   269  		return false, nil, fmt.Errorf("failpoint.Label expect 1 arguments, but got %v in %s", count, r.pos(call.Pos()))
   270  	}
   271  	label := call.Args[0].(*ast.BasicLit).Value
   272  	label = strings.Trim(label, "`\"")
   273  	stmt := &ast.LabeledStmt{
   274  		Colon: call.Pos(),
   275  		Label: ast.NewIdent(label + labelSuffix), // It's a trick here
   276  	}
   277  	return true, stmt, nil
   278  }
   279  
   280  func (r *Rewriter) rewriteGoto(call *ast.CallExpr) (bool, ast.Stmt, error) {
   281  	if count := len(call.Args); count != 1 {
   282  		return false, nil, fmt.Errorf("failpoint.Goto expect 1 arguments, but got %v in %s", count, r.pos(call.Pos()))
   283  	}
   284  	label := call.Args[0].(*ast.BasicLit).Value
   285  	label = strings.Trim(label, "`\"")
   286  	stmt := &ast.BranchStmt{
   287  		TokPos: call.Pos(),
   288  		Tok:    token.GOTO,
   289  		Label:  ast.NewIdent(label),
   290  	}
   291  	return true, stmt, nil
   292  }
   293  
   294  func (r *Rewriter) rewriteFallthrough(call *ast.CallExpr) (bool, ast.Stmt, error) {
   295  	stmt := &ast.BranchStmt{
   296  		TokPos: call.Pos(),
   297  		Tok:    token.FALLTHROUGH,
   298  	}
   299  	return true, stmt, nil
   300  }
   301  
   302  func (r *Rewriter) rewriteReturn(call *ast.CallExpr) (bool, ast.Stmt, error) {
   303  	stmt := &ast.ReturnStmt{
   304  		Return:  call.Pos(),
   305  		Results: call.Args,
   306  	}
   307  	return true, stmt, nil
   308  }