github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/cmd/yuhaiin/main_windows.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"strconv"
     7  	"time"
     8  
     9  	"github.com/Asutorufa/yuhaiin/internal/appapi"
    10  	"github.com/Asutorufa/yuhaiin/internal/version"
    11  	"github.com/Asutorufa/yuhaiin/pkg/log"
    12  	"golang.org/x/sys/windows"
    13  	"golang.org/x/sys/windows/registry"
    14  	"golang.org/x/sys/windows/svc"
    15  	"golang.org/x/sys/windows/svc/mgr"
    16  )
    17  
    18  func init() {
    19  	install = installSystemDaemonWindows
    20  	uninstall = uninstallSystemDaemonWindows
    21  	stop = stopService
    22  	start = startService
    23  
    24  	if !isWindowsService() {
    25  		return
    26  	}
    27  
    28  	log.OutputStderr = false
    29  	run = runService
    30  }
    31  
    32  func runService(app *appapi.Components, errChan chan error, signChannel chan os.Signal) {
    33  	svc.Run(version.AppName, &service{
    34  		app:         app,
    35  		errChan:     errChan,
    36  		signChannel: signChannel,
    37  	})
    38  }
    39  
    40  // copy from https://github.com/tailscale/tailscale/blob/main/cmd/tailscaled/install_windows.go
    41  
    42  func installSystemDaemonWindows(args []string) (err error) {
    43  	m, err := mgr.Connect()
    44  	if err != nil {
    45  		return fmt.Errorf("failed to connect to Windows service manager: %v", err)
    46  	}
    47  	defer m.Disconnect()
    48  
    49  	service, err := m.OpenService(version.AppName)
    50  	if err == nil {
    51  		service.Close()
    52  		return fmt.Errorf("service %q is already installed", version.AppName)
    53  	}
    54  
    55  	// no such service; proceed to install the service.
    56  
    57  	exe, err := os.Executable()
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	c := mgr.Config{
    63  		ServiceType:  windows.SERVICE_WIN32_OWN_PROCESS,
    64  		StartType:    mgr.StartAutomatic,
    65  		ErrorControl: mgr.ErrorNormal,
    66  		DisplayName:  version.AppName,
    67  		Description:  "transparent proxy",
    68  	}
    69  
    70  	service, err = m.CreateService(version.AppName, exe, c, args...)
    71  	if err != nil {
    72  		return fmt.Errorf("failed to create %q service: %v", version.AppName, err)
    73  	}
    74  	defer service.Close()
    75  
    76  	// Exponential backoff is often too aggressive, so use (mostly)
    77  	// squares instead.
    78  	ra := []mgr.RecoveryAction{
    79  		{mgr.ServiceRestart, 1 * time.Second},
    80  		{mgr.ServiceRestart, 2 * time.Second},
    81  		{mgr.ServiceRestart, 4 * time.Second},
    82  		{mgr.ServiceRestart, 9 * time.Second},
    83  		{mgr.ServiceRestart, 16 * time.Second},
    84  		{mgr.ServiceRestart, 25 * time.Second},
    85  		{mgr.ServiceRestart, 36 * time.Second},
    86  		{mgr.ServiceRestart, 49 * time.Second},
    87  		{mgr.ServiceRestart, 64 * time.Second},
    88  	}
    89  	const resetPeriodSecs = 60
    90  	err = service.SetRecoveryActions(ra, resetPeriodSecs)
    91  	if err != nil {
    92  		return fmt.Errorf("failed to set service recovery actions: %v", err)
    93  	}
    94  
    95  	return service.Start(args...)
    96  }
    97  
    98  func uninstallSystemDaemonWindows(args []string) (ret error) {
    99  	// Remove file sharing from Windows shell (noop in non-windows)
   100  	// osshare.SetFileSharingEnabled(false, logger.Discard)
   101  
   102  	m, err := mgr.Connect()
   103  	if err != nil {
   104  		return fmt.Errorf("failed to connect to Windows service manager: %v", err)
   105  	}
   106  	defer m.Disconnect()
   107  
   108  	service, err := m.OpenService(version.AppName)
   109  	if err != nil {
   110  		return fmt.Errorf("failed to open %q service: %v", version.AppName, err)
   111  	}
   112  
   113  	st, err := service.Query()
   114  	if err != nil {
   115  		service.Close()
   116  		return fmt.Errorf("failed to query service state: %v", err)
   117  	}
   118  	if st.State != svc.Stopped {
   119  		service.Control(svc.Stop)
   120  	}
   121  	err = service.Delete()
   122  	service.Close()
   123  	if err != nil {
   124  		return fmt.Errorf("failed to delete service: %v", err)
   125  	}
   126  
   127  	end := time.Now().Add(15 * time.Second)
   128  	for time.Until(end) > 0 {
   129  		service, err = m.OpenService(version.AppName)
   130  		if err != nil {
   131  			// service is no longer openable; success!
   132  			break
   133  		}
   134  		service.Close()
   135  	}
   136  	return nil
   137  }
   138  
   139  func isWindowsService() bool {
   140  	ok, err := svc.IsWindowsService()
   141  	if err != nil {
   142  		log.Error("failed to check if we are running in Windows service", "err", err)
   143  		panic(err)
   144  	}
   145  
   146  	return ok
   147  }
   148  
   149  type service struct {
   150  	app         *appapi.Components
   151  	errChan     chan error
   152  	signChannel chan os.Signal
   153  }
   154  
   155  func (ss *service) Execute(args []string, r <-chan svc.ChangeRequest, s chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
   156  
   157  	s <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
   158  
   159  	for {
   160  		select {
   161  		case err := <-ss.errChan:
   162  			log.Error("http server stop", "err", err)
   163  			s <- svc.Status{State: svc.Stopped}
   164  			return false, windows.NO_ERROR
   165  		case <-ss.signChannel:
   166  			ss.app.HttpListener.Close()
   167  			s <- svc.Status{State: svc.Stopped}
   168  		case c := <-r:
   169  			log.Info("Got Windows Service event", "cmd", cmdName(c.Cmd))
   170  			switch c.Cmd {
   171  			case svc.Interrogate:
   172  				s <- c.CurrentStatus
   173  			case svc.Stop, svc.Shutdown:
   174  				ss.app.HttpListener.Close()
   175  			}
   176  		}
   177  	}
   178  
   179  }
   180  
   181  func cmdName(c svc.Cmd) string {
   182  	switch c {
   183  	case svc.Stop:
   184  		return "Stop"
   185  	case svc.Pause:
   186  		return "Pause"
   187  	case svc.Continue:
   188  		return "Continue"
   189  	case svc.Interrogate:
   190  		return "Interrogate"
   191  	case svc.Shutdown:
   192  		return "Shutdown"
   193  	case svc.ParamChange:
   194  		return "ParamChange"
   195  	case svc.NetBindAdd:
   196  		return "NetBindAdd"
   197  	case svc.NetBindRemove:
   198  		return "NetBindRemove"
   199  	case svc.NetBindEnable:
   200  		return "NetBindEnable"
   201  	case svc.NetBindDisable:
   202  		return "NetBindDisable"
   203  	case svc.DeviceEvent:
   204  		return "DeviceEvent"
   205  	case svc.HardwareProfileChange:
   206  		return "HardwareProfileChange"
   207  	case svc.PowerEvent:
   208  		return "PowerEvent"
   209  	case svc.SessionChange:
   210  		return "SessionChange"
   211  	case svc.PreShutdown:
   212  		return "PreShutdown"
   213  	}
   214  	return fmt.Sprintf("Unknown-Service-Cmd-%d", c)
   215  }
   216  
   217  func stopService(args []string) error {
   218  	m, err := mgr.Connect()
   219  	if err != nil {
   220  		return err
   221  	}
   222  	defer m.Disconnect()
   223  
   224  	s, err := m.OpenService(version.AppName)
   225  	if err != nil {
   226  		return err
   227  	}
   228  	defer s.Close()
   229  
   230  	status, err := s.Control(svc.Stop)
   231  	if err != nil {
   232  		return err
   233  	}
   234  
   235  	timeDuration := time.Millisecond * 50
   236  
   237  	timeout := time.After(getStopTimeout() + (timeDuration * 2))
   238  	tick := time.NewTicker(timeDuration)
   239  	defer tick.Stop()
   240  
   241  	for status.State != svc.Stopped {
   242  		select {
   243  		case <-tick.C:
   244  			status, err = s.Query()
   245  			if err != nil {
   246  				return err
   247  			}
   248  		case <-timeout:
   249  			break
   250  		}
   251  	}
   252  
   253  	return nil
   254  }
   255  
   256  func startService(args []string) error {
   257  	m, err := mgr.Connect()
   258  	if err != nil {
   259  		return err
   260  	}
   261  	defer m.Disconnect()
   262  
   263  	s, err := m.OpenService(version.AppName)
   264  	if err != nil {
   265  		return err
   266  	}
   267  	defer s.Close()
   268  
   269  	err = s.Start()
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	return nil
   275  }
   276  
   277  // getStopTimeout fetches the time before windows will kill the service.
   278  func getStopTimeout() time.Duration {
   279  	// For default and paths see https://support.microsoft.com/en-us/kb/146092
   280  	defaultTimeout := time.Millisecond * 20000
   281  	key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Control`, registry.READ)
   282  	if err != nil {
   283  		return defaultTimeout
   284  	}
   285  	sv, _, err := key.GetStringValue("WaitToKillServiceTimeout")
   286  	if err != nil {
   287  		return defaultTimeout
   288  	}
   289  	v, err := strconv.Atoi(sv)
   290  	if err != nil {
   291  		return defaultTimeout
   292  	}
   293  	return time.Millisecond * time.Duration(v)
   294  }