github.com/tencent/goom@v1.0.1/when.go (about)

     1  // Package mocker 定义了 mock 的外层用户使用 API 定义,
     2  // 包括函数、方法、接口、未导出函数(或方法的)的 Mocker 的实现。
     3  // 当前文件实现了按照参数条件进行匹配, 返回对应的 mock return 值,
     4  // 支持了 mocker.When(XXX).Return(YYY)的高效匹配。
     5  package mocker
     6  
     7  import (
     8  	"reflect"
     9  
    10  	"github.com/tencent/goom/arg"
    11  	"github.com/tencent/goom/erro"
    12  )
    13  
    14  // Matcher 参数匹配接口
    15  type Matcher interface {
    16  	// Match 匹配执行方法
    17  	Match(args []reflect.Value) bool
    18  	// Result 匹配成功返回的结果
    19  	Result() []reflect.Value
    20  	// AddResult 添加返回结果
    21  	AddResult([]interface{})
    22  }
    23  
    24  // When Mock 条件匹配。
    25  // 当参数等于指定的值时,会 return 对应的指定值
    26  type When struct {
    27  	ExportedMocker
    28  	funcTyp        reflect.Type
    29  	funcDef        interface{}
    30  	isMethod       bool
    31  	matches        []Matcher
    32  	defaultReturns Matcher
    33  	// curMatch 当前指定的参数匹配
    34  	curMatch Matcher
    35  }
    36  
    37  // CreateWhen 构造条件判断
    38  // param 参数条件
    39  // defaultReturns 默认返回值
    40  // isMethod 是否为方法类型
    41  func CreateWhen(m ExportedMocker, funcDef interface{}, args []interface{},
    42  	defaultReturns []interface{}, isMethod bool) (*When, error) {
    43  	impTyp := reflect.TypeOf(funcDef)
    44  	err := checkParams(funcDef, impTyp, args, defaultReturns, isMethod)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	var (
    50  		curMatch     Matcher
    51  		defaultMatch Matcher
    52  	)
    53  	if defaultReturns != nil {
    54  		curMatch = newAlwaysMatch(defaultReturns, impTyp)
    55  	} else if len(outTypes(impTyp)) == 0 {
    56  		curMatch = newEmptyMatch()
    57  	}
    58  
    59  	defaultMatch = curMatch
    60  	if args != nil {
    61  		curMatch = newDefaultMatch(args, nil, isMethod, impTyp)
    62  	}
    63  	return &When{
    64  		ExportedMocker: m,
    65  		defaultReturns: defaultMatch,
    66  		funcTyp:        impTyp,
    67  		funcDef:        funcDef,
    68  		isMethod:       isMethod,
    69  		matches:        make([]Matcher, 0),
    70  		curMatch:       curMatch,
    71  	}, nil
    72  }
    73  
    74  // checkParams 检查参数
    75  func checkParams(funcDef interface{}, impTyp reflect.Type,
    76  	args []interface{}, returns []interface{}, isMethod bool) error {
    77  	if returns != nil && len(returns) < impTyp.NumOut() {
    78  		return erro.NewReturnsNotMatchError(funcDef, len(returns), impTyp.NumOut())
    79  	}
    80  	if isMethod {
    81  		if args != nil && len(args)+1 < impTyp.NumIn() {
    82  			return erro.NewArgsNotMatchError(funcDef, len(args), impTyp.NumIn()-1)
    83  		}
    84  	} else {
    85  		if args != nil && len(args) < impTyp.NumIn() {
    86  			return erro.NewArgsNotMatchError(funcDef, len(args), impTyp.NumIn())
    87  		}
    88  	}
    89  	return nil
    90  }
    91  
    92  // NewWhen 创建默认 When
    93  func NewWhen(funTyp reflect.Type) *When {
    94  	return &When{
    95  		ExportedMocker: nil,
    96  		funcTyp:        funTyp,
    97  		matches:        make([]Matcher, 0),
    98  		defaultReturns: nil,
    99  		curMatch:       nil,
   100  	}
   101  }
   102  
   103  // When 当参数符合一定的条件, 使用 DefaultMatcher
   104  // 入参个数必须和函数或方法参数个数一致,
   105  // 比如: When(
   106  //
   107  //	In(3, 4), // 第一个参数是 In
   108  //	Any()) // 第二个参数是 Any
   109  func (w *When) When(specArgOrExpr ...interface{}) *When {
   110  	w.curMatch = newDefaultMatch(specArgOrExpr, nil, w.isMethod, w.funcTyp)
   111  	return w
   112  }
   113  
   114  // In 当参数包含其中之一, 使用 ContainsMatcher
   115  // 当参数为多个时, In 的每个条件各使用一个数组表示:
   116  // .In([]interface{}{3, Any()}, []interface{}{4, Any()})
   117  func (w *When) In(specArgsOrExprs ...interface{}) *When {
   118  	w.curMatch = newContainsMatch(specArgsOrExprs, nil, w.isMethod, w.funcTyp)
   119  	return w
   120  }
   121  
   122  // Return 指定返回值
   123  func (w *When) Return(value ...interface{}) *When {
   124  	if w.curMatch != nil {
   125  		w.curMatch.AddResult(value)
   126  		w.matches = append(w.matches, w.curMatch)
   127  		return w
   128  	}
   129  
   130  	if w.defaultReturns == nil {
   131  		w.defaultReturns = newAlwaysMatch(value, w.funcTyp)
   132  	} else {
   133  		w.defaultReturns.AddResult(value)
   134  	}
   135  	return w
   136  }
   137  
   138  // AndReturn 指定第二次调用返回值,之后的调用以最后一个指定的值返回
   139  func (w *When) AndReturn(value ...interface{}) *When {
   140  	if w.curMatch == nil {
   141  		return w.Return(value...)
   142  	}
   143  	w.curMatch.AddResult(value)
   144  	return w
   145  }
   146  
   147  // Matches 多个条件匹配
   148  func (w *When) Matches(argAndRet ...arg.Pair) *When {
   149  	if len(argAndRet) == 0 {
   150  		return w
   151  	}
   152  	for _, v := range argAndRet {
   153  		args, ok := v.Args.([]interface{})
   154  		if !ok {
   155  			args = []interface{}{v.Args}
   156  		}
   157  
   158  		results, ok := v.Return.([]interface{})
   159  		if !ok {
   160  			results = []interface{}{v.Return}
   161  		}
   162  
   163  		w.Return(results...)
   164  		matcher := newDefaultMatch(args, results, w.isMethod, w.funcTyp)
   165  		w.matches = append(w.matches, matcher)
   166  	}
   167  	return w
   168  }
   169  
   170  // Returns 按顺序依次返回值
   171  func (w *When) Returns(values ...interface{}) *When {
   172  	if len(values) == 0 {
   173  		return w
   174  	}
   175  
   176  	for i, v := range values {
   177  		ret, ok := v.([]interface{})
   178  		if !ok {
   179  			ret = []interface{}{v}
   180  		}
   181  		if i == 0 {
   182  			w.Return(ret...)
   183  		} else {
   184  			w.AndReturn(ret...)
   185  		}
   186  	}
   187  	return w
   188  }
   189  
   190  // invoke 执行 When 参数匹配并返回值
   191  func (w *When) invoke(args1 []reflect.Value) (results []reflect.Value) {
   192  	if len(w.matches) != 0 {
   193  		for _, c := range w.matches {
   194  			if c.Match(args1) {
   195  				return c.Result()
   196  			}
   197  		}
   198  	}
   199  	return w.returnDefaults()
   200  }
   201  
   202  // Eval 执行 when 子句
   203  func (w *When) Eval(args ...interface{}) []interface{} {
   204  	argVs := arg.I2V(args, inTypes(w.isMethod, w.funcTyp))
   205  	resultVs := w.invoke(argVs)
   206  	return arg.V2I(resultVs, outTypes(w.funcTyp))
   207  }
   208  
   209  // returnDefaults 返回默认值
   210  func (w *When) returnDefaults() []reflect.Value {
   211  	if w.defaultReturns == nil && w.funcTyp.NumOut() != 0 {
   212  		panic("there is no suitable condition matched, or set default return with: mocker.Return(...)")
   213  	}
   214  	return w.defaultReturns.Result()
   215  }