golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/tunnel/service.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tunnel
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"log"
    12  	"os"
    13  	"runtime"
    14  	"time"
    15  
    16  	"golang.org/x/sys/windows"
    17  	"golang.org/x/sys/windows/svc"
    18  	"golang.org/x/sys/windows/svc/mgr"
    19  	"golang.zx2c4.com/wireguard/windows/conf"
    20  	"golang.zx2c4.com/wireguard/windows/driver"
    21  	"golang.zx2c4.com/wireguard/windows/elevate"
    22  	"golang.zx2c4.com/wireguard/windows/ringlogger"
    23  	"golang.zx2c4.com/wireguard/windows/services"
    24  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    25  )
    26  
    27  type tunnelService struct {
    28  	Path string
    29  }
    30  
    31  func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
    32  	serviceState := svc.StartPending
    33  	changes <- svc.Status{State: serviceState}
    34  
    35  	var watcher *interfaceWatcher
    36  	var adapter *driver.Adapter
    37  	var luid winipcfg.LUID
    38  	var config *conf.Config
    39  	var err error
    40  	serviceError := services.ErrorSuccess
    41  
    42  	defer func() {
    43  		svcSpecificEC, exitCode = services.DetermineErrorCode(err, serviceError)
    44  		logErr := services.CombineErrors(err, serviceError)
    45  		if logErr != nil {
    46  			log.Println(logErr)
    47  		}
    48  		serviceState = svc.StopPending
    49  		changes <- svc.Status{State: serviceState}
    50  
    51  		stopIt := make(chan bool, 1)
    52  		go func() {
    53  			t := time.NewTicker(time.Second * 30)
    54  			for {
    55  				select {
    56  				case <-t.C:
    57  					t.Stop()
    58  					buf := make([]byte, 1024)
    59  					for {
    60  						n := runtime.Stack(buf, true)
    61  						if n < len(buf) {
    62  							buf = buf[:n]
    63  							break
    64  						}
    65  						buf = make([]byte, 2*len(buf))
    66  					}
    67  					lines := bytes.Split(buf, []byte{'\n'})
    68  					log.Println("Failed to shutdown after 30 seconds. Probably dead locked. Printing stack and killing.")
    69  					for _, line := range lines {
    70  						if len(bytes.TrimSpace(line)) > 0 {
    71  							log.Println(string(line))
    72  						}
    73  					}
    74  					os.Exit(777)
    75  					return
    76  				case <-stopIt:
    77  					t.Stop()
    78  					return
    79  				}
    80  			}
    81  		}()
    82  
    83  		if logErr == nil && adapter != nil && config != nil {
    84  			logErr = runScriptCommand(config.Interface.PreDown, config.Name)
    85  		}
    86  		if watcher != nil {
    87  			watcher.Destroy()
    88  		}
    89  		if adapter != nil {
    90  			adapter.Close()
    91  		}
    92  		if logErr == nil && adapter != nil && config != nil {
    93  			_ = runScriptCommand(config.Interface.PostDown, config.Name)
    94  		}
    95  		stopIt <- true
    96  		log.Println("Shutting down")
    97  	}()
    98  
    99  	var logFile string
   100  	logFile, err = conf.LogFile(true)
   101  	if err != nil {
   102  		serviceError = services.ErrorRingloggerOpen
   103  		return
   104  	}
   105  	err = ringlogger.InitGlobalLogger(logFile, "TUN")
   106  	if err != nil {
   107  		serviceError = services.ErrorRingloggerOpen
   108  		return
   109  	}
   110  
   111  	config, err = conf.LoadFromPath(service.Path)
   112  	if err != nil {
   113  		serviceError = services.ErrorLoadConfiguration
   114  		return
   115  	}
   116  	config.DeduplicateNetworkEntries()
   117  
   118  	log.SetPrefix(fmt.Sprintf("[%s] ", config.Name))
   119  
   120  	services.PrintStarting()
   121  
   122  	if services.StartedAtBoot() {
   123  		if m, err := mgr.Connect(); err == nil {
   124  			if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
   125  				/* If we don't do this, then the driver installation will block forever, because
   126  				 * installing a network adapter starts the driver service too. Apparently at boot time,
   127  				 * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't
   128  				 * announce that we're running before starting additional services.
   129  				 */
   130  				log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
   131  				serviceState = svc.Running
   132  				changes <- svc.Status{State: serviceState}
   133  			}
   134  			m.Disconnect()
   135  		}
   136  	}
   137  
   138  	evaluateStaticPitfalls()
   139  
   140  	log.Println("Watching network interfaces")
   141  	watcher, err = watchInterface()
   142  	if err != nil {
   143  		serviceError = services.ErrorSetNetConfig
   144  		return
   145  	}
   146  
   147  	log.Println("Resolving DNS names")
   148  	err = config.ResolveEndpoints()
   149  	if err != nil {
   150  		serviceError = services.ErrorDNSLookup
   151  		return
   152  	}
   153  
   154  	log.Println("Creating network adapter")
   155  	for i := 0; i < 15; i++ {
   156  		if i > 0 {
   157  			time.Sleep(time.Second)
   158  			log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
   159  		}
   160  		adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config))
   161  		if err == nil || !services.StartedAtBoot() {
   162  			break
   163  		}
   164  	}
   165  	if err != nil {
   166  		err = fmt.Errorf("Error creating adapter: %w", err)
   167  		serviceError = services.ErrorCreateNetworkAdapter
   168  		return
   169  	}
   170  	luid = adapter.LUID()
   171  	driverVersion, err := driver.RunningVersion()
   172  	if err != nil {
   173  		log.Printf("Warning: unable to determine driver version: %v", err)
   174  	} else {
   175  		log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
   176  	}
   177  	err = adapter.SetLogging(driver.AdapterLogOn)
   178  	if err != nil {
   179  		err = fmt.Errorf("Error enabling adapter logging: %w", err)
   180  		serviceError = services.ErrorCreateNetworkAdapter
   181  		return
   182  	}
   183  
   184  	err = runScriptCommand(config.Interface.PreUp, config.Name)
   185  	if err != nil {
   186  		serviceError = services.ErrorRunScript
   187  		return
   188  	}
   189  
   190  	err = enableFirewall(config, luid)
   191  	if err != nil {
   192  		serviceError = services.ErrorFirewall
   193  		return
   194  	}
   195  
   196  	log.Println("Dropping privileges")
   197  	err = elevate.DropAllPrivileges(true)
   198  	if err != nil {
   199  		serviceError = services.ErrorDropPrivileges
   200  		return
   201  	}
   202  
   203  	log.Println("Setting interface configuration")
   204  	err = adapter.SetConfiguration(config.ToDriverConfiguration())
   205  	if err != nil {
   206  		serviceError = services.ErrorDeviceSetConfig
   207  		return
   208  	}
   209  	err = adapter.SetAdapterState(driver.AdapterStateUp)
   210  	if err != nil {
   211  		serviceError = services.ErrorDeviceBringUp
   212  		return
   213  	}
   214  	watcher.Configure(adapter, config, luid)
   215  
   216  	err = runScriptCommand(config.Interface.PostUp, config.Name)
   217  	if err != nil {
   218  		serviceError = services.ErrorRunScript
   219  		return
   220  	}
   221  
   222  	changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown}
   223  
   224  	var started bool
   225  	for {
   226  		select {
   227  		case c := <-r:
   228  			switch c.Cmd {
   229  			case svc.Stop, svc.Shutdown:
   230  				return
   231  			case svc.Interrogate:
   232  				changes <- c.CurrentStatus
   233  			default:
   234  				log.Printf("Unexpected service control request #%d\n", c)
   235  			}
   236  		case <-watcher.started:
   237  			if !started {
   238  				serviceState = svc.Running
   239  				changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown}
   240  				log.Println("Startup complete")
   241  				started = true
   242  			}
   243  		case e := <-watcher.errors:
   244  			serviceError, err = e.serviceError, e.err
   245  			return
   246  		}
   247  	}
   248  }
   249  
   250  func Run(confPath string) error {
   251  	name, err := conf.NameFromPath(confPath)
   252  	if err != nil {
   253  		return err
   254  	}
   255  	serviceName, err := conf.ServiceNameOfTunnel(name)
   256  	if err != nil {
   257  		return err
   258  	}
   259  	return svc.Run(serviceName, &tunnelService{confPath})
   260  }