gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/postconfig.go (about)

     1  package device
     2  
     3  import (
     4  	"encoding/hex"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/jackpal/gateway"
    16  )
    17  
    18  const (
    19  	ENV_SG_GROUP_NAME          = "SG_GROUP_NAME"
    20  	ENV_SG_CONFIG_ROOT         = "SG_CONFIG_ROOT"
    21  	ENV_SG_KMS_URL             = "SG_KMS_URL"
    22  	ENV_SG_ZK_URL              = "SG_ZK_URL"
    23  	ENV_SG_DEFAULT_IFACE       = "SG_DEFAULT_IFACE"
    24  	ENV_SG_LISTEN_PORT         = "SG_LISTEN_PORT"
    25  	ENV_SG_KEEPALIVE_INTERVAL  = "SG_KEEPALIVE_INTERVAL"
    26  	ENV_SG_MONITOR_IP_INTERVAL = "SG_MONITOR_IP_INTERVAL"
    27  	ENV_SG_PRIVATEKEY          = "SG_PRIVATEKEY"
    28  	//ENV_SG_IF_BIND_INTERFACE   = "SG_IF_BIND_INTERFACE"
    29  )
    30  
    31  const DEFAULT_GROUP_NAME = "SURGUARD"
    32  const DEFAULT_CONFIG_ROOT = "/etc/surguard"
    33  const DEFAULT_LISTEN_PORT = 50001
    34  const DEFAULT_KEEPALIVE_INTERVAL = 30
    35  const DEFAULT_MONITOR_IP_INTERVAL = 5
    36  
    37  var groupName string
    38  var zkCli *ZkDiscovery
    39  var interfaceIP string
    40  var interfaceIPArr [][4]byte
    41  var sgIPArr [][4]byte
    42  var interfaceIndex byte
    43  var finishCh chan chan bool
    44  var keepaliveInterval int
    45  var ifBindInterface bool
    46  
    47  func IfBindInterface() bool {
    48  	return ifBindInterface
    49  }
    50  
    51  func (device *Device) PostConfig() {
    52  	//parse environment variables
    53  	groupNameStr := os.Getenv(ENV_SG_GROUP_NAME)
    54  	if groupNameStr == "" {
    55  		groupName = DEFAULT_GROUP_NAME
    56  	} else {
    57  		groupName = groupNameStr
    58  	}
    59  	configRoot := os.Getenv(ENV_SG_CONFIG_ROOT)
    60  	if configRoot == "" {
    61  		configRoot = DEFAULT_CONFIG_ROOT
    62  	}
    63  	kmsURL := os.Getenv(ENV_SG_KMS_URL)
    64  	if kmsURL == "" {
    65  		device.log.Errorf("PostConfig: KMS URL is empty")
    66  		panic(errors.New("invalid KMS URL"))
    67  	}
    68  	zkUrl := os.Getenv(ENV_SG_ZK_URL)
    69  	if zkUrl == "" {
    70  		device.log.Errorf("PostConfig: zookeeper URL is invalid: %s", zkUrl)
    71  		panic(errors.New("invalid zookeeper URL"))
    72  	}
    73  	keepaliveIntervalStr := os.Getenv(ENV_SG_KEEPALIVE_INTERVAL)
    74  	if keepaliveIntervalStr != "" {
    75  		var err error
    76  		keepaliveInterval, err = strconv.Atoi(keepaliveIntervalStr)
    77  		if err != nil {
    78  			device.log.Errorf("PostConfig: parsing keepalive interval failed: %s", err)
    79  			keepaliveInterval = DEFAULT_KEEPALIVE_INTERVAL
    80  		}
    81  	} else {
    82  		keepaliveInterval = DEFAULT_KEEPALIVE_INTERVAL
    83  	}
    84  
    85  	monitorIPInterval := DEFAULT_MONITOR_IP_INTERVAL
    86  	monitorIPIntervalStr := os.Getenv(ENV_SG_MONITOR_IP_INTERVAL)
    87  	if monitorIPIntervalStr != "" {
    88  		var err error
    89  		monitorIPInterval, err = strconv.Atoi(monitorIPIntervalStr)
    90  		if err != nil {
    91  			device.log.Errorf("PostConfig: parse monitor IP interval failed: %s", err)
    92  			monitorIPInterval = DEFAULT_MONITOR_IP_INTERVAL
    93  		}
    94  	}
    95  
    96  	// ifBindInterfaceStr := os.Getenv(ENV_SG_IF_BIND_INTERFACE)
    97  	// if ifBindInterfaceStr == "" {
    98  	// 	ifBindInterface = false
    99  	// } else {
   100  	// 	var err error
   101  	// 	ifBindInterface, err = strconv.ParseBool(ifBindInterfaceStr)
   102  	// 	if err != nil {
   103  	// 		device.log.Errorf("PostConfig: ifBindInterface is invalid: %s -> %s", ifBindInterfaceStr, err)
   104  	// 		panic(err)
   105  	// 	}
   106  	// }
   107  
   108  	var err error
   109  	finishCh = make(chan chan bool)
   110  
   111  	// ipaddrs, err := device.getIP()
   112  	// if err != nil {
   113  	// 	device.log.Errorf("PostConfig: failed to get IP address: %s", err)
   114  	// 	panic(err)
   115  	// }
   116  	// interfaceIP = strings.Join(ipaddrs, ",")
   117  	// interfaceIPArr = make([][4]byte, 0)
   118  	// for _, ipaddr := range ipaddrs {
   119  	// 	tmp := net.ParseIP(ipaddr)
   120  	// 	interfaceIPArr[0] = tmp[12]
   121  	// 	interfaceIPArr[1] = tmp[13]
   122  	// 	interfaceIPArr[2] = tmp[14]
   123  	// 	interfaceIPArr[3] = tmp[15]
   124  	// }
   125  
   126  	err = device.configTunDevice()
   127  	if err != nil {
   128  		device.log.Errorf("PostConfig: config TUN device failed: %s", err)
   129  		panic(err)
   130  	}
   131  	sk, err := device.loginKMS(kmsURL, configRoot)
   132  	if err != nil {
   133  		device.log.Errorf("PostConfig: login KMS failed: %s", err)
   134  		panic(err)
   135  	}
   136  	err = device.setSKandPort(sk)
   137  	if err != nil {
   138  		device.log.Errorf("PostConfig: set private key and listen port failed: %s", err)
   139  		panic(err)
   140  	}
   141  	err = device.initRules()
   142  	if err != nil {
   143  		device.log.Errorf("PostConfig: initialize rules failed: %s", err)
   144  		panic(err)
   145  	}
   146  	// interfaceIPStr, err := getIP()
   147  	// if err != nil {
   148  	// 	panic(err)
   149  	// }
   150  	// interfaceIP = interfaceIPStr
   151  
   152  	zkCli, err = CreateZkDiscovery(zkUrl, device, fmt.Sprintf("/%s", groupName))
   153  	if err != nil {
   154  		device.log.Errorf("PostConfig: create zookeeper client failed: %s", err)
   155  		panic(err)
   156  	}
   157  	device.monitorIP(true)
   158  	go func() {
   159  		var ch chan bool
   160  	OUTER:
   161  		for {
   162  			select {
   163  			case ch = <-finishCh:
   164  				break OUTER
   165  			case <-time.After(time.Second * time.Duration(monitorIPInterval)):
   166  				device.monitorIP(false)
   167  			}
   168  		}
   169  		ch <- true
   170  	}()
   171  }
   172  
   173  func (device *Device) monitorIP(init bool) {
   174  	zkCli.Lock()
   175  	pathExist, err := zkCli.ExistPeer(device.staticIdentity.publicKey)
   176  	if err != nil {
   177  		device.log.Errorf("PostConfig: error when get path from zk: %s", err)
   178  	}
   179  	if !pathExist && interfaceIP != "" && !init {
   180  		ifAddrs := strings.Split(interfaceIP, ",")
   181  		rpAddrs := make([]string, 0)
   182  		for _, rpAddr := range ifAddrs {
   183  			rpAddrs = append(rpAddrs, fmt.Sprintf("%s:%d", rpAddr, device.net.port))
   184  		}
   185  		zkCli.AddPeer(device.staticIdentity.publicKey, strings.Join(rpAddrs, ","))
   186  	}
   187  	zkCli.Unlock()
   188  	//TODO: 检查zookeeper上的IP地址是否为空
   189  	data := ""
   190  	data, _ = zkCli.GetData(device.staticIdentity.publicKey)
   191  	ipaddrs, err := device.getIP()
   192  	if err != nil {
   193  		device.log.Errorf("PostConfig: error when get interface IP: %s", err)
   194  	} else {
   195  		sort.Strings(ipaddrs)
   196  		if data == "" || strings.Join(ipaddrs, ",") != interfaceIP || (ifBindInterface && init) {
   197  			zkCli.Lock()
   198  			if data == "" || !(interfaceIP == "" || (ifBindInterface && init)) {
   199  				zkCli.RemovePeer(device.staticIdentity.publicKey)
   200  			}
   201  			reportAddrs := make([]string, 0)
   202  			for _, reportAddr := range ipaddrs {
   203  				reportAddrs = append(reportAddrs, fmt.Sprintf("%s:%d", reportAddr, device.net.port))
   204  			}
   205  			zkCli.AddPeer(device.staticIdentity.publicKey, strings.Join(reportAddrs, ","))
   206  			interfaceIP = strings.Join(ipaddrs, ",")
   207  			// tmp := net.ParseIP(ipaddrs[0])
   208  			// interfaceIPArr[0] = tmp[12]
   209  			// interfaceIPArr[1] = tmp[13]
   210  			// interfaceIPArr[2] = tmp[14]
   211  			// interfaceIPArr[3] = tmp[15]
   212  			if ifBindInterface {
   213  				oldsgLen := len(sgIPArr)
   214  				interfaceIPArr = make([][4]byte, 0)
   215  				sgIPArr = make([][4]byte, 0)
   216  
   217  				for i, ipaddr := range ipaddrs {
   218  					tmp := net.ParseIP(ipaddr)
   219  					tmp2 := [4]byte{tmp[12], tmp[13], tmp[14], tmp[15]}
   220  					interfaceIPArr = append(interfaceIPArr, tmp2)
   221  					sgIPArr = append(sgIPArr, [4]byte{169, 254, interfaceIndex, byte(i + 1)})
   222  				}
   223  				if oldsgLen < len(sgIPArr) {
   224  					for i := oldsgLen + 1; i <= len(sgIPArr); i++ {
   225  						staticip := fmt.Sprintf("169.254.%d.%d", interfaceIndex, i)
   226  						device.runCmd(false, "netsh", "interface", "ipv4", "add", "address", "name=\""+deviceName+"\"", staticip, "255.255.255.0")
   227  						device.log.Verbosef("configTunDevice: add address %s/24 for device %s\n", staticip, deviceName)
   228  					}
   229  				} else if oldsgLen > len(sgIPArr) {
   230  					for i := oldsgLen; i > len(sgIPArr); i-- {
   231  						staticip := fmt.Sprintf("169.254.%d.%d", interfaceIndex, i)
   232  						device.runCmd(false, "netsh", "interface", "ipv4", "delete", "address", "name=\""+deviceName+"\"", staticip)
   233  						device.log.Verbosef("configTunDevice: delete address %s/24 for device %s\n", staticip, deviceName)
   234  					}
   235  				}
   236  			}
   237  
   238  			zkCli.Unlock()
   239  			if ifBindInterface {
   240  				device.net.Lock()
   241  				device.net.ipv4Addr = ipaddrs[0]
   242  				device.net.Unlock()
   243  				if err := device.BindUpdate(); err != nil {
   244  					device.log.Errorf("PostConfig: failed to set ip address %s: %s", ipaddrs[0], err)
   245  				}
   246  			}
   247  		}
   248  	}
   249  }
   250  
   251  // login to KMS and get private key
   252  func (device *Device) loginKMS(kmsURL, configRoot string) (NoisePrivateKey, error) {
   253  	skstr := os.Getenv(ENV_SG_PRIVATEKEY)
   254  	if skstr != "" {
   255  		return device.parsePrivateKey(skstr)
   256  	}
   257  	yidRoot := fmt.Sprintf("%s/%s", configRoot, groupName)
   258  	err := os.MkdirAll(yidRoot, os.ModePerm)
   259  	if err != nil {
   260  		device.log.Errorf("PostConfig: create config path %s failed: %s", yidRoot, err)
   261  		return NoisePrivateKey{}, err
   262  	}
   263  	yidPath := fmt.Sprintf("%s/%s", yidRoot, "kms.conf")
   264  
   265  	//login and get private key
   266  	var vui, yid, yentity string
   267  	ystr, err := os.ReadFile(yidPath)
   268  	if err != nil {
   269  		//create entity
   270  		vuin, yidn, yentityn, err := device.createEntity(kmsURL)
   271  		if err != nil {
   272  			device.log.Errorf("PostConfig: create entity failed: %s", err)
   273  			return NoisePrivateKey{}, err
   274  		}
   275  		err = os.WriteFile(yidPath, []byte(fmt.Sprintf("%s|%s|%s", vuin, yidn, yentityn)), 0666)
   276  		if err != nil {
   277  			device.log.Errorf("PostConfig: write YID and entity to file failed: %s", err)
   278  			return NoisePrivateKey{}, err
   279  		}
   280  		vui = vuin
   281  		yid = yidn
   282  		yentity = yentityn
   283  	} else {
   284  		ystrs := strings.Split(string(ystr), "|")
   285  		if len(ystrs) != 3 {
   286  			err := fmt.Errorf("format of config file is corrupt: %s", yidPath)
   287  			device.log.Errorf("PostConfig: %s", err)
   288  			return NoisePrivateKey{}, err
   289  		}
   290  		vui = strings.TrimSpace(ystrs[0])
   291  		yid = strings.TrimSpace(ystrs[1])
   292  		yentity = strings.TrimSpace(ystrs[2])
   293  	}
   294  	err = device.loginEntity(kmsURL, vui, yid)
   295  	if err != nil {
   296  		device.log.Errorf("PostConfig: login entity failed: %s", err)
   297  		return NoisePrivateKey{}, err
   298  	}
   299  	yentityid, err := strconv.Atoi(yentity)
   300  	if err != nil {
   301  		device.log.Errorf("PostConfig: parse entity to number failed: %s", err)
   302  		return NoisePrivateKey{}, err
   303  	}
   304  	skstr, err = device.getPrivateKey(kmsURL, yentityid)
   305  	if err != nil {
   306  		device.log.Errorf("PostConfig: get private key from KMS failed: %s", err)
   307  		return NoisePrivateKey{}, err
   308  	}
   309  	return device.parsePrivateKey(skstr)
   310  }
   311  
   312  func (device *Device) getIP() ([]string, error) {
   313  	interfaceName := os.Getenv(ENV_SG_DEFAULT_IFACE)
   314  	if interfaceName == "" {
   315  		ip, err := gateway.DiscoverInterface()
   316  		if err != nil {
   317  			device.log.Errorf("PostConfig: error when discovery interface IP: %s", err)
   318  			return nil, err
   319  		}
   320  		if ip.To4() != nil {
   321  			return []string{ip.String()}, nil
   322  		} else {
   323  			err := errors.New("no valid IPv4 address")
   324  			device.log.Errorf("PostConfig: error when convert interface IP: %s", err)
   325  			return nil, err
   326  		}
   327  	}
   328  
   329  	iface, err := net.InterfaceByName(interfaceName)
   330  	if err != nil {
   331  		device.log.Errorf("PostConfig: error when get interface IP by name: %s", err)
   332  		return nil, err
   333  	}
   334  
   335  	addrs, err := iface.Addrs()
   336  	if err != nil {
   337  		device.log.Errorf("PostConfig: error when get IP address of interface: %s", err)
   338  		return nil, err
   339  	}
   340  	addrstrs := make([]string, 0)
   341  	for _, addr := range addrs {
   342  		ipNet, ok := addr.(*net.IPNet)
   343  		if ok && !ipNet.IP.IsLoopback() {
   344  			if ipNet.IP.To4() != nil {
   345  				addrstrs = append(addrstrs, ipNet.IP.String())
   346  				//return ipNet.IP.String(), nil
   347  			}
   348  		}
   349  	}
   350  	//TODO:增加指定IP地址功能
   351  	if len(addrstrs) == 0 {
   352  		err = errors.New("no valid IPv4 address")
   353  		device.log.Errorf("PostConfig: error when get IP address of interface: %s", err)
   354  		return nil, err
   355  	}
   356  	sort.Strings(addrstrs)
   357  	return addrstrs, nil
   358  }
   359  
   360  // set up private key and UDP listen port
   361  func (device *Device) setSKandPort(sk NoisePrivateKey) error {
   362  	listenPort := DEFAULT_LISTEN_PORT
   363  	portStr := os.Getenv(ENV_SG_LISTEN_PORT)
   364  	var err error
   365  	if portStr != "" {
   366  		listenPort, err = strconv.Atoi(portStr)
   367  		if err != nil {
   368  			device.log.Errorf("PostConfig: failed to parse listen port %s: %s", portStr, err)
   369  		}
   370  	}
   371  	//ischanged := false
   372  	device.ipcMutex.Lock()
   373  	//oldPK := device.staticIdentity.publicKey
   374  	defer device.ipcMutex.Unlock()
   375  	if !device.staticIdentity.privateKey.Equals(sk) {
   376  		device.SetPrivateKey(sk)
   377  	}
   378  	// ipv4 := ""
   379  	// ipv4s, err := device.getIP()
   380  	// if err != nil {
   381  	// 	device.log.Errorf("PostConfig: failed to get IP address: %s", err)
   382  	// } else if len(ipv4s) > 0 {
   383  	// 	ipv4 = ipv4s[0]
   384  	// }
   385  	var ipv4 string
   386  	if ifBindInterface && len(interfaceIPArr) > 0 {
   387  		ipv4 = fmt.Sprintf("%d.%d.%d.%d", interfaceIPArr[0][0], interfaceIPArr[0][1], interfaceIPArr[0][2], interfaceIPArr[0][3])
   388  	} else {
   389  		ipv4s, err := device.getIP()
   390  		if err != nil {
   391  			device.log.Errorf("PostConfig: failed to get IP address: %s", err)
   392  		} else if len(ipv4s) > 0 {
   393  			sort.Strings(ipv4s)
   394  			ipv4 = ipv4s[0]
   395  		}
   396  	}
   397  	if device.net.port != uint16(listenPort) {
   398  		device.net.Lock()
   399  		if ifBindInterface && ipv4 != "" {
   400  			device.net.ipv4Addr = ipv4
   401  		}
   402  		device.net.port = uint16(listenPort)
   403  		device.net.Unlock()
   404  		if err := device.BindUpdate(); err != nil {
   405  			device.log.Errorf("PostConfig: failed to set listen_port %d: %s", listenPort, err)
   406  			return err
   407  		}
   408  		//ischanged = true
   409  	}
   410  	// if ischanged {
   411  	// 	//TODO: update zk
   412  	// 	zkCli.Lock()
   413  	// 	zkCli.RemovePeer(oldPK)
   414  	// 	zkCli.AddPeer(device.staticIdentity.publicKey)
   415  	// 	zkCli.Unlock()
   416  	// }
   417  	return nil
   418  }
   419  
   420  func (device *Device) AddPeer(pk NoisePublicKey, endPoints string) error {
   421  	endPointArr := strings.Split(endPoints, ",")
   422  	if len(endPointArr) == 0 {
   423  		device.log.Verbosef("PostConfig: no endpoints under public key %s", hex.EncodeToString(pk[:]))
   424  		return errors.New("no endpoints")
   425  	}
   426  	device.ipcMutex.Lock()
   427  	defer device.ipcMutex.Unlock()
   428  	peer := device.LookupPeer(pk)
   429  	if peer != nil {
   430  		//err := errors.New("peer exists")
   431  		device.log.Verbosef("PostConfig: peer %v exists", peer)
   432  		return nil
   433  	}
   434  	peer, err := device.NewPeer(pk)
   435  	if err != nil {
   436  		device.log.Errorf("PostConfig: create peer failed: %s", err)
   437  		return err
   438  	}
   439  	device.log.Verbosef("%v - PostConfig: Created", peer)
   440  	endpoint, err := device.net.bind.ParseEndpoint(endPointArr[0])
   441  	if err != nil {
   442  		device.log.Errorf("%v - PostConfig: parse endpoint failed: %s", peer, err)
   443  		return err
   444  	}
   445  	peer.endpoint.Lock()
   446  	peer.endpoint.val = endpoint
   447  	peer.endpoint.Unlock()
   448  	device.log.Verbosef("%v - PostConfig: Updating endpoint %s", peer, endPointArr[0])
   449  
   450  	device.allowedips.RemoveByPeer(peer)
   451  	for _, endPoint := range endPointArr {
   452  		allowedIP := fmt.Sprintf("%s/32", strings.Split(endPoint, ":")[0])
   453  		prefix, err := netip.ParsePrefix(allowedIP)
   454  		if err != nil {
   455  			device.log.Errorf("%v - PostConfig: parse allowedIP failed: %s", peer, err)
   456  			return err
   457  		}
   458  		device.allowedips.Insert(prefix, peer)
   459  		device.log.Verbosef("%v - PostConfig: Adding allowedip: %s", peer, allowedIP)
   460  	}
   461  
   462  	old := peer.persistentKeepaliveInterval.Swap(uint32(keepaliveInterval))
   463  	device.log.Verbosef("%v - PostConfig: Updating persistent keepalive interval %d", peer, keepaliveInterval)
   464  
   465  	ipcPeer := new(ipcSetPeer)
   466  	ipcPeer.Peer = peer
   467  	ipcPeer.dummy = false
   468  	ipcPeer.created = true
   469  	ipcPeer.pkaOn = old == 0 && keepaliveInterval != 0
   470  
   471  	if ipcPeer.created {
   472  		ipcPeer.endpoint.disableRoaming = ipcPeer.device.net.brokenRoaming && ipcPeer.endpoint.val != nil
   473  	}
   474  	if ipcPeer.device.isUp() {
   475  		ipcPeer.Start()
   476  		if ipcPeer.pkaOn {
   477  			ipcPeer.SendKeepalive()
   478  		}
   479  		ipcPeer.SendStagedPackets()
   480  	}
   481  	return nil
   482  }
   483  
   484  func (device *Device) DeletePeer(pk NoisePublicKey) {
   485  	device.ipcMutex.Lock()
   486  	defer device.ipcMutex.Unlock()
   487  	peer := device.LookupPeer(pk)
   488  	if peer != nil {
   489  		device.RemovePeer(peer.handshake.remoteStatic)
   490  	}
   491  }
   492  
   493  func (device *Device) ClearConfig() {
   494  	ch := make(chan bool)
   495  	finishCh <- ch
   496  	<-ch
   497  	zkCli.Close()
   498  	device.clearConfigOSSpecific()
   499  }
   500  
   501  func (device *Device) IterPeerEndpoint(f func(string)) {
   502  	device.peers.RLock()
   503  	defer device.peers.RUnlock()
   504  	for _, peer := range device.peers.keyMap {
   505  		dstIP := peer.endpoint.val.DstIP().String()
   506  		f(dstIP)
   507  	}
   508  }