github.com/pkujhd/goloader@v0.0.0-20240411034752-1a28096bd7bd/examples/loader/loader.go (about)

     1  package main
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"net/http"
     7  	"os"
     8  	"runtime"
     9  	"strings"
    10  	"sync"
    11  	"unsafe"
    12  
    13  	"github.com/pkujhd/goloader"
    14  )
    15  
    16  type arrayFlags struct {
    17  	File    []string
    18  	PkgPath []string
    19  }
    20  
    21  func (i *arrayFlags) String() string {
    22  	return "my string representation"
    23  }
    24  
    25  func (i *arrayFlags) Set(value string) error {
    26  	s := strings.Split(value, ":")
    27  	i.File = append(i.File, s[0])
    28  	var path string
    29  	if len(s) > 1 {
    30  		path = s[1]
    31  	}
    32  	i.PkgPath = append(i.PkgPath, path)
    33  	return nil
    34  }
    35  
    36  func main() {
    37  	var files arrayFlags
    38  	flag.Var(&files, "o", "load go object file")
    39  	var pkgpath = flag.String("p", "", "package path")
    40  	var parseFile = flag.String("parse", "", "parse go object file")
    41  	var run = flag.String("run", "main.main", "run function")
    42  	var times = flag.Int("times", 1, "run count")
    43  
    44  	flag.Parse()
    45  
    46  	if *parseFile != "" {
    47  		parse(*parseFile, *pkgpath)
    48  		return
    49  	}
    50  
    51  	if len(files.File) == 0 {
    52  		flag.PrintDefaults()
    53  		return
    54  	}
    55  
    56  	symPtr := make(map[string]uintptr)
    57  	err := goloader.RegSymbol(symPtr)
    58  	if err != nil {
    59  		fmt.Println(err)
    60  		return
    61  	}
    62  
    63  	// most of time you don't need to register function, but if loader complain about it, you have to.
    64  	w := sync.WaitGroup{}
    65  	str := make([]string, 0)
    66  	goloader.RegTypes(symPtr, http.ListenAndServe, http.Dir("/"),
    67  		http.Handler(http.FileServer(http.Dir("/"))), http.FileServer, http.HandleFunc,
    68  		&http.Request{}, &http.Server{}, (&http.ServeMux{}).Handle)
    69  	goloader.RegTypes(symPtr, runtime.LockOSThread, &w, w.Wait)
    70  	goloader.RegTypes(symPtr, fmt.Sprint, str)
    71  
    72  	linker, err := goloader.ReadObjs(files.File, files.PkgPath)
    73  	if err != nil {
    74  		fmt.Println(err)
    75  		return
    76  	}
    77  
    78  	var mmapByte []byte
    79  	for i := 0; i < *times; i++ {
    80  		codeModule, err := goloader.Load(linker, symPtr)
    81  		if err != nil {
    82  			fmt.Println("Load error:", err)
    83  			return
    84  		}
    85  		runFuncPtr := codeModule.Syms[*run]
    86  		if runFuncPtr == 0 {
    87  			fmt.Println("Load error! not find function:", *run)
    88  			return
    89  		}
    90  		funcPtrContainer := (uintptr)(unsafe.Pointer(&runFuncPtr))
    91  		runFunc := *(*func())(unsafe.Pointer(&funcPtrContainer))
    92  		runFunc()
    93  		os.Stdout.Sync()
    94  		codeModule.Unload()
    95  
    96  		// a strict test, try to make mmap random
    97  		if mmapByte == nil {
    98  			mmapByte, err = goloader.Mmap(1024)
    99  			if err != nil {
   100  				fmt.Println(err)
   101  			}
   102  			b := make([]byte, 1024)
   103  			copy(mmapByte, b) // reset all bytes
   104  		} else {
   105  			goloader.Munmap(mmapByte)
   106  			mmapByte = nil
   107  		}
   108  	}
   109  
   110  }
   111  
   112  func parse(file, pkgpath string) {
   113  	if file == "" {
   114  		flag.PrintDefaults()
   115  		return
   116  	}
   117  	obj, err := goloader.Parse(file, pkgpath)
   118  	fmt.Printf("%# v\n", obj)
   119  	if err != nil {
   120  		fmt.Printf("error reading %s: %v\n", file, err)
   121  		return
   122  	}
   123  }