github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/api_stream.go (about) 1 package nodes 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs" 8 "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" 9 "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" 10 "github.com/TeaOSLab/EdgeNode/internal/caches" 11 "github.com/TeaOSLab/EdgeNode/internal/configs" 12 teaconst "github.com/TeaOSLab/EdgeNode/internal/const" 13 "github.com/TeaOSLab/EdgeNode/internal/errors" 14 "github.com/TeaOSLab/EdgeNode/internal/events" 15 "github.com/TeaOSLab/EdgeNode/internal/firewalls" 16 "github.com/TeaOSLab/EdgeNode/internal/goman" 17 "github.com/TeaOSLab/EdgeNode/internal/remotelogs" 18 "github.com/TeaOSLab/EdgeNode/internal/rpc" 19 executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec" 20 "github.com/iwind/TeaGo/Tea" 21 "github.com/iwind/TeaGo/maps" 22 "net/url" 23 "regexp" 24 "runtime" 25 "strconv" 26 "time" 27 ) 28 29 type APIStream struct { 30 stream pb.NodeService_NodeStreamClient 31 32 isQuiting bool 33 cancelFunc context.CancelFunc 34 } 35 36 func NewAPIStream() *APIStream { 37 return &APIStream{} 38 } 39 40 func (this *APIStream) Start() { 41 events.OnKey(events.EventQuit, this, func() { 42 this.isQuiting = true 43 if this.cancelFunc != nil { 44 this.cancelFunc() 45 } 46 }) 47 for { 48 if this.isQuiting { 49 return 50 } 51 err := this.loop() 52 if err != nil { 53 if rpc.IsConnError(err) { 54 remotelogs.Debug("API_STREAM", err.Error()) 55 } else { 56 remotelogs.Warn("API_STREAM", err.Error()) 57 } 58 time.Sleep(10 * time.Second) 59 continue 60 } 61 time.Sleep(1 * time.Second) 62 } 63 } 64 65 func (this *APIStream) loop() error { 66 rpcClient, err := rpc.SharedRPC() 67 if err != nil { 68 return errors.Wrap(err) 69 } 70 71 ctx, cancelFunc := context.WithCancel(rpcClient.Context()) 72 this.cancelFunc = cancelFunc 73 74 defer func() { 75 cancelFunc() 76 }() 77 78 nodeStream, err := rpcClient.NodeRPC.NodeStream(ctx) 79 if err != nil { 80 if this.isQuiting { 81 return nil 82 } 83 return err 84 } 85 this.stream = nodeStream 86 87 for { 88 if this.isQuiting { 89 remotelogs.Println("API_STREAM", "quit") 90 break 91 } 92 93 message, streamErr := nodeStream.Recv() 94 if streamErr != nil { 95 if this.isQuiting { 96 remotelogs.Println("API_STREAM", "quit") 97 return nil 98 } 99 return streamErr 100 } 101 102 // 处理消息 103 switch message.Code { 104 case messageconfigs.MessageCodeConnectedAPINode: // 连接API节点成功 105 err = this.handleConnectedAPINode(message) 106 case messageconfigs.MessageCodeWriteCache: // 写入缓存 107 err = this.handleWriteCache(message) 108 case messageconfigs.MessageCodeReadCache: // 读取缓存 109 err = this.handleReadCache(message) 110 case messageconfigs.MessageCodeStatCache: // 统计缓存 111 err = this.handleStatCache(message) 112 case messageconfigs.MessageCodeCleanCache: // 清理缓存 113 err = this.handleCleanCache(message) 114 case messageconfigs.MessageCodeNewNodeTask: // 有新的任务 115 err = this.handleNewNodeTask(message) 116 case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务 117 err = this.handleCheckSystemdService(message) 118 case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙 119 err = this.handleCheckLocalFirewall(message) 120 case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址 121 err = this.handleChangeAPINode(message) 122 default: 123 err = this.handleUnknownMessage(message) 124 } 125 if err != nil { 126 remotelogs.Error("API_STREAM", "handle message failed: "+err.Error()) 127 } 128 } 129 130 return nil 131 } 132 133 // 连接API节点成功 134 func (this *APIStream) handleConnectedAPINode(message *pb.NodeStreamMessage) error { 135 // 更改连接的APINode信息 136 if len(message.DataJSON) == 0 { 137 return nil 138 } 139 msg := &messageconfigs.ConnectedAPINodeMessage{} 140 err := json.Unmarshal(message.DataJSON, msg) 141 if err != nil { 142 return errors.Wrap(err) 143 } 144 145 _, err = rpc.SharedRPC() 146 if err != nil { 147 return errors.Wrap(err) 148 } 149 150 remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'") 151 152 // 重新读取配置 153 if nodeConfigUpdatedAt == 0 { 154 select { 155 case nodeConfigChangedNotify <- true: 156 default: 157 158 } 159 } 160 161 return nil 162 } 163 164 // 写入缓存 165 func (this *APIStream) handleWriteCache(message *pb.NodeStreamMessage) error { 166 msg := &messageconfigs.WriteCacheMessage{} 167 err := json.Unmarshal(message.DataJSON, msg) 168 if err != nil { 169 this.replyFail(message.RequestId, "decode message data failed: "+err.Error()) 170 return err 171 } 172 173 storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON) 174 if err != nil { 175 return err 176 } 177 if shouldStop { 178 defer func() { 179 storage.Stop() 180 }() 181 } 182 183 expiredAt := time.Now().Unix() + msg.LifeSeconds 184 writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, -1, int64(len(msg.Value)), -1, false) 185 if err != nil { 186 this.replyFail(message.RequestId, "prepare writing failed: "+err.Error()) 187 return err 188 } 189 190 // 写入一个空的Header 191 _, err = writer.WriteHeader([]byte(":")) 192 if err != nil { 193 this.replyFail(message.RequestId, "write failed: "+err.Error()) 194 return err 195 } 196 197 // 写入数据 198 _, err = writer.Write(msg.Value) 199 if err != nil { 200 this.replyFail(message.RequestId, "write failed: "+err.Error()) 201 return err 202 } 203 204 err = writer.Close() 205 if err != nil { 206 this.replyFail(message.RequestId, "write failed: "+err.Error()) 207 return err 208 } 209 storage.AddToList(&caches.Item{ 210 Type: writer.ItemType(), 211 Key: msg.Key, 212 ExpiresAt: expiredAt, 213 HeaderSize: writer.HeaderSize(), 214 BodySize: writer.BodySize(), 215 }) 216 217 this.replyOk(message.RequestId, "write ok") 218 219 return nil 220 } 221 222 // 读取缓存 223 func (this *APIStream) handleReadCache(message *pb.NodeStreamMessage) error { 224 msg := &messageconfigs.ReadCacheMessage{} 225 err := json.Unmarshal(message.DataJSON, msg) 226 if err != nil { 227 this.replyFail(message.RequestId, "decode message data failed: "+err.Error()) 228 return err 229 } 230 231 storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON) 232 if err != nil { 233 return err 234 } 235 if shouldStop { 236 defer func() { 237 storage.Stop() 238 }() 239 } 240 241 reader, err := storage.OpenReader(msg.Key, false, false) 242 if err != nil { 243 if err == caches.ErrNotFound { 244 this.replyFail(message.RequestId, "key not found") 245 return nil 246 } 247 this.replyFail(message.RequestId, "read key failed: "+err.Error()) 248 return nil 249 } 250 defer func() { 251 _ = reader.Close() 252 }() 253 254 this.replyOk(message.RequestId, "value "+strconv.FormatInt(reader.BodySize(), 10)+" bytes") 255 256 return nil 257 } 258 259 // 统计缓存 260 func (this *APIStream) handleStatCache(message *pb.NodeStreamMessage) error { 261 msg := &messageconfigs.ReadCacheMessage{} 262 err := json.Unmarshal(message.DataJSON, msg) 263 if err != nil { 264 this.replyFail(message.RequestId, "decode message data failed: "+err.Error()) 265 return err 266 } 267 268 storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON) 269 if err != nil { 270 return err 271 } 272 if shouldStop { 273 defer func() { 274 storage.Stop() 275 }() 276 } 277 278 stat, err := storage.Stat() 279 if err != nil { 280 this.replyFail(message.RequestId, "stat failed: "+err.Error()) 281 return err 282 } 283 284 sizeFormat := "" 285 if stat.Size < (1 << 10) { 286 sizeFormat = strconv.FormatInt(stat.Size, 10) + " Bytes" 287 } else if stat.Size < (1 << 20) { 288 sizeFormat = fmt.Sprintf("%.2f KiB", float64(stat.Size)/(1<<10)) 289 } else if stat.Size < (1 << 30) { 290 sizeFormat = fmt.Sprintf("%.2f MiB", float64(stat.Size)/(1<<20)) 291 } else { 292 sizeFormat = fmt.Sprintf("%.2f GiB", float64(stat.Size)/(1<<30)) 293 } 294 this.replyOk(message.RequestId, "size:"+sizeFormat+", count:"+strconv.Itoa(stat.Count)) 295 296 return nil 297 } 298 299 // 清理缓存 300 func (this *APIStream) handleCleanCache(message *pb.NodeStreamMessage) error { 301 msg := &messageconfigs.ReadCacheMessage{} 302 err := json.Unmarshal(message.DataJSON, msg) 303 if err != nil { 304 this.replyFail(message.RequestId, "decode message data failed: "+err.Error()) 305 return err 306 } 307 308 storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON) 309 if err != nil { 310 return err 311 } 312 if shouldStop { 313 defer func() { 314 storage.Stop() 315 }() 316 } 317 318 err = storage.CleanAll() 319 if err != nil { 320 this.replyFail(message.RequestId, "clean cache failed: "+err.Error()) 321 return err 322 } 323 324 this.replyOk(message.RequestId, "ok") 325 326 return nil 327 } 328 329 // 处理配置变化 330 func (this *APIStream) handleNewNodeTask(message *pb.NodeStreamMessage) error { 331 select { 332 case nodeTaskNotify <- true: 333 default: 334 335 } 336 this.replyOk(message.RequestId, "ok") 337 return nil 338 } 339 340 // 检查Systemd服务 341 func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage) error { 342 systemctl, err := executils.LookPath("systemctl") 343 if err != nil { 344 this.replyFail(message.RequestId, "'systemctl' not found") 345 return nil 346 } 347 if len(systemctl) == 0 { 348 this.replyFail(message.RequestId, "'systemctl' not found") 349 return nil 350 } 351 352 var shortName = teaconst.SystemdServiceName 353 var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName) 354 cmd.WithStdout() 355 err = cmd.Run() 356 if err != nil { 357 this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error()) 358 return nil 359 } 360 if cmd.Stdout() == "enabled" { 361 this.replyOk(message.RequestId, "ok") 362 } else { 363 this.replyFail(message.RequestId, "not installed") 364 } 365 return nil 366 } 367 368 // 检查本地防火墙 369 func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error { 370 var dataMessage = &messageconfigs.CheckLocalFirewallMessage{} 371 err := json.Unmarshal(message.DataJSON, dataMessage) 372 if err != nil { 373 this.replyFail(message.RequestId, "decode message data failed: "+err.Error()) 374 return nil 375 } 376 377 // nft 378 if dataMessage.Name == "nftables" { 379 if runtime.GOOS != "linux" { 380 this.replyFail(message.RequestId, "not Linux system") 381 return nil 382 } 383 384 nft, err := executils.LookPath("nft") 385 if err != nil { 386 this.replyFail(message.RequestId, "'nft' not found: "+err.Error()) 387 return nil 388 } 389 390 var cmd = executils.NewTimeoutCmd(10*time.Second, nft, "--version") 391 cmd.WithStdout() 392 err = cmd.Run() 393 if err != nil { 394 this.replyFail(message.RequestId, "get version failed: "+err.Error()) 395 return nil 396 } 397 398 var outputString = cmd.Stdout() 399 var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString) 400 if len(versionMatches) <= 1 { 401 this.replyFail(message.RequestId, "can not get nft version") 402 return nil 403 } 404 var version = versionMatches[1] 405 406 var result = maps.Map{ 407 "version": version, 408 } 409 410 var protectionConfig = sharedNodeConfig.DDoSProtection 411 err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig) 412 if err != nil { 413 this.replyFail(message.RequestId, dataMessage.Name+" was installed, but apply DDoS protection config failed: "+err.Error()) 414 } else { 415 this.replyOk(message.RequestId, string(result.AsJSON())) 416 } 417 } else { 418 this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'") 419 } 420 421 return nil 422 } 423 424 // 修改API地址 425 func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error { 426 config, err := configs.LoadAPIConfig() 427 if err != nil { 428 this.replyFail(message.RequestId, "read config error: "+err.Error()) 429 return nil 430 } 431 432 var messageData = &messageconfigs.ChangeAPINodeMessage{} 433 err = json.Unmarshal(message.DataJSON, messageData) 434 if err != nil { 435 this.replyFail(message.RequestId, "unmarshal message failed: "+err.Error()) 436 return nil 437 } 438 439 _, err = url.Parse(messageData.Addr) 440 if err != nil { 441 this.replyFail(message.RequestId, "invalid new api node address: '"+messageData.Addr+"'") 442 return nil 443 } 444 445 config.RPCEndpoints = []string{messageData.Addr} 446 447 // 保存到文件 448 err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName)) 449 if err != nil { 450 this.replyFail(message.RequestId, "save config file failed: "+err.Error()) 451 return nil 452 } 453 454 this.replyOk(message.RequestId, "") 455 456 goman.New(func() { 457 // 延后生效,防止变更前的API无法读取到状态 458 time.Sleep(1 * time.Second) 459 460 rpcClient, err := rpc.SharedRPC() 461 if err != nil { 462 remotelogs.Error("API_STREAM", "change rpc endpoint to '"+ 463 messageData.Addr+"' failed: "+err.Error()) 464 return 465 } 466 467 rpcClient.Close() 468 469 err = rpcClient.UpdateConfig(config) 470 if err != nil { 471 remotelogs.Error("API_STREAM", "change rpc endpoint to '"+ 472 messageData.Addr+"' failed: "+err.Error()) 473 return 474 } 475 476 remotelogs.Println("API_STREAM", "change rpc endpoint to '"+ 477 messageData.Addr+"' successfully") 478 }) 479 480 return nil 481 } 482 483 // 处理未知消息 484 func (this *APIStream) handleUnknownMessage(message *pb.NodeStreamMessage) error { 485 this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'") 486 return nil 487 } 488 489 // 回复失败 490 func (this *APIStream) replyFail(requestId int64, message string) { 491 _ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: false, Message: message}) 492 } 493 494 // 回复成功 495 func (this *APIStream) replyOk(requestId int64, message string) { 496 _ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message}) 497 } 498 499 // 回复成功并包含数据 500 func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) { 501 _ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON}) 502 } 503 504 // 获取缓存存取对象 505 func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) { 506 cachePolicy := &serverconfigs.HTTPCachePolicy{} 507 err = json.Unmarshal(cachePolicyJSON, cachePolicy) 508 if err != nil { 509 this.replyFail(message.RequestId, "decode cache policy config failed: "+err.Error()) 510 return nil, false, err 511 } 512 513 storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id) 514 if storage == nil { 515 storage = caches.SharedManager.NewStorageWithPolicy(cachePolicy) 516 if storage == nil { 517 this.replyFail(message.RequestId, "invalid storage type '"+cachePolicy.Type+"'") 518 return nil, false, errors.New("invalid storage type '" + cachePolicy.Type + "'") 519 } 520 err = storage.Init() 521 if err != nil { 522 this.replyFail(message.RequestId, "storage init failed: "+err.Error()) 523 return nil, false, err 524 } 525 shouldStop = true 526 } 527 528 return 529 }