gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/postconfig.go (about) 1 package device 2 3 import ( 4 "encoding/hex" 5 "errors" 6 "fmt" 7 "net" 8 "net/netip" 9 "os" 10 "sort" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/jackpal/gateway" 16 ) 17 18 const ( 19 ENV_SG_GROUP_NAME = "SG_GROUP_NAME" 20 ENV_SG_CONFIG_ROOT = "SG_CONFIG_ROOT" 21 ENV_SG_KMS_URL = "SG_KMS_URL" 22 ENV_SG_ZK_URL = "SG_ZK_URL" 23 ENV_SG_DEFAULT_IFACE = "SG_DEFAULT_IFACE" 24 ENV_SG_LISTEN_PORT = "SG_LISTEN_PORT" 25 ENV_SG_KEEPALIVE_INTERVAL = "SG_KEEPALIVE_INTERVAL" 26 ENV_SG_MONITOR_IP_INTERVAL = "SG_MONITOR_IP_INTERVAL" 27 ENV_SG_PRIVATEKEY = "SG_PRIVATEKEY" 28 //ENV_SG_IF_BIND_INTERFACE = "SG_IF_BIND_INTERFACE" 29 ) 30 31 const DEFAULT_GROUP_NAME = "SURGUARD" 32 const DEFAULT_CONFIG_ROOT = "/etc/surguard" 33 const DEFAULT_LISTEN_PORT = 50001 34 const DEFAULT_KEEPALIVE_INTERVAL = 30 35 const DEFAULT_MONITOR_IP_INTERVAL = 5 36 37 var groupName string 38 var zkCli *ZkDiscovery 39 var interfaceIP string 40 var interfaceIPArr [][4]byte 41 var sgIPArr [][4]byte 42 var interfaceIndex byte 43 var finishCh chan chan bool 44 var keepaliveInterval int 45 var ifBindInterface bool 46 47 func IfBindInterface() bool { 48 return ifBindInterface 49 } 50 51 func (device *Device) PostConfig() { 52 //parse environment variables 53 groupNameStr := os.Getenv(ENV_SG_GROUP_NAME) 54 if groupNameStr == "" { 55 groupName = DEFAULT_GROUP_NAME 56 } else { 57 groupName = groupNameStr 58 } 59 configRoot := os.Getenv(ENV_SG_CONFIG_ROOT) 60 if configRoot == "" { 61 configRoot = DEFAULT_CONFIG_ROOT 62 } 63 kmsURL := os.Getenv(ENV_SG_KMS_URL) 64 if kmsURL == "" { 65 device.log.Errorf("PostConfig: KMS URL is empty") 66 panic(errors.New("invalid KMS URL")) 67 } 68 zkUrl := os.Getenv(ENV_SG_ZK_URL) 69 if zkUrl == "" { 70 device.log.Errorf("PostConfig: zookeeper URL is invalid: %s", zkUrl) 71 panic(errors.New("invalid zookeeper URL")) 72 } 73 keepaliveIntervalStr := os.Getenv(ENV_SG_KEEPALIVE_INTERVAL) 74 if keepaliveIntervalStr != "" { 75 var err error 76 keepaliveInterval, err = strconv.Atoi(keepaliveIntervalStr) 77 if err != nil { 78 device.log.Errorf("PostConfig: parsing keepalive interval failed: %s", err) 79 keepaliveInterval = DEFAULT_KEEPALIVE_INTERVAL 80 } 81 } else { 82 keepaliveInterval = DEFAULT_KEEPALIVE_INTERVAL 83 } 84 85 monitorIPInterval := DEFAULT_MONITOR_IP_INTERVAL 86 monitorIPIntervalStr := os.Getenv(ENV_SG_MONITOR_IP_INTERVAL) 87 if monitorIPIntervalStr != "" { 88 var err error 89 monitorIPInterval, err = strconv.Atoi(monitorIPIntervalStr) 90 if err != nil { 91 device.log.Errorf("PostConfig: parse monitor IP interval failed: %s", err) 92 monitorIPInterval = DEFAULT_MONITOR_IP_INTERVAL 93 } 94 } 95 96 // ifBindInterfaceStr := os.Getenv(ENV_SG_IF_BIND_INTERFACE) 97 // if ifBindInterfaceStr == "" { 98 // ifBindInterface = false 99 // } else { 100 // var err error 101 // ifBindInterface, err = strconv.ParseBool(ifBindInterfaceStr) 102 // if err != nil { 103 // device.log.Errorf("PostConfig: ifBindInterface is invalid: %s -> %s", ifBindInterfaceStr, err) 104 // panic(err) 105 // } 106 // } 107 108 var err error 109 finishCh = make(chan chan bool) 110 111 // ipaddrs, err := device.getIP() 112 // if err != nil { 113 // device.log.Errorf("PostConfig: failed to get IP address: %s", err) 114 // panic(err) 115 // } 116 // interfaceIP = strings.Join(ipaddrs, ",") 117 // interfaceIPArr = make([][4]byte, 0) 118 // for _, ipaddr := range ipaddrs { 119 // tmp := net.ParseIP(ipaddr) 120 // interfaceIPArr[0] = tmp[12] 121 // interfaceIPArr[1] = tmp[13] 122 // interfaceIPArr[2] = tmp[14] 123 // interfaceIPArr[3] = tmp[15] 124 // } 125 126 err = device.configTunDevice() 127 if err != nil { 128 device.log.Errorf("PostConfig: config TUN device failed: %s", err) 129 panic(err) 130 } 131 sk, err := device.loginKMS(kmsURL, configRoot) 132 if err != nil { 133 device.log.Errorf("PostConfig: login KMS failed: %s", err) 134 panic(err) 135 } 136 err = device.setSKandPort(sk) 137 if err != nil { 138 device.log.Errorf("PostConfig: set private key and listen port failed: %s", err) 139 panic(err) 140 } 141 err = device.initRules() 142 if err != nil { 143 device.log.Errorf("PostConfig: initialize rules failed: %s", err) 144 panic(err) 145 } 146 // interfaceIPStr, err := getIP() 147 // if err != nil { 148 // panic(err) 149 // } 150 // interfaceIP = interfaceIPStr 151 152 zkCli, err = CreateZkDiscovery(zkUrl, device, fmt.Sprintf("/%s", groupName)) 153 if err != nil { 154 device.log.Errorf("PostConfig: create zookeeper client failed: %s", err) 155 panic(err) 156 } 157 device.monitorIP(true) 158 go func() { 159 var ch chan bool 160 OUTER: 161 for { 162 select { 163 case ch = <-finishCh: 164 break OUTER 165 case <-time.After(time.Second * time.Duration(monitorIPInterval)): 166 device.monitorIP(false) 167 } 168 } 169 ch <- true 170 }() 171 } 172 173 func (device *Device) monitorIP(init bool) { 174 zkCli.Lock() 175 pathExist, err := zkCli.ExistPeer(device.staticIdentity.publicKey) 176 if err != nil { 177 device.log.Errorf("PostConfig: error when get path from zk: %s", err) 178 } 179 if !pathExist && interfaceIP != "" && !init { 180 ifAddrs := strings.Split(interfaceIP, ",") 181 rpAddrs := make([]string, 0) 182 for _, rpAddr := range ifAddrs { 183 rpAddrs = append(rpAddrs, fmt.Sprintf("%s:%d", rpAddr, device.net.port)) 184 } 185 zkCli.AddPeer(device.staticIdentity.publicKey, strings.Join(rpAddrs, ",")) 186 } 187 zkCli.Unlock() 188 //TODO: 检查zookeeper上的IP地址是否为空 189 data := "" 190 data, _ = zkCli.GetData(device.staticIdentity.publicKey) 191 ipaddrs, err := device.getIP() 192 if err != nil { 193 device.log.Errorf("PostConfig: error when get interface IP: %s", err) 194 } else { 195 sort.Strings(ipaddrs) 196 if data == "" || strings.Join(ipaddrs, ",") != interfaceIP || (ifBindInterface && init) { 197 zkCli.Lock() 198 if data == "" || !(interfaceIP == "" || (ifBindInterface && init)) { 199 zkCli.RemovePeer(device.staticIdentity.publicKey) 200 } 201 reportAddrs := make([]string, 0) 202 for _, reportAddr := range ipaddrs { 203 reportAddrs = append(reportAddrs, fmt.Sprintf("%s:%d", reportAddr, device.net.port)) 204 } 205 zkCli.AddPeer(device.staticIdentity.publicKey, strings.Join(reportAddrs, ",")) 206 interfaceIP = strings.Join(ipaddrs, ",") 207 // tmp := net.ParseIP(ipaddrs[0]) 208 // interfaceIPArr[0] = tmp[12] 209 // interfaceIPArr[1] = tmp[13] 210 // interfaceIPArr[2] = tmp[14] 211 // interfaceIPArr[3] = tmp[15] 212 if ifBindInterface { 213 oldsgLen := len(sgIPArr) 214 interfaceIPArr = make([][4]byte, 0) 215 sgIPArr = make([][4]byte, 0) 216 217 for i, ipaddr := range ipaddrs { 218 tmp := net.ParseIP(ipaddr) 219 tmp2 := [4]byte{tmp[12], tmp[13], tmp[14], tmp[15]} 220 interfaceIPArr = append(interfaceIPArr, tmp2) 221 sgIPArr = append(sgIPArr, [4]byte{169, 254, interfaceIndex, byte(i + 1)}) 222 } 223 if oldsgLen < len(sgIPArr) { 224 for i := oldsgLen + 1; i <= len(sgIPArr); i++ { 225 staticip := fmt.Sprintf("169.254.%d.%d", interfaceIndex, i) 226 device.runCmd(false, "netsh", "interface", "ipv4", "add", "address", "name=\""+deviceName+"\"", staticip, "255.255.255.0") 227 device.log.Verbosef("configTunDevice: add address %s/24 for device %s\n", staticip, deviceName) 228 } 229 } else if oldsgLen > len(sgIPArr) { 230 for i := oldsgLen; i > len(sgIPArr); i-- { 231 staticip := fmt.Sprintf("169.254.%d.%d", interfaceIndex, i) 232 device.runCmd(false, "netsh", "interface", "ipv4", "delete", "address", "name=\""+deviceName+"\"", staticip) 233 device.log.Verbosef("configTunDevice: delete address %s/24 for device %s\n", staticip, deviceName) 234 } 235 } 236 } 237 238 zkCli.Unlock() 239 if ifBindInterface { 240 device.net.Lock() 241 device.net.ipv4Addr = ipaddrs[0] 242 device.net.Unlock() 243 if err := device.BindUpdate(); err != nil { 244 device.log.Errorf("PostConfig: failed to set ip address %s: %s", ipaddrs[0], err) 245 } 246 } 247 } 248 } 249 } 250 251 // login to KMS and get private key 252 func (device *Device) loginKMS(kmsURL, configRoot string) (NoisePrivateKey, error) { 253 skstr := os.Getenv(ENV_SG_PRIVATEKEY) 254 if skstr != "" { 255 return device.parsePrivateKey(skstr) 256 } 257 yidRoot := fmt.Sprintf("%s/%s", configRoot, groupName) 258 err := os.MkdirAll(yidRoot, os.ModePerm) 259 if err != nil { 260 device.log.Errorf("PostConfig: create config path %s failed: %s", yidRoot, err) 261 return NoisePrivateKey{}, err 262 } 263 yidPath := fmt.Sprintf("%s/%s", yidRoot, "kms.conf") 264 265 //login and get private key 266 var vui, yid, yentity string 267 ystr, err := os.ReadFile(yidPath) 268 if err != nil { 269 //create entity 270 vuin, yidn, yentityn, err := device.createEntity(kmsURL) 271 if err != nil { 272 device.log.Errorf("PostConfig: create entity failed: %s", err) 273 return NoisePrivateKey{}, err 274 } 275 err = os.WriteFile(yidPath, []byte(fmt.Sprintf("%s|%s|%s", vuin, yidn, yentityn)), 0666) 276 if err != nil { 277 device.log.Errorf("PostConfig: write YID and entity to file failed: %s", err) 278 return NoisePrivateKey{}, err 279 } 280 vui = vuin 281 yid = yidn 282 yentity = yentityn 283 } else { 284 ystrs := strings.Split(string(ystr), "|") 285 if len(ystrs) != 3 { 286 err := fmt.Errorf("format of config file is corrupt: %s", yidPath) 287 device.log.Errorf("PostConfig: %s", err) 288 return NoisePrivateKey{}, err 289 } 290 vui = strings.TrimSpace(ystrs[0]) 291 yid = strings.TrimSpace(ystrs[1]) 292 yentity = strings.TrimSpace(ystrs[2]) 293 } 294 err = device.loginEntity(kmsURL, vui, yid) 295 if err != nil { 296 device.log.Errorf("PostConfig: login entity failed: %s", err) 297 return NoisePrivateKey{}, err 298 } 299 yentityid, err := strconv.Atoi(yentity) 300 if err != nil { 301 device.log.Errorf("PostConfig: parse entity to number failed: %s", err) 302 return NoisePrivateKey{}, err 303 } 304 skstr, err = device.getPrivateKey(kmsURL, yentityid) 305 if err != nil { 306 device.log.Errorf("PostConfig: get private key from KMS failed: %s", err) 307 return NoisePrivateKey{}, err 308 } 309 return device.parsePrivateKey(skstr) 310 } 311 312 func (device *Device) getIP() ([]string, error) { 313 interfaceName := os.Getenv(ENV_SG_DEFAULT_IFACE) 314 if interfaceName == "" { 315 ip, err := gateway.DiscoverInterface() 316 if err != nil { 317 device.log.Errorf("PostConfig: error when discovery interface IP: %s", err) 318 return nil, err 319 } 320 if ip.To4() != nil { 321 return []string{ip.String()}, nil 322 } else { 323 err := errors.New("no valid IPv4 address") 324 device.log.Errorf("PostConfig: error when convert interface IP: %s", err) 325 return nil, err 326 } 327 } 328 329 iface, err := net.InterfaceByName(interfaceName) 330 if err != nil { 331 device.log.Errorf("PostConfig: error when get interface IP by name: %s", err) 332 return nil, err 333 } 334 335 addrs, err := iface.Addrs() 336 if err != nil { 337 device.log.Errorf("PostConfig: error when get IP address of interface: %s", err) 338 return nil, err 339 } 340 addrstrs := make([]string, 0) 341 for _, addr := range addrs { 342 ipNet, ok := addr.(*net.IPNet) 343 if ok && !ipNet.IP.IsLoopback() { 344 if ipNet.IP.To4() != nil { 345 addrstrs = append(addrstrs, ipNet.IP.String()) 346 //return ipNet.IP.String(), nil 347 } 348 } 349 } 350 //TODO:增加指定IP地址功能 351 if len(addrstrs) == 0 { 352 err = errors.New("no valid IPv4 address") 353 device.log.Errorf("PostConfig: error when get IP address of interface: %s", err) 354 return nil, err 355 } 356 sort.Strings(addrstrs) 357 return addrstrs, nil 358 } 359 360 // set up private key and UDP listen port 361 func (device *Device) setSKandPort(sk NoisePrivateKey) error { 362 listenPort := DEFAULT_LISTEN_PORT 363 portStr := os.Getenv(ENV_SG_LISTEN_PORT) 364 var err error 365 if portStr != "" { 366 listenPort, err = strconv.Atoi(portStr) 367 if err != nil { 368 device.log.Errorf("PostConfig: failed to parse listen port %s: %s", portStr, err) 369 } 370 } 371 //ischanged := false 372 device.ipcMutex.Lock() 373 //oldPK := device.staticIdentity.publicKey 374 defer device.ipcMutex.Unlock() 375 if !device.staticIdentity.privateKey.Equals(sk) { 376 device.SetPrivateKey(sk) 377 } 378 // ipv4 := "" 379 // ipv4s, err := device.getIP() 380 // if err != nil { 381 // device.log.Errorf("PostConfig: failed to get IP address: %s", err) 382 // } else if len(ipv4s) > 0 { 383 // ipv4 = ipv4s[0] 384 // } 385 var ipv4 string 386 if ifBindInterface && len(interfaceIPArr) > 0 { 387 ipv4 = fmt.Sprintf("%d.%d.%d.%d", interfaceIPArr[0][0], interfaceIPArr[0][1], interfaceIPArr[0][2], interfaceIPArr[0][3]) 388 } else { 389 ipv4s, err := device.getIP() 390 if err != nil { 391 device.log.Errorf("PostConfig: failed to get IP address: %s", err) 392 } else if len(ipv4s) > 0 { 393 sort.Strings(ipv4s) 394 ipv4 = ipv4s[0] 395 } 396 } 397 if device.net.port != uint16(listenPort) { 398 device.net.Lock() 399 if ifBindInterface && ipv4 != "" { 400 device.net.ipv4Addr = ipv4 401 } 402 device.net.port = uint16(listenPort) 403 device.net.Unlock() 404 if err := device.BindUpdate(); err != nil { 405 device.log.Errorf("PostConfig: failed to set listen_port %d: %s", listenPort, err) 406 return err 407 } 408 //ischanged = true 409 } 410 // if ischanged { 411 // //TODO: update zk 412 // zkCli.Lock() 413 // zkCli.RemovePeer(oldPK) 414 // zkCli.AddPeer(device.staticIdentity.publicKey) 415 // zkCli.Unlock() 416 // } 417 return nil 418 } 419 420 func (device *Device) AddPeer(pk NoisePublicKey, endPoints string) error { 421 endPointArr := strings.Split(endPoints, ",") 422 if len(endPointArr) == 0 { 423 device.log.Verbosef("PostConfig: no endpoints under public key %s", hex.EncodeToString(pk[:])) 424 return errors.New("no endpoints") 425 } 426 device.ipcMutex.Lock() 427 defer device.ipcMutex.Unlock() 428 peer := device.LookupPeer(pk) 429 if peer != nil { 430 //err := errors.New("peer exists") 431 device.log.Verbosef("PostConfig: peer %v exists", peer) 432 return nil 433 } 434 peer, err := device.NewPeer(pk) 435 if err != nil { 436 device.log.Errorf("PostConfig: create peer failed: %s", err) 437 return err 438 } 439 device.log.Verbosef("%v - PostConfig: Created", peer) 440 endpoint, err := device.net.bind.ParseEndpoint(endPointArr[0]) 441 if err != nil { 442 device.log.Errorf("%v - PostConfig: parse endpoint failed: %s", peer, err) 443 return err 444 } 445 peer.endpoint.Lock() 446 peer.endpoint.val = endpoint 447 peer.endpoint.Unlock() 448 device.log.Verbosef("%v - PostConfig: Updating endpoint %s", peer, endPointArr[0]) 449 450 device.allowedips.RemoveByPeer(peer) 451 for _, endPoint := range endPointArr { 452 allowedIP := fmt.Sprintf("%s/32", strings.Split(endPoint, ":")[0]) 453 prefix, err := netip.ParsePrefix(allowedIP) 454 if err != nil { 455 device.log.Errorf("%v - PostConfig: parse allowedIP failed: %s", peer, err) 456 return err 457 } 458 device.allowedips.Insert(prefix, peer) 459 device.log.Verbosef("%v - PostConfig: Adding allowedip: %s", peer, allowedIP) 460 } 461 462 old := peer.persistentKeepaliveInterval.Swap(uint32(keepaliveInterval)) 463 device.log.Verbosef("%v - PostConfig: Updating persistent keepalive interval %d", peer, keepaliveInterval) 464 465 ipcPeer := new(ipcSetPeer) 466 ipcPeer.Peer = peer 467 ipcPeer.dummy = false 468 ipcPeer.created = true 469 ipcPeer.pkaOn = old == 0 && keepaliveInterval != 0 470 471 if ipcPeer.created { 472 ipcPeer.endpoint.disableRoaming = ipcPeer.device.net.brokenRoaming && ipcPeer.endpoint.val != nil 473 } 474 if ipcPeer.device.isUp() { 475 ipcPeer.Start() 476 if ipcPeer.pkaOn { 477 ipcPeer.SendKeepalive() 478 } 479 ipcPeer.SendStagedPackets() 480 } 481 return nil 482 } 483 484 func (device *Device) DeletePeer(pk NoisePublicKey) { 485 device.ipcMutex.Lock() 486 defer device.ipcMutex.Unlock() 487 peer := device.LookupPeer(pk) 488 if peer != nil { 489 device.RemovePeer(peer.handshake.remoteStatic) 490 } 491 } 492 493 func (device *Device) ClearConfig() { 494 ch := make(chan bool) 495 finishCh <- ch 496 <-ch 497 zkCli.Close() 498 device.clearConfigOSSpecific() 499 } 500 501 func (device *Device) IterPeerEndpoint(f func(string)) { 502 device.peers.RLock() 503 defer device.peers.RUnlock() 504 for _, peer := range device.peers.keyMap { 505 dstIP := peer.endpoint.val.DstIP().String() 506 f(dstIP) 507 } 508 }