github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/slave.go (about) 1 package cluster 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 model "github.com/cloudreve/Cloudreve/v3/models" 9 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" 10 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 11 "github.com/cloudreve/Cloudreve/v3/pkg/auth" 12 "github.com/cloudreve/Cloudreve/v3/pkg/conf" 13 "github.com/cloudreve/Cloudreve/v3/pkg/request" 14 "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 15 "github.com/cloudreve/Cloudreve/v3/pkg/util" 16 "io" 17 "net/url" 18 "strings" 19 "sync" 20 "time" 21 ) 22 23 type SlaveNode struct { 24 Model *model.Node 25 Active bool 26 27 caller slaveCaller 28 callback func(bool, uint) 29 close chan bool 30 lock sync.RWMutex 31 } 32 33 type slaveCaller struct { 34 parent *SlaveNode 35 Client request.Client 36 } 37 38 // Init 初始化节点 39 func (node *SlaveNode) Init(nodeModel *model.Node) { 40 node.lock.Lock() 41 node.Model = nodeModel 42 43 // Init http request client 44 var endpoint *url.URL 45 if serverURL, err := url.Parse(node.Model.Server); err == nil { 46 var controller *url.URL 47 controller, _ = url.Parse("/api/v3/slave/") 48 endpoint = serverURL.ResolveReference(controller) 49 } 50 51 signTTL := model.GetIntSetting("slave_api_timeout", 60) 52 node.caller.Client = request.NewClient( 53 request.WithMasterMeta(), 54 request.WithTimeout(time.Duration(signTTL)*time.Second), 55 request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)), 56 request.WithEndpoint(endpoint.String()), 57 ) 58 59 node.caller.parent = node 60 if node.close != nil { 61 node.lock.Unlock() 62 node.close <- true 63 go node.StartPingLoop() 64 } else { 65 node.Active = true 66 node.lock.Unlock() 67 go node.StartPingLoop() 68 } 69 } 70 71 // IsFeatureEnabled 查询节点的某项功能是否启用 72 func (node *SlaveNode) IsFeatureEnabled(feature string) bool { 73 node.lock.RLock() 74 defer node.lock.RUnlock() 75 76 switch feature { 77 case "aria2": 78 return node.Model.Aria2Enabled 79 default: 80 return false 81 } 82 } 83 84 // SubscribeStatusChange 订阅节点状态更改 85 func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) { 86 node.lock.Lock() 87 node.callback = callback 88 node.lock.Unlock() 89 } 90 91 // Ping 从机节点,返回从机负载 92 func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { 93 node.lock.RLock() 94 defer node.lock.RUnlock() 95 96 reqBodyEncoded, err := json.Marshal(req) 97 if err != nil { 98 return nil, err 99 } 100 101 bodyReader := strings.NewReader(string(reqBodyEncoded)) 102 103 resp, err := node.caller.Client.Request( 104 "POST", 105 "heartbeat", 106 bodyReader, 107 ).CheckHTTPResponse(200).DecodeResponse() 108 if err != nil { 109 return nil, err 110 } 111 112 // 处理列取结果 113 if resp.Code != 0 { 114 return nil, serializer.NewErrorFromResponse(resp) 115 } 116 117 var res serializer.NodePingResp 118 119 if resStr, ok := resp.Data.(string); ok { 120 err = json.Unmarshal([]byte(resStr), &res) 121 if err != nil { 122 return nil, err 123 } 124 } 125 126 return &res, nil 127 } 128 129 // IsActive 返回节点是否在线 130 func (node *SlaveNode) IsActive() bool { 131 node.lock.RLock() 132 defer node.lock.RUnlock() 133 134 return node.Active 135 } 136 137 // Kill 结束节点内相关循环 138 func (node *SlaveNode) Kill() { 139 node.lock.RLock() 140 defer node.lock.RUnlock() 141 142 if node.close != nil { 143 close(node.close) 144 } 145 } 146 147 // GetAria2Instance 获取从机Aria2实例 148 func (node *SlaveNode) GetAria2Instance() common.Aria2 { 149 node.lock.RLock() 150 defer node.lock.RUnlock() 151 152 if !node.Model.Aria2Enabled { 153 return &common.DummyAria2{} 154 } 155 156 return &node.caller 157 } 158 159 func (node *SlaveNode) ID() uint { 160 node.lock.RLock() 161 defer node.lock.RUnlock() 162 163 return node.Model.ID 164 } 165 166 func (node *SlaveNode) StartPingLoop() { 167 node.lock.Lock() 168 node.close = make(chan bool) 169 node.lock.Unlock() 170 171 tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second 172 recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second 173 pingTicker := time.Duration(0) 174 175 util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name) 176 retry := 0 177 recoverMode := false 178 isFirstLoop := true 179 180 loop: 181 for { 182 select { 183 case <-time.After(pingTicker): 184 if pingTicker == 0 { 185 pingTicker = tickDuration 186 } 187 188 util.Log().Debug("Slave node %q send ping.", node.Model.Name) 189 res, err := node.Ping(node.getHeartbeatContent(isFirstLoop)) 190 isFirstLoop = false 191 192 if err != nil { 193 util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err) 194 retry++ 195 if retry >= model.GetIntSetting("slave_node_retry", 3) { 196 util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name) 197 node.changeStatus(false) 198 199 if !recoverMode { 200 // 启动恢复监控循环 201 util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name) 202 pingTicker = recoverDuration 203 recoverMode = true 204 } 205 } 206 } else { 207 if recoverMode { 208 util.Log().Debug("Slave node %q recovered.", node.Model.Name) 209 pingTicker = tickDuration 210 recoverMode = false 211 isFirstLoop = true 212 } 213 214 util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res) 215 node.changeStatus(true) 216 retry = 0 217 } 218 219 case <-node.close: 220 util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name) 221 break loop 222 } 223 } 224 } 225 226 func (node *SlaveNode) IsMater() bool { 227 return false 228 } 229 230 func (node *SlaveNode) MasterAuthInstance() auth.Auth { 231 node.lock.RLock() 232 defer node.lock.RUnlock() 233 234 return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} 235 } 236 237 func (node *SlaveNode) SlaveAuthInstance() auth.Auth { 238 node.lock.RLock() 239 defer node.lock.RUnlock() 240 241 return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)} 242 } 243 244 func (node *SlaveNode) DBModel() *model.Node { 245 node.lock.RLock() 246 defer node.lock.RUnlock() 247 248 return node.Model 249 } 250 251 // getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave 252 func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq { 253 return &serializer.NodePingReq{ 254 SiteURL: model.GetSiteURL().String(), 255 IsUpdate: isUpdate, 256 SiteID: model.GetSettingByName("siteID"), 257 Node: node.Model, 258 CredentialTTL: model.GetIntSetting("slave_api_timeout", 60), 259 } 260 } 261 262 func (node *SlaveNode) changeStatus(isActive bool) { 263 node.lock.RLock() 264 id := node.Model.ID 265 if isActive != node.Active { 266 node.lock.RUnlock() 267 node.lock.Lock() 268 node.Active = isActive 269 node.lock.Unlock() 270 node.callback(isActive, id) 271 } else { 272 node.lock.RUnlock() 273 } 274 275 } 276 277 func (s *slaveCaller) Init() error { 278 return nil 279 } 280 281 // SendAria2Call send remote aria2 call to slave node 282 func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) { 283 reqReader, err := getAria2RequestBody(body) 284 if err != nil { 285 return nil, err 286 } 287 288 return s.Client.Request( 289 "POST", 290 "aria2/"+scope, 291 reqReader, 292 ).CheckHTTPResponse(200).DecodeResponse() 293 } 294 295 func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) { 296 s.parent.lock.RLock() 297 defer s.parent.lock.RUnlock() 298 299 req := &serializer.SlaveAria2Call{ 300 Task: task, 301 GroupOptions: options, 302 } 303 304 res, err := s.SendAria2Call(req, "task") 305 if err != nil { 306 return "", err 307 } 308 309 if res.Code != 0 { 310 return "", serializer.NewErrorFromResponse(res) 311 } 312 313 return res.Data.(string), err 314 } 315 316 func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) { 317 s.parent.lock.RLock() 318 defer s.parent.lock.RUnlock() 319 320 req := &serializer.SlaveAria2Call{ 321 Task: task, 322 } 323 324 res, err := s.SendAria2Call(req, "status") 325 if err != nil { 326 return rpc.StatusInfo{}, err 327 } 328 329 if res.Code != 0 { 330 return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res) 331 } 332 333 var status rpc.StatusInfo 334 res.GobDecode(&status) 335 336 return status, err 337 } 338 339 func (s *slaveCaller) Cancel(task *model.Download) error { 340 s.parent.lock.RLock() 341 defer s.parent.lock.RUnlock() 342 343 req := &serializer.SlaveAria2Call{ 344 Task: task, 345 } 346 347 res, err := s.SendAria2Call(req, "cancel") 348 if err != nil { 349 return err 350 } 351 352 if res.Code != 0 { 353 return serializer.NewErrorFromResponse(res) 354 } 355 356 return nil 357 } 358 359 func (s *slaveCaller) Select(task *model.Download, files []int) error { 360 s.parent.lock.RLock() 361 defer s.parent.lock.RUnlock() 362 363 req := &serializer.SlaveAria2Call{ 364 Task: task, 365 Files: files, 366 } 367 368 res, err := s.SendAria2Call(req, "select") 369 if err != nil { 370 return err 371 } 372 373 if res.Code != 0 { 374 return serializer.NewErrorFromResponse(res) 375 } 376 377 return nil 378 } 379 380 func (s *slaveCaller) GetConfig() model.Aria2Option { 381 s.parent.lock.RLock() 382 defer s.parent.lock.RUnlock() 383 384 return s.parent.Model.Aria2OptionsSerialized 385 } 386 387 func (s *slaveCaller) DeleteTempFile(task *model.Download) error { 388 s.parent.lock.RLock() 389 defer s.parent.lock.RUnlock() 390 391 req := &serializer.SlaveAria2Call{ 392 Task: task, 393 } 394 395 res, err := s.SendAria2Call(req, "delete") 396 if err != nil { 397 return err 398 } 399 400 if res.Code != 0 { 401 return serializer.NewErrorFromResponse(res) 402 } 403 404 return nil 405 } 406 407 func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { 408 reqBodyEncoded, err := json.Marshal(body) 409 if err != nil { 410 return nil, err 411 } 412 413 return strings.NewReader(string(reqBodyEncoded)), nil 414 } 415 416 // RemoteCallback 发送远程存储策略上传回调请求 417 func RemoteCallback(url string, body serializer.UploadCallback) error { 418 callbackBody, err := json.Marshal(struct { 419 Data serializer.UploadCallback `json:"data"` 420 }{ 421 Data: body, 422 }) 423 if err != nil { 424 return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err) 425 } 426 427 resp := request.GeneralClient.Request( 428 "POST", 429 url, 430 bytes.NewReader(callbackBody), 431 request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), 432 request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), 433 ) 434 435 if resp.Err != nil { 436 return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err) 437 } 438 439 // 解析回调服务端响应 440 response, err := resp.DecodeResponse() 441 if err != nil { 442 msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode) 443 return serializer.NewError(serializer.CodeCallbackError, msg, err) 444 } 445 446 if response.Code != 0 { 447 return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) 448 } 449 450 return nil 451 }