github.com/bytedance/mockey@v1.2.10/mock_condition.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package mockey
    18  
    19  import (
    20  	"reflect"
    21  
    22  	"github.com/bytedance/mockey/internal/tool"
    23  )
    24  
    25  type mockCondition struct {
    26  	when interface{} // condition
    27  	hook interface{} // mock function
    28  
    29  	builder *MockBuilder
    30  }
    31  
    32  func (m *mockCondition) Complete() bool {
    33  	return m.when != nil && m.hook != nil
    34  }
    35  
    36  func (m *mockCondition) SetWhen(when interface{}) {
    37  	tool.Assert(m.when == nil, "re-set builder when")
    38  	m.SetWhenForce(when)
    39  }
    40  
    41  func (m *mockCondition) SetWhenForce(when interface{}) {
    42  	wVal := reflect.ValueOf(when)
    43  	tool.Assert(wVal.Type().NumOut() == 1, "when func ret value not bool")
    44  	out1 := wVal.Type().Out(0)
    45  	tool.Assert(out1.Kind() == reflect.Bool, "when func ret value not bool")
    46  
    47  	hookType := m.builder.hookType()
    48  	inTypes := []reflect.Type{}
    49  	for i := 0; i < hookType.NumIn(); i++ {
    50  		inTypes = append(inTypes, hookType.In(i))
    51  	}
    52  
    53  	hasGeneric, hasReceiver := m.checkGenericAndReceiver(wVal.Type())
    54  	whenType := reflect.FuncOf(inTypes, []reflect.Type{out1}, hookType.IsVariadic())
    55  	m.when = reflect.MakeFunc(whenType, func(args []reflect.Value) (results []reflect.Value) {
    56  		results = tool.ReflectCall(wVal, m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
    57  		return
    58  	}).Interface()
    59  }
    60  
    61  func (m *mockCondition) SetReturn(results ...interface{}) {
    62  	tool.Assert(m.hook == nil, "re-set builder hook")
    63  	m.SetReturnForce(results...)
    64  }
    65  
    66  func (m *mockCondition) SetReturnForce(results ...interface{}) {
    67  	getResult := func() []interface{} { return results }
    68  	if len(results) == 1 {
    69  		seq, ok := results[0].(SequenceOpt)
    70  		if ok {
    71  			getResult = seq.GetNext
    72  		}
    73  	}
    74  
    75  	hookType := m.builder.hookType()
    76  	m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value {
    77  		results := getResult()
    78  		tool.CheckReturnType(m.builder.target, results...)
    79  		valueResults := make([]reflect.Value, 0)
    80  		for i, result := range results {
    81  			rValue := reflect.Zero(hookType.Out(i))
    82  			if result != nil {
    83  				rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
    84  			}
    85  			valueResults = append(valueResults, rValue)
    86  		}
    87  		return valueResults
    88  	}).Interface()
    89  }
    90  
    91  func (m *mockCondition) SetTo(to interface{}) {
    92  	tool.Assert(m.hook == nil, "re-set builder hook")
    93  	m.SetToForce(to)
    94  }
    95  
    96  func (m *mockCondition) SetToForce(to interface{}) {
    97  	hType := reflect.TypeOf(to)
    98  	tool.Assert(hType.Kind() == reflect.Func, "to a is not a func")
    99  	hasGeneric, hasReceiver := m.checkGenericAndReceiver(hType)
   100  	tool.Assert(m.builder.generic || !hasGeneric, "non-generic function should not have 'GenericInfo' as first argument")
   101  	m.hook = reflect.MakeFunc(m.builder.hookType(), func(args []reflect.Value) (results []reflect.Value) {
   102  		results = tool.ReflectCall(reflect.ValueOf(to), m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
   103  		return
   104  	}).Interface()
   105  }
   106  
   107  // checkGenericAndReceiver check if typ has GenericsInfo and selfReceiver as argument
   108  //
   109  // The hook function will looks like func(_ GenericInfo, self *struct, arg0 int ...)
   110  // When we use 'When' or 'To', our input hook function will looks like:
   111  //  1. func(arg0 int ...)
   112  //  2. func(info GenericInfo, arg0 int ...)
   113  //  3. func(self *struct, arg0 int ...)
   114  //  4. func(info GenericInfo, self *struct, arg0 int ...)
   115  //
   116  // All above input hooks are legal, but we need to make an adaptation when calling then
   117  func (m *mockCondition) checkGenericAndReceiver(typ reflect.Type) (bool, bool) {
   118  	targetType := reflect.TypeOf(m.builder.target)
   119  	tool.Assert(typ.Kind() == reflect.Func, "Param(%v) a is not a func", typ.Kind())
   120  	tool.Assert(targetType.IsVariadic() == typ.IsVariadic(), "target:%v, hook:%v args not match", targetType, typ)
   121  
   122  	shiftTyp := 0
   123  	if typ.NumIn() > 0 && typ.In(0) == genericInfoType {
   124  		shiftTyp = 1
   125  	}
   126  
   127  	// has receiver
   128  	if tool.CheckFuncArgs(targetType, typ, 0, shiftTyp) {
   129  		return shiftTyp == 1, true
   130  	}
   131  
   132  	if tool.CheckFuncArgs(targetType, typ, 1, shiftTyp) {
   133  		return shiftTyp == 1, false
   134  	}
   135  	tool.Assert(false, "target:%v, hook:%v args not match", targetType, typ)
   136  	return false, false
   137  }
   138  
   139  // adaptArgsForReflectCall makes an adaption for reflect call
   140  //
   141  // see (*mockCondition).checkGenericAndReceiver for more info
   142  func (m *mockCondition) adaptArgsForReflectCall(args []reflect.Value, hasGeneric, hasReceiver bool) []reflect.Value {
   143  	adaption := []reflect.Value{}
   144  	if m.builder.generic {
   145  		if hasGeneric {
   146  			adaption = append(adaption, args[0])
   147  		}
   148  		args = args[1:]
   149  	}
   150  	if !hasReceiver {
   151  		args = args[1:]
   152  	}
   153  	adaption = append(adaption, args...)
   154  	return adaption
   155  }