github.com/eh-steve/goloader@v0.0.0-20240111193454-90ff3cfdae39/mmap/manager.go (about)

     1  package mmap
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/eh-steve/goloader/mmap/mapping"
     6  	"math"
     7  	"sort"
     8  	"sync"
     9  	"syscall"
    10  	"unsafe"
    11  )
    12  
    13  var pageSize = uintptr(syscall.Getpagesize()) // Overridden for windows to use GetAllocationGranularity()
    14  
    15  func roundPageUp(p uintptr) uintptr {
    16  	return (p & ^(pageSize - 1)) + pageSize
    17  }
    18  
    19  func roundPageDown(p uintptr) uintptr {
    20  	return p & ^(pageSize - 1)
    21  }
    22  
    23  type gap struct {
    24  	startAddr uintptr
    25  	endAddr   uintptr
    26  }
    27  
    28  func findNextFreeAddressesAfterTarget(targetAddr uintptr, size int, mappings []mapping.Mapping) (gaps []gap, err error) {
    29  	sort.Slice(mappings, func(i, j int) bool {
    30  		return mappings[i].StartAddr < mappings[j].StartAddr
    31  	})
    32  
    33  	var allGaps []gap
    34  	for i, mapping := range mappings[:len(mappings)-1] {
    35  		allGaps = append(allGaps, gap{
    36  			startAddr: roundPageUp(mapping.EndAddr),
    37  			endAddr:   roundPageDown(mappings[i+1].StartAddr),
    38  		})
    39  	}
    40  	allGaps = append(allGaps, gap{
    41  		startAddr: roundPageUp(mappings[len(mappings)-1].EndAddr),
    42  		endAddr:   math.MaxUint64, // We really shouldn't be in this situation...
    43  	})
    44  	var suitableGaps []gap
    45  	for _, g := range allGaps {
    46  		if g.startAddr > targetAddr && int(g.endAddr-g.startAddr) >= size {
    47  			suitableGaps = append(suitableGaps, g)
    48  		}
    49  	}
    50  	if len(suitableGaps) == 0 {
    51  		return suitableGaps, fmt.Errorf("could not find free address range with size 0x%x after target 0x%x", size, targetAddr)
    52  	}
    53  	return suitableGaps, nil
    54  }
    55  
    56  //go:linkname activeModules runtime.activeModules
    57  func activeModules() []unsafe.Pointer
    58  
    59  // This isn't concurrency safe since other code outside goloader might mmap something in the same region we're trying to.
    60  // Ideally we would just skip these collisions via MAP_FIXED_NOREPLACE but this isn't portable...
    61  var mmapLock sync.Mutex
    62  
    63  func AcquireMapping(size int, mapFunc func(size int, addr uintptr) ([]byte, error)) ([]byte, error) {
    64  	mmapLock.Lock()
    65  	defer mmapLock.Unlock()
    66  
    67  	firstModuleAddr := uintptr(activeModules()[0])
    68  	mappings, err := getCurrentProcMaps()
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	gaps, err := findNextFreeAddressesAfterTarget(firstModuleAddr, int(roundPageUp(uintptr(size))), mappings)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	for _, gap := range gaps {
    79  		start := gap.startAddr
    80  		end := gap.endAddr
    81  		for i := start; i+roundPageUp(uintptr(size)) < end; i += roundPageUp(uintptr(size)) {
    82  			mapping, err := mapFunc(int(roundPageUp(uintptr(size))), i)
    83  			if err != nil {
    84  				// Keep going, try again
    85  			} else {
    86  				if uintptr(unsafe.Pointer(&mapping[len(mapping)-1]))-firstModuleAddr > 1<<32 {
    87  					err = Munmap(mapping)
    88  					if err != nil {
    89  						return nil, fmt.Errorf("failed to acquire a mapping within 32 bits of the first module address, wanted 0x%x, got %p - %p, also failed to munmap: %w", firstModuleAddr, &mapping[0], &mapping[len(mapping)-1], err)
    90  					}
    91  					return nil, fmt.Errorf("failed to acquire a mapping within 32 bits of the first module address, wanted 0x%x, got %p - %p", firstModuleAddr, &mapping[0], &mapping[len(mapping)-1])
    92  				}
    93  				return mapping, nil
    94  			}
    95  		}
    96  	}
    97  
    98  	return nil, fmt.Errorf("failed to aquire mapping between taken mappings: \n%s", formatTakenMappings(mappings))
    99  }
   100  
   101  func formatTakenMappings(mappings []mapping.Mapping) string {
   102  	var mappingTakenRanges string
   103  	for _, m := range mappings {
   104  		mappingTakenRanges += fmt.Sprintf("0x%x - 0x%x  %s\n", m.StartAddr, m.EndAddr, m.PathName)
   105  	}
   106  	return mappingTakenRanges
   107  }