github.com/eh-steve/goloader@v0.0.0-20240111193454-90ff3cfdae39/convert.go (about)

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  package goloader
     5  
     6  import (
     7  	"fmt"
     8  	"github.com/eh-steve/goloader/goversion"
     9  	"github.com/eh-steve/goloader/mprotect"
    10  	"log"
    11  	"reflect"
    12  	"regexp"
    13  	"runtime"
    14  	"runtime/debug"
    15  	"strings"
    16  	"unsafe"
    17  )
    18  
    19  func CanAttemptConversion(oldValue interface{}, newType reflect.Type) bool {
    20  	oldT := efaceOf(&oldValue)._type
    21  	newT := fromRType(newType)
    22  	seen := map[_typePair]struct{}{}
    23  	return typesEqual(oldT, newT, seen)
    24  }
    25  
    26  func ConvertTypesAcrossModules(oldModule, newModule *CodeModule, oldValue interface{}, newType reflect.Type) (res interface{}, err error) {
    27  	defer func() {
    28  		if v := recover(); v != nil {
    29  			err = fmt.Errorf("unexpected panic (this is a bug): %v\n stack trace: %s", v, debug.Stack())
    30  		}
    31  	}()
    32  
    33  	// You can't just do a reflect.cvtDirect() across modules if composite types of oldValue/newValue contain an interface.
    34  	// The value stored in the interface could point to the itab from the old module, which might get unloaded
    35  	// So we need to recurse over the entire structure, and find any itabs and replace them with the equivalent from the new module
    36  
    37  	oldT := efaceOf(&oldValue)._type
    38  	newT := fromRType(newType)
    39  	seen := map[_typePair]struct{}{}
    40  	if !typesEqual(oldT, newT, seen) {
    41  		return nil, fmt.Errorf("old type %T and new type %s are not equal", oldValue, newType)
    42  	}
    43  
    44  	// Need to take data in old value and copy into new value one field at a time, but check that
    45  	// the type is either shared (first module) or translated from the old to the new modules
    46  	oldV := Indirect(ValueOf(&oldValue)).Elem()
    47  
    48  	cycleDetector := map[uintptr]*Value{}
    49  	typeHash := make(map[uint32][]*_type, len(newModule.module.typelinks))
    50  	buildModuleTypeHash(activeModules()[0], typeHash)
    51  	buildModuleTypeHash(newModule.module, typeHash)
    52  
    53  	cvt(oldModule, newModule, Value{oldV}, AsType(newT), nil, cycleDetector, typeHash)
    54  
    55  	return oldV.ConvertWithInterface(AsType(newT)).Interface(), err
    56  }
    57  
    58  func toType(t Type) *_type {
    59  	var x interface{} = t
    60  	return (*_type)(efaceOf(&x).data)
    61  }
    62  
    63  func fromRType(t reflect.Type) *_type {
    64  	var x interface{} = t
    65  	return (*_type)(efaceOf(&x).data)
    66  }
    67  
    68  type fakeValue struct {
    69  	typ  *_type
    70  	ptr  unsafe.Pointer
    71  	flag uintptr
    72  }
    73  
    74  func AsType(_typ *_type) Type {
    75  	var t interface{} = TypeOf("")
    76  	eface := efaceOf(&t)
    77  	eface.data = unsafe.Pointer(_typ)
    78  	return t.(Type)
    79  }
    80  
    81  func AsRType(_typ *_type) reflect.Type {
    82  	var t interface{} = reflect.TypeOf("")
    83  	eface := efaceOf(&t)
    84  	eface.data = unsafe.Pointer(_typ)
    85  	return t.(reflect.Type)
    86  }
    87  
    88  var closureFuncRegex = regexp.MustCompile(`^.*\.func[0-9]+$`)
    89  
    90  func cvt(oldModule, newModule *CodeModule, oldValue Value, newType Type, oldValueBeforeElem *Value, cycleDetector map[uintptr]*Value, typeHash map[uint32][]*_type) {
    91  	// By this point we're sure that types are structurally equal, but their *_type addresses might not be
    92  
    93  	kind := oldValue.Kind()
    94  
    95  	if newType.Kind() != kind {
    96  		panic(fmt.Sprintf("old value's kind (%s) and new type (%s - %s) don't match", kind, newType.String(), newType.Kind()))
    97  	}
    98  
    99  	// Non-composite types of equal kind have same underlying type
   100  	if Bool <= kind && kind <= Complex128 || kind == String || kind == UnsafePointer {
   101  		return
   102  	}
   103  
   104  	switch kind {
   105  	case Array, Ptr, Slice:
   106  		elemKind := oldValue.Type().Elem().Kind()
   107  		if Bool <= elemKind && elemKind <= Complex128 || elemKind == String || elemKind == UnsafePointer {
   108  			// Shortcut for non-composite types
   109  			return
   110  		}
   111  	}
   112  
   113  	// Composite types.
   114  	switch kind {
   115  	case Interface:
   116  		innerVal := oldValue.Elem()
   117  		if innerVal.Kind() == Invalid {
   118  			return
   119  		}
   120  		oldTInner := toType(innerVal.Type())
   121  		oldTOuter := toType(oldValue.Type())
   122  		var newTypeInner *_type
   123  		var newTypeOuter *_type
   124  		types := typeHash[oldTInner.hash]
   125  		for _, _typeNew := range types {
   126  			seen := map[_typePair]struct{}{}
   127  			if oldTInner == _typeNew || typesEqual(oldTInner, _typeNew, seen) {
   128  				newTypeInner = _typeNew
   129  				break
   130  			}
   131  		}
   132  
   133  		types = typeHash[oldTOuter.hash]
   134  		for _, _typeNew := range types {
   135  			seen := map[_typePair]struct{}{}
   136  			if oldTOuter == _typeNew || typesEqual(oldTOuter, _typeNew, seen) {
   137  				newTypeOuter = _typeNew
   138  				break
   139  			}
   140  		}
   141  
   142  		if newTypeInner == nil {
   143  			oldTAddr := uintptr(unsafe.Pointer(oldTInner))
   144  			if innerVal.Type().PkgPath() == "" || (firstmoduledata.types >= oldTAddr && oldTAddr < firstmoduledata.etypes) {
   145  				newTypeInner = oldTInner
   146  			} else {
   147  				panic(fmt.Sprintf("new module does not contain equivalent type for %s (hash %d)", innerVal.Type(), toType(innerVal.Type()).hash))
   148  			}
   149  		}
   150  		if newTypeOuter == nil {
   151  			oldTAddr := uintptr(unsafe.Pointer(oldTOuter))
   152  			if oldValue.Type().PkgPath() == "" || (firstmoduledata.types >= oldTAddr && oldTAddr < firstmoduledata.etypes) {
   153  				newTypeOuter = oldTOuter
   154  			} else {
   155  				panic(fmt.Sprintf("new module does not contain equivalent type for %s (hash %d)", oldValue.Type(), toType(oldValue.Type()).hash))
   156  			}
   157  		}
   158  
   159  		newInnerType := AsType(newTypeInner)
   160  		newOuterType := AsType(newTypeOuter)
   161  		tt := (*interfacetype)(unsafe.Pointer(newTypeOuter))
   162  
   163  		if len(tt.mhdr) > 0 {
   164  			iface := (*nonEmptyInterface)(((*fakeValue)(unsafe.Pointer(&oldValue))).ptr)
   165  			if iface.itab == nil {
   166  				// nil value in interface, no further work required
   167  				return
   168  			} else {
   169  				// Need to check whether itab points at old module, and find the equivalent itab in the new module and point to that instead
   170  
   171  				var oldItab *itab
   172  				for _, o := range oldModule.module.itablinks {
   173  					if iface.itab == o {
   174  						oldItab = o
   175  						break
   176  					}
   177  				}
   178  				if oldItab != nil {
   179  					var newItab *itab
   180  					for _, n := range newModule.module.itablinks {
   181  						// Need to compare these types carefully
   182  						if oldItab.inter.typ.hash == n.inter.typ.hash && oldItab._type.hash == n._type.hash {
   183  							seen := map[_typePair]struct{}{}
   184  							if typesEqual(&oldItab.inter.typ, &n.inter.typ, seen) && typesEqual(oldItab._type, n._type, seen) {
   185  								newItab = n
   186  								break
   187  							}
   188  						}
   189  					}
   190  					if newItab == nil {
   191  						panic(fmt.Sprintf("could not find equivalent itab for interface %s type %s in new module.", oldValue.Type().String(), oldValue.Elem().Type().String()))
   192  					}
   193  					iface.itab = newItab
   194  				}
   195  			}
   196  		} else {
   197  			eface := (*emptyInterface)(((*fakeValue)(unsafe.Pointer(&oldValue))).ptr)
   198  			eface._type = newTypeInner
   199  		}
   200  
   201  		innerValKind := innerVal.Kind()
   202  		if !(Bool <= innerValKind && innerValKind <= Complex128 || innerValKind == String || innerValKind == UnsafePointer) {
   203  			cvt(oldModule, newModule, Value{innerVal}, newInnerType, &oldValue, cycleDetector, typeHash)
   204  		} else {
   205  			if innerVal.CanConvert(newInnerType) {
   206  				newVal := innerVal.Convert(newInnerType)
   207  				if !oldValue.CanSet() {
   208  					if !oldValue.CanAddr() {
   209  						if oldValueBeforeElem != nil && oldValueBeforeElem.Kind() == Interface {
   210  							oldValueBeforeElem.Set(newVal)
   211  						} else {
   212  							panic(fmt.Sprintf("can't set old value of type %s with new value %s (can't address or indirect)", oldValue.Type(), newVal.Type()))
   213  						}
   214  					} else {
   215  						NewAt(newOuterType, unsafe.Pointer(oldValue.UnsafeAddr())).Elem().Set(newVal)
   216  					}
   217  				} else {
   218  					oldValue.Set(newVal)
   219  				}
   220  			} else {
   221  				panic(fmt.Sprintf("can't convert old value of type %s with new value %s", innerVal.Type(), newInnerType))
   222  			}
   223  		}
   224  	case Func:
   225  		oldPtr := oldValue.Pointer()
   226  		if oldPtr != 0 {
   227  			if oldPtr < firstmoduledata.text || oldPtr >= firstmoduledata.etext {
   228  				if oldPtr >= oldModule.module.text && oldPtr < oldModule.module.etext {
   229  					// If the func points at code inside the old module, we need to either find the address of
   230  					// the equivalent func by name, or error if we can't find it
   231  					oldF := runtime.FuncForPC(oldPtr)
   232  					oldFName := oldF.Name()
   233  					if oldFName == "" {
   234  						panic(fmt.Sprintf("old value's function pointer 0x%x does not have a name - cannot convert anonymous functions", oldPtr))
   235  					}
   236  					found := false
   237  					for _, f := range newModule.module.ftab {
   238  						_func := (*_func)(unsafe.Pointer(&(newModule.module.pclntable[f.funcoff])))
   239  						name := getfuncname(_func, newModule.module)
   240  						if name == oldFName {
   241  							entry := getfuncentry(_func, newModule.module.text)
   242  							// This is actually unsafe, because there's no guarantee that the new version
   243  							// of the function has the same signature as the old, and there's no way of accessing
   244  							// the function *_type from just a PC addr, unless the compiler populated a ptab.
   245  							log.Printf("WARNING - converting functions %s by name - no guarantees that signatures will match \n", oldFName)
   246  							newValue := oldValue
   247  							manipulation := (*fakeValue)(unsafe.Pointer(&newValue))
   248  							var funcContainer unsafe.Pointer
   249  							if strings.HasSuffix(oldFName, "-fm") {
   250  								// This is a method, so the data pointer in the value is actually to a closure struct { F uintptr; R *receiver }
   251  								// and the function pointer is to a wrapper func which accepts this struct as its argument
   252  								closure := *(**struct {
   253  									F uintptr
   254  									R unsafe.Pointer
   255  								})(manipulation.ptr)
   256  
   257  								// We need to not only set the func entrypoint, but also convert the receiver and set that too
   258  								// TODO - how can we find out the receiver's type in order to convert across modules?
   259  								//  This code might not be safe if the receivers then call other methods?
   260  
   261  								// Now check whether the old closure.F is an itab method or a concrete type
   262  								var oldItab *itab
   263  
   264  								// This deref of the receiver into an 8 byte word is 100% unsafe, but I can't figure out how to find out what the type of R is...
   265  								recvVal := (*itab)(closure.R)
   266  								for _, itab := range oldModule.module.itablinks {
   267  									if itab.inter == recvVal.inter && itab._type == recvVal._type {
   268  										oldItab = itab
   269  									}
   270  								}
   271  								if oldItab != nil {
   272  									var newItab *itab
   273  									for _, n := range newModule.module.itablinks {
   274  										// Need to compare these types carefully
   275  										if oldItab.inter.typ.hash == n.inter.typ.hash && oldItab._type.hash == n._type.hash {
   276  											seen := map[_typePair]struct{}{}
   277  											if typesEqual(&oldItab.inter.typ, &n.inter.typ, seen) && typesEqual(oldItab._type, n._type, seen) {
   278  												newItab = n
   279  												break
   280  											}
   281  										}
   282  									}
   283  									if newItab == nil {
   284  										panic(fmt.Sprintf("could not find equivalent itab for interface %s type %s in new module.", oldValue.Type().String(), oldValue.Elem().Type().String()))
   285  									}
   286  									closure.R = unsafe.Pointer(newItab)
   287  								}
   288  
   289  								funcContainer = unsafe.Pointer(closure)
   290  								closure.F = entry
   291  							} else if closureFuncRegex.MatchString(oldFName) {
   292  								containerSym, haveContainerSym := newModule.Syms[oldFName+"·f"]
   293  								if haveContainerSym && goversion.GoVersion() > 18 {
   294  									funcContainer = unsafe.Pointer(containerSym)
   295  								} else {
   296  									// This is a closure which is unlikely to be safe since the variables it closes over might be in the old module's memory
   297  									closure := *(**struct {
   298  										F uintptr
   299  										// ... <- variables which are captured by the closure would follow, but we can't know how many they are or what their types are - the best we can do is switch the function implementation and keep the variables the same
   300  									})(manipulation.ptr)
   301  									if runtime.GOARCH == "arm64" && runtime.GOOS == "darwin" {
   302  										err := mprotect.MprotectMakeWritable(mprotect.GetPage(uintptr(unsafe.Pointer(closure))))
   303  										if err != nil {
   304  											panic(fmt.Sprintf("failed to make page of closure writable: %s", err))
   305  										}
   306  									}
   307  									closure.F = entry
   308  									funcContainer = unsafe.Pointer(closure)
   309  									log.Printf("EVEN BIGGER WARNING - converting anonymous function %s by name - no guarantees that signatures, or the closed over variable sizes, or types will match. This is dangerous! \n", oldFName)
   310  								}
   311  							} else {
   312  								containerSym, haveContainerSym := newModule.Syms[oldFName+"·f"]
   313  								if haveContainerSym {
   314  									funcContainer = unsafe.Pointer(containerSym)
   315  								} else {
   316  									// PC addresses for functions are 2 levels of indirection from a reflect value's word addr,
   317  									// so we allocate addresses on the heap to hold the indirections
   318  									// Normally the RODATA has a pkgname.FuncName·f symbol which stores this - Ideally we would use that instead of the heap
   319  									// TODO - is this definitely safe from GC?
   320  									funcPtr := new(uintptr)
   321  									*funcPtr = entry
   322  									funcContainer = unsafe.Pointer(funcPtr)
   323  								}
   324  							}
   325  							funcPtrContainer := new(unsafe.Pointer)
   326  							*funcPtrContainer = funcContainer
   327  							manipulation.ptr = unsafe.Pointer(funcPtrContainer)
   328  							manipulation.typ = toType(newType)
   329  							doSet := true
   330  							if !oldValue.CanSet() {
   331  								if oldValue.CanAddr() {
   332  									oldValue = Value{NewAt(newType, unsafe.Pointer(oldValue.UnsafeAddr())).Elem()}
   333  								} else {
   334  									if oldValueBeforeElem != nil && oldValueBeforeElem.Kind() == Interface {
   335  										doSet = false
   336  										oldValueBeforeElem.Set(newValue.Value)
   337  									} else {
   338  										panic(fmt.Sprintf("can't set old func of type %s with new value 0x%x (can't address or indirect)", oldValue.Type(), entry))
   339  									}
   340  								}
   341  							}
   342  							if doSet {
   343  								oldValue.Set(newValue.Value)
   344  							}
   345  							found = true
   346  							break
   347  						}
   348  					}
   349  					if !found {
   350  						panic(fmt.Sprintf("old value's function pointer 0x%x with name %s has no equivalent name in new module - cannot convert", oldPtr, oldFName))
   351  					}
   352  				} else {
   353  					panic(fmt.Sprintf("old value's function pointer 0x%x not in first module (0x%x - 0x%x) nor old module (0x%x - 0x%x) - cannot convert", oldPtr, firstmoduledata.text, firstmoduledata.etext, oldModule.module.text, oldModule.module.etext))
   354  				}
   355  			}
   356  		}
   357  	case Array, Slice:
   358  		for i := 0; i < oldValue.Len(); i++ {
   359  			cvt(oldModule, newModule, Value{oldValue.Index(i)}, newType.Elem(), nil, cycleDetector, typeHash)
   360  		}
   361  	case Map:
   362  		if oldValue.Len() == 0 {
   363  			return
   364  		}
   365  		keyType := oldValue.Type().Key()
   366  		valType := oldValue.Type().Elem()
   367  		mvKind := valType.Kind()
   368  		mkKind := keyType.Kind()
   369  		if !(Bool <= mvKind && mvKind <= Complex128 || mvKind == String || mvKind == UnsafePointer) ||
   370  			!(Bool <= mkKind && mkKind <= Complex128 || mkKind == String || mkKind == UnsafePointer) {
   371  			// Need to recreate map entirely since they aren't mutable
   372  			newMap := MakeMapWithSize(oldValue.Type(), oldValue.Len())
   373  			mapKeys := oldValue.MapKeys()
   374  			for _, mapKey := range mapKeys {
   375  				mapValue := oldValue.MapIndex(mapKey)
   376  				var nk Value
   377  				if mkKind == Ptr {
   378  					nk = Value{New(keyType.Elem())}
   379  				} else {
   380  					nk = Value{Indirect(New(keyType))}
   381  				}
   382  				var nv Value
   383  				if mvKind == Ptr {
   384  					nv = Value{New(valType.Elem())}
   385  				} else {
   386  					nv = Value{Indirect(New(valType))}
   387  				}
   388  				if mkKind == Ptr {
   389  					nk.Elem().Set(mapKey.Elem())
   390  				} else {
   391  					nk.Set(mapKey)
   392  				}
   393  				if mvKind == Ptr {
   394  					nv.Elem().Set(mapValue.Elem())
   395  				} else {
   396  					nv.Set(mapValue)
   397  				}
   398  
   399  				cvt(oldModule, newModule, nv, newType.Elem(), &oldValue, cycleDetector, typeHash)
   400  				cvt(oldModule, newModule, nk, newType.Key(), &oldValue, cycleDetector, typeHash)
   401  				newMap.SetMapIndex(nk.Value, nv.Value)
   402  			}
   403  			doSet := true
   404  			if !oldValue.CanSet() {
   405  				if oldValue.CanAddr() {
   406  					oldValue = Value{NewAt(oldValue.Type(), unsafe.Pointer(oldValue.UnsafeAddr())).Elem()}
   407  				} else {
   408  					if oldValueBeforeElem != nil && oldValueBeforeElem.Kind() == Interface {
   409  						doSet = false
   410  						oldValueBeforeElem.Set(newMap)
   411  					} else {
   412  						panic(fmt.Sprintf("can't set old map of type %s with new value (can't address or indirect)", oldValue.Type()))
   413  					}
   414  				}
   415  			}
   416  			if doSet {
   417  				oldValue.Set(newMap)
   418  			}
   419  		}
   420  	case Ptr:
   421  		if !oldValue.IsNil() {
   422  			up := oldValue.Pointer()
   423  			if _, cyclic := cycleDetector[up]; cyclic {
   424  				return
   425  			} else {
   426  				cycleDetector[up] = &oldValue
   427  				cvt(oldModule, newModule, Value{oldValue.Elem()}, newType.Elem(), &oldValue, cycleDetector, typeHash)
   428  			}
   429  		}
   430  	case Struct:
   431  		for i := 0; i < oldValue.NumField(); i++ {
   432  			field := oldValue.Field(i)
   433  			fieldKind := field.Kind()
   434  			if !(Bool <= fieldKind && fieldKind <= Complex128 || fieldKind == String || fieldKind == UnsafePointer) {
   435  				cvt(oldModule, newModule, Value{field}, newType.Field(i).Type, nil, cycleDetector, typeHash)
   436  			}
   437  		}
   438  	}
   439  }