github.com/bytedance/mockey@v1.2.10/utils.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package mockey
    18  
    19  import (
    20  	"reflect"
    21  	"unsafe"
    22  
    23  	"github.com/bytedance/mockey/internal/tool"
    24  	"github.com/bytedance/mockey/internal/unsafereflect"
    25  )
    26  
    27  // GetMethod resolve a certain public method from an instance.
    28  func GetMethod(instance interface{}, methodName string) interface{} {
    29  	if typ := reflect.TypeOf(instance); typ != nil {
    30  		if m, ok := getNestedMethod(reflect.ValueOf(instance), methodName); ok {
    31  			return m.Func.Interface()
    32  		}
    33  		if m, ok := typ.MethodByName(methodName); ok {
    34  			return m.Func.Interface()
    35  		}
    36  		if m, ok := getFieldMethod(instance, methodName); ok {
    37  			return m
    38  		}
    39  		ch0 := methodName[0]
    40  		if !(ch0 >= 'A' && ch0 <= 'Z') {
    41  			return unsafeMethodByName(instance, methodName)
    42  		}
    43  	}
    44  	tool.Assert(false, "can't reflect instance method :%v", methodName)
    45  	return nil
    46  }
    47  
    48  // getFieldMethod gets a functional field's value as an instance
    49  // The return instance is not original field but a new function object points to
    50  // the same function.
    51  // for example:
    52  //
    53  //	  type Fn func()
    54  //	  type Foo struct {
    55  //			privateField Fn
    56  //	  }
    57  //	  func NewFoo() Foo { return Foo{ privateField: func() { /*do nothing*/ } }}
    58  //
    59  // getFieldMethod(NewFoo(),"privateField") will return a function object which
    60  // points to the anonymous function in NewFoo
    61  func getFieldMethod(instance interface{}, fieldName string) (interface{}, bool) {
    62  	v := reflect.Indirect(reflect.ValueOf(instance))
    63  	if v.Kind() != reflect.Struct {
    64  		return nil, false
    65  	}
    66  
    67  	field := v.FieldByName(fieldName)
    68  	if !field.IsValid() || field.Kind() != reflect.Func {
    69  		return nil, false
    70  	}
    71  
    72  	carrier := reflect.MakeFunc(field.Type(), nil)
    73  	type function struct {
    74  		_      uintptr
    75  		fnAddr *uintptr
    76  	}
    77  	*(*function)(unsafe.Pointer(&carrier)).fnAddr = field.Pointer()
    78  	return carrier.Interface(), true
    79  }
    80  
    81  // GetPrivateMethod resolve a certain public method from an instance.
    82  // Deprecated, this is an old API in mockito. Please use GetMethod instead.
    83  func GetPrivateMethod(instance interface{}, methodName string) interface{} {
    84  	m, ok := reflect.TypeOf(instance).MethodByName(methodName)
    85  	if ok {
    86  		return m.Func.Interface()
    87  	}
    88  	tool.Assert(false, "can't reflect instance method :%v", methodName)
    89  	return nil
    90  }
    91  
    92  // GetNestedMethod resolves a certain public method in anonymous structs, it will
    93  // look for the specific method in every anonymous struct field recursively.
    94  // Deprecated, this is an old API in mockito. Please use GetMethod instead.
    95  func GetNestedMethod(instance interface{}, methodName string) interface{} {
    96  	if typ := reflect.TypeOf(instance); typ != nil {
    97  		if m, ok := getNestedMethod(reflect.ValueOf(instance), methodName); ok {
    98  			return m.Func.Interface()
    99  		}
   100  	}
   101  	tool.Assert(false, "can't reflect instance method :%v", methodName)
   102  	return nil
   103  }
   104  
   105  func getNestedMethod(val reflect.Value, methodName string) (reflect.Method, bool) {
   106  	typ := val.Type()
   107  	kind := typ.Kind()
   108  	if kind == reflect.Ptr || kind == reflect.Interface {
   109  		val = val.Elem()
   110  	}
   111  	if !val.IsValid() {
   112  		return reflect.Method{}, false
   113  	}
   114  
   115  	typ = val.Type()
   116  	kind = typ.Kind()
   117  	if kind == reflect.Struct {
   118  		for i := 0; i < typ.NumField(); i++ {
   119  			if !typ.Field(i).Anonymous {
   120  				// there is no need to acquire non-anonymous method
   121  				continue
   122  			}
   123  			if m, ok := getNestedMethod(val.Field(i), methodName); ok {
   124  				return m, true
   125  			}
   126  		}
   127  	}
   128  	// a struct receiver is prior to the corresponding pointer receiver
   129  	if m, ok := typ.MethodByName(methodName); ok {
   130  		return m, true
   131  	}
   132  	return reflect.PtrTo(typ).MethodByName(methodName)
   133  }
   134  
   135  // unsafeMethodByName resolve a method from an instance, include private method.
   136  //
   137  // THIS IS UNSAFE FOR LOWER GO VERSION(<1.12)
   138  //
   139  // for example:
   140  //
   141  //	unsafeMethodByName(&bytes.Buffer{}, "empty")
   142  //	unsafeMethodByName(sha256.New(), "checkSum")
   143  func unsafeMethodByName(instance interface{}, methodName string) interface{} {
   144  	typ, tfn, ok := unsafereflect.MethodByName(instance, methodName)
   145  	if !ok {
   146  		tool.Assert(false, "can't reflect instance method :%v", methodName)
   147  		return nil
   148  	}
   149  	if typ == nil {
   150  		tool.Assert(false, "failed to determine %v's type", methodName)
   151  	}
   152  
   153  	if typ.Kind() != reflect.Func {
   154  		tool.Assert(false, "invalid instance method type: %v,%v", methodName, typ.Kind().String())
   155  		return nil
   156  	}
   157  
   158  	in := []reflect.Type{reflect.TypeOf(instance)}
   159  	out := []reflect.Type{}
   160  	for i := 0; i < typ.NumIn(); i++ {
   161  		in = append(in, typ.In(i))
   162  	}
   163  	for i := 0; i < typ.NumOut(); i++ {
   164  		out = append(out, typ.Out(i))
   165  	}
   166  
   167  	hook := reflect.FuncOf(in, out, typ.IsVariadic())
   168  	vt := reflect.Zero(hook).Interface()
   169  	*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&vt)) + 8)) = uintptr(unsafe.Pointer(tfn))
   170  	return vt
   171  }
   172  
   173  // GetGoroutineId gets the current goroutine ID
   174  func GetGoroutineId() int64 {
   175  	return tool.GetGoroutineID()
   176  }