github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/lint/passes/timer/timer.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // Package timer defines an Analyzer that detects correct use of
    12  // timeutil.Timer.
    13  package timer
    14  
    15  import (
    16  	"fmt"
    17  	"go/ast"
    18  	"go/token"
    19  	"go/types"
    20  
    21  	"golang.org/x/tools/go/analysis"
    22  	"golang.org/x/tools/go/analysis/passes/inspect"
    23  	"golang.org/x/tools/go/ast/inspector"
    24  )
    25  
    26  // Doc documents this pass.
    27  const Doc = `check for correct use of timeutil.Timer`
    28  
    29  // Analyzer defines this pass.
    30  var Analyzer = &analysis.Analyzer{
    31  	Name:     "timer",
    32  	Doc:      Doc,
    33  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    34  	Run:      run,
    35  }
    36  
    37  // timerChecker assures that timeutil.Timer objects are used correctly, to
    38  // avoid race conditions and deadlocks. These timers require callers to set
    39  // their Read field to true when their channel has been received on. If this
    40  // field is not set and the timer's Reset method is called, we will deadlock.
    41  // This lint assures that the Read field is set in the most common case where
    42  // Reset is used, within a for-loop where each iteration blocks on a select
    43  // statement. The timers are usually used as timeouts on these select
    44  // statements, and need to be reset after each iteration.
    45  //
    46  // for {
    47  //   timer.Reset(...)
    48  //   select {
    49  //     case <-timer.C:
    50  //       timer.Read = true   <--  lint verifies that this line is present
    51  //     case ...:
    52  //   }
    53  // }
    54  //
    55  func run(pass *analysis.Pass) (interface{}, error) {
    56  	selectorIsTimer := func(s *ast.SelectorExpr) bool {
    57  		tv, ok := pass.TypesInfo.Types[s.X]
    58  		if !ok {
    59  			return false
    60  		}
    61  		typ := tv.Type.Underlying()
    62  		for {
    63  			ptr, pok := typ.(*types.Pointer)
    64  			if !pok {
    65  				break
    66  			}
    67  			typ = ptr.Elem()
    68  		}
    69  		named, ok := typ.(*types.Named)
    70  		if !ok {
    71  			return false
    72  		}
    73  		if named.Obj().Type().String() != "github.com/cockroachdb/cockroach/pkg/util/timeutil.Timer" {
    74  			return false
    75  		}
    76  		return true
    77  	}
    78  
    79  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    80  
    81  	nodeFilter := []ast.Node{
    82  		(*ast.ForStmt)(nil),
    83  	}
    84  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    85  		fr, ok := n.(*ast.ForStmt)
    86  		if !ok {
    87  			return
    88  		}
    89  		walkStmts(fr.Body.List, func(s ast.Stmt) bool {
    90  			return walkSelectStmts(s, func(s ast.Stmt) bool {
    91  				comm, ok := s.(*ast.CommClause)
    92  				if !ok || comm.Comm == nil /* default: */ {
    93  					return true
    94  				}
    95  
    96  				// if receiving on a timer's C chan.
    97  				var unary ast.Expr
    98  				switch v := comm.Comm.(type) {
    99  				case *ast.AssignStmt:
   100  					// case `now := <-timer.C:`
   101  					unary = v.Rhs[0]
   102  				case *ast.ExprStmt:
   103  					// case `<-timer.C:`
   104  					unary = v.X
   105  				default:
   106  					return true
   107  				}
   108  				chanRead, ok := unary.(*ast.UnaryExpr)
   109  				if !ok || chanRead.Op != token.ARROW {
   110  					return true
   111  				}
   112  				selector, ok := chanRead.X.(*ast.SelectorExpr)
   113  				if !ok {
   114  					return true
   115  				}
   116  				if !selectorIsTimer(selector) {
   117  					return true
   118  				}
   119  				selectorName := fmt.Sprint(selector.X)
   120  				if selector.Sel.String() != timerChanName {
   121  					return true
   122  				}
   123  
   124  				// Verify that the case body contains `timer.Read = true`.
   125  				noRead := walkStmts(comm.Body, func(s ast.Stmt) bool {
   126  					assign, ok := s.(*ast.AssignStmt)
   127  					if !ok || assign.Tok != token.ASSIGN {
   128  						return true
   129  					}
   130  					for i := range assign.Lhs {
   131  						l, r := assign.Lhs[i], assign.Rhs[i]
   132  
   133  						// if assignment to correct field in timer.
   134  						assignSelector, ok := l.(*ast.SelectorExpr)
   135  						if !ok {
   136  							return true
   137  						}
   138  						if !selectorIsTimer(assignSelector) {
   139  							return true
   140  						}
   141  						if fmt.Sprint(assignSelector.X) != selectorName {
   142  							return true
   143  						}
   144  						if assignSelector.Sel.String() != "Read" {
   145  							return true
   146  						}
   147  
   148  						// if assigning `true`.
   149  						val, ok := r.(*ast.Ident)
   150  						if !ok {
   151  							return true
   152  						}
   153  						if val.String() == "true" {
   154  							// returning false will short-circuit walkStmts and assign
   155  							// noRead to false instead of the default value of true.
   156  							return false
   157  						}
   158  					}
   159  					return true
   160  				})
   161  				if noRead {
   162  					pass.Reportf(comm.Pos(), "must set timer.Read = true after reading from timer.C (see timeutil/timer.go)")
   163  				}
   164  				return true
   165  			})
   166  		})
   167  	})
   168  
   169  	return nil, nil
   170  }
   171  
   172  const timerChanName = "C"
   173  
   174  func walkSelectStmts(n ast.Node, fn func(ast.Stmt) bool) bool {
   175  	sel, ok := n.(*ast.SelectStmt)
   176  	if !ok {
   177  		return true
   178  	}
   179  	return walkStmts(sel.Body.List, fn)
   180  }
   181  
   182  func walkStmts(stmts []ast.Stmt, fn func(ast.Stmt) bool) bool {
   183  	for _, stmt := range stmts {
   184  		if !fn(stmt) {
   185  			return false
   186  		}
   187  	}
   188  	return true
   189  }