gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/main.go (about)

     1  //go:build !windows
     2  
     3  /* SPDX-License-Identifier: MIT
     4   *
     5   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     6   */
     7  
     8  package main
     9  
    10  import (
    11  	"fmt"
    12  	"os"
    13  	"os/signal"
    14  	"runtime"
    15  	"strconv"
    16  
    17  	"gitee.com/aurawing/surguard-go/conn"
    18  	"gitee.com/aurawing/surguard-go/device"
    19  	"gitee.com/aurawing/surguard-go/ipc"
    20  	"gitee.com/aurawing/surguard-go/tun"
    21  	"golang.org/x/sys/unix"
    22  )
    23  
    24  const (
    25  	ExitSetupSuccess = 0
    26  	ExitSetupFailed  = 1
    27  )
    28  
    29  const (
    30  	ENV_SG_TUN_FD             = "SG_TUN_FD"
    31  	ENV_SG_UAPI_FD            = "SG_UAPI_FD"
    32  	ENV_SG_PROCESS_FOREGROUND = "SG_PROCESS_FOREGROUND"
    33  	ENV_SG_LEGACY_MODE        = "SG_LEGACY_MODE"
    34  )
    35  
    36  func printUsage() {
    37  	fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
    38  }
    39  
    40  func warning() {
    41  	switch runtime.GOOS {
    42  	case "linux", "freebsd", "openbsd":
    43  		if os.Getenv(ENV_SG_PROCESS_FOREGROUND) == "1" {
    44  			return
    45  		}
    46  	default:
    47  		return
    48  	}
    49  
    50  	fmt.Fprintln(os.Stdin, "surguard is running in background")
    51  
    52  	// fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
    53  	// fmt.Fprintln(os.Stderr, "│                                                      │")
    54  	// fmt.Fprintln(os.Stderr, "│   Running wireguard-go is not required because this  │")
    55  	// fmt.Fprintln(os.Stderr, "│   kernel has first class support for WireGuard. For  │")
    56  	// fmt.Fprintln(os.Stderr, "│   information on installing the kernel module,       │")
    57  	// fmt.Fprintln(os.Stderr, "│   please visit:                                      │")
    58  	// fmt.Fprintln(os.Stderr, "│         https://www.wireguard.com/install/           │")
    59  	// fmt.Fprintln(os.Stderr, "│                                                      │")
    60  	// fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
    61  }
    62  
    63  func main() {
    64  	if len(os.Args) == 2 && os.Args[1] == "--version" {
    65  		fmt.Printf("surguard v%s\n\nUserspace SurGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
    66  		return
    67  	}
    68  
    69  	warning()
    70  
    71  	var foreground bool
    72  	var interfaceName string
    73  	if len(os.Args) < 2 || len(os.Args) > 3 {
    74  		printUsage()
    75  		return
    76  	}
    77  
    78  	switch os.Args[1] {
    79  
    80  	case "-f", "--foreground":
    81  		foreground = true
    82  		if len(os.Args) != 3 {
    83  			printUsage()
    84  			return
    85  		}
    86  		interfaceName = os.Args[2]
    87  
    88  	default:
    89  		foreground = false
    90  		if len(os.Args) != 2 {
    91  			printUsage()
    92  			return
    93  		}
    94  		interfaceName = os.Args[1]
    95  	}
    96  
    97  	if !foreground {
    98  		foreground = os.Getenv(ENV_SG_PROCESS_FOREGROUND) == "1"
    99  	}
   100  
   101  	// get log level (default: info)
   102  
   103  	logLevel := func() int {
   104  		switch os.Getenv("LOG_LEVEL") {
   105  		case "verbose", "debug":
   106  			return device.LogLevelVerbose
   107  		case "error":
   108  			return device.LogLevelError
   109  		case "silent":
   110  			return device.LogLevelSilent
   111  		}
   112  		return device.LogLevelError
   113  	}()
   114  
   115  	// open TUN device (or use supplied fd)
   116  
   117  	tdev, err := func() (tun.Device, error) {
   118  		tunFdStr := os.Getenv(ENV_SG_TUN_FD)
   119  		if tunFdStr == "" {
   120  			return tun.CreateTUN(interfaceName, device.DefaultMTU)
   121  		}
   122  
   123  		// construct tun device from supplied fd
   124  
   125  		fd, err := strconv.ParseUint(tunFdStr, 10, 32)
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  
   130  		err = unix.SetNonblock(int(fd), true)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  
   135  		file := os.NewFile(uintptr(fd), "")
   136  		return tun.CreateTUNFromFile(file, device.DefaultMTU)
   137  	}()
   138  
   139  	if err == nil {
   140  		realInterfaceName, err2 := tdev.Name()
   141  		if err2 == nil {
   142  			interfaceName = realInterfaceName
   143  		}
   144  	}
   145  
   146  	logger := device.NewLogger(
   147  		logLevel,
   148  		fmt.Sprintf("(%s) ", interfaceName),
   149  	)
   150  
   151  	logger.Verbosef("Starting surguard version %s", Version)
   152  
   153  	if err != nil {
   154  		logger.Errorf("Failed to create TUN device: %v", err)
   155  		os.Exit(ExitSetupFailed)
   156  	}
   157  
   158  	// open UAPI file (or use supplied fd)
   159  
   160  	fileUAPI, err := func() (*os.File, error) {
   161  		uapiFdStr := os.Getenv(ENV_SG_UAPI_FD)
   162  		if uapiFdStr == "" {
   163  			return ipc.UAPIOpen(interfaceName)
   164  		}
   165  
   166  		// use supplied fd
   167  
   168  		fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
   169  		if err != nil {
   170  			return nil, err
   171  		}
   172  
   173  		return os.NewFile(uintptr(fd), ""), nil
   174  	}()
   175  	if err != nil {
   176  		logger.Errorf("UAPI listen error: %v", err)
   177  		os.Exit(ExitSetupFailed)
   178  		return
   179  	}
   180  	// daemonize the process
   181  
   182  	if !foreground {
   183  		env := os.Environ()
   184  		env = append(env, fmt.Sprintf("%s=3", ENV_SG_TUN_FD))
   185  		env = append(env, fmt.Sprintf("%s=4", ENV_SG_UAPI_FD))
   186  		env = append(env, fmt.Sprintf("%s=1", ENV_SG_PROCESS_FOREGROUND))
   187  		files := [3]*os.File{}
   188  		if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent {
   189  			files[0], _ = os.Open(os.DevNull)
   190  			files[1] = os.Stdout
   191  			files[2] = os.Stderr
   192  		} else {
   193  			files[0], _ = os.Open(os.DevNull)
   194  			files[1], _ = os.Open(os.DevNull)
   195  			files[2], _ = os.Open(os.DevNull)
   196  		}
   197  		attr := &os.ProcAttr{
   198  			Files: []*os.File{
   199  				files[0], // stdin
   200  				files[1], // stdout
   201  				files[2], // stderr
   202  				tdev.File(),
   203  				fileUAPI,
   204  			},
   205  			Dir: ".",
   206  			Env: env,
   207  		}
   208  
   209  		path, err := os.Executable()
   210  		if err != nil {
   211  			logger.Errorf("Failed to determine executable: %v", err)
   212  			os.Exit(ExitSetupFailed)
   213  		}
   214  
   215  		process, err := os.StartProcess(
   216  			path,
   217  			os.Args,
   218  			attr,
   219  		)
   220  		if err != nil {
   221  			logger.Errorf("Failed to daemonize: %v", err)
   222  			os.Exit(ExitSetupFailed)
   223  		}
   224  		process.Release()
   225  		return
   226  	}
   227  
   228  	dev := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
   229  
   230  	logger.Verbosef("Device started")
   231  
   232  	errs := make(chan error)
   233  	term := make(chan os.Signal, 1)
   234  
   235  	uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
   236  	if err != nil {
   237  		logger.Errorf("Failed to listen on uapi socket: %v", err)
   238  		os.Exit(ExitSetupFailed)
   239  	}
   240  
   241  	go func() {
   242  		for {
   243  			conn, err := uapi.Accept()
   244  			if err != nil {
   245  				errs <- err
   246  				return
   247  			}
   248  			go dev.IpcHandle(conn)
   249  		}
   250  	}()
   251  
   252  	logger.Verbosef("UAPI listener started")
   253  	legacymode := os.Getenv(ENV_SG_LEGACY_MODE)
   254  	statPortStr := os.Getenv(device.ENV_SG_LISTEN_PORT)
   255  	if statPortStr == "" {
   256  		statPortStr = strconv.FormatInt(device.DEFAULT_LISTEN_PORT, 10)
   257  	}
   258  	if legacymode != "true" {
   259  		err = startStatServer(interfaceName, statPortStr, logger, dev)
   260  		if err != nil {
   261  			logger.Errorf("start stat server failed: %v", err)
   262  			os.Exit(ExitSetupFailed)
   263  		}
   264  		dev.PostConfig()
   265  	}
   266  
   267  	// wait for program to terminate
   268  
   269  	signal.Notify(term, unix.SIGTERM)
   270  	signal.Notify(term, os.Interrupt)
   271  
   272  	select {
   273  	case <-term:
   274  	case <-errs:
   275  	case <-dev.Wait():
   276  	}
   277  
   278  	// clean up
   279  	if legacymode != "true" {
   280  		dev.ClearConfig()
   281  		stopStatServer(logger)
   282  	}
   283  	uapi.Close()
   284  	dev.Close()
   285  
   286  	logger.Verbosef("Shutting down")
   287  }