github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/master.go (about) 1 package cluster 2 3 import ( 4 "context" 5 "encoding/json" 6 model "github.com/cloudreve/Cloudreve/v3/models" 7 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" 8 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 9 "github.com/cloudreve/Cloudreve/v3/pkg/auth" 10 "github.com/cloudreve/Cloudreve/v3/pkg/mq" 11 "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 12 "github.com/cloudreve/Cloudreve/v3/pkg/util" 13 "github.com/gofrs/uuid" 14 "net/url" 15 "os" 16 "path/filepath" 17 "strconv" 18 "strings" 19 "sync" 20 "time" 21 ) 22 23 const ( 24 deleteTempFileDuration = 60 * time.Second 25 statusRetryDuration = 10 * time.Second 26 ) 27 28 type MasterNode struct { 29 Model *model.Node 30 aria2RPC rpcService 31 lock sync.RWMutex 32 } 33 34 // RPCService 通过RPC服务的Aria2任务管理器 35 type rpcService struct { 36 Caller rpc.Client 37 Initialized bool 38 39 retryDuration time.Duration 40 deletePaddingDuration time.Duration 41 parent *MasterNode 42 options *clientOptions 43 } 44 45 type clientOptions struct { 46 Options map[string]interface{} // 创建下载时额外添加的设置 47 } 48 49 // Init 初始化节点 50 func (node *MasterNode) Init(nodeModel *model.Node) { 51 node.lock.Lock() 52 node.Model = nodeModel 53 node.aria2RPC.parent = node 54 node.aria2RPC.retryDuration = statusRetryDuration 55 node.aria2RPC.deletePaddingDuration = deleteTempFileDuration 56 node.lock.Unlock() 57 58 node.lock.RLock() 59 if node.Model.Aria2Enabled { 60 node.lock.RUnlock() 61 node.aria2RPC.Init() 62 return 63 } 64 node.lock.RUnlock() 65 } 66 67 func (node *MasterNode) ID() uint { 68 node.lock.RLock() 69 defer node.lock.RUnlock() 70 71 return node.Model.ID 72 } 73 74 func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { 75 return &serializer.NodePingResp{}, nil 76 } 77 78 // IsFeatureEnabled 查询节点的某项功能是否启用 79 func (node *MasterNode) IsFeatureEnabled(feature string) bool { 80 node.lock.RLock() 81 defer node.lock.RUnlock() 82 83 switch feature { 84 case "aria2": 85 return node.Model.Aria2Enabled 86 default: 87 return false 88 } 89 } 90 91 func (node *MasterNode) MasterAuthInstance() auth.Auth { 92 node.lock.RLock() 93 defer node.lock.RUnlock() 94 95 return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} 96 } 97 98 func (node *MasterNode) SlaveAuthInstance() auth.Auth { 99 node.lock.RLock() 100 defer node.lock.RUnlock() 101 102 return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)} 103 } 104 105 // SubscribeStatusChange 订阅节点状态更改 106 func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) { 107 } 108 109 // IsActive 返回节点是否在线 110 func (node *MasterNode) IsActive() bool { 111 return true 112 } 113 114 // Kill 结束aria2请求 115 func (node *MasterNode) Kill() { 116 if node.aria2RPC.Caller != nil { 117 node.aria2RPC.Caller.Close() 118 } 119 } 120 121 // GetAria2Instance 获取主机Aria2实例 122 func (node *MasterNode) GetAria2Instance() common.Aria2 { 123 node.lock.RLock() 124 125 if !node.Model.Aria2Enabled { 126 node.lock.RUnlock() 127 return &common.DummyAria2{} 128 } 129 130 if !node.aria2RPC.Initialized { 131 node.lock.RUnlock() 132 node.aria2RPC.Init() 133 return &common.DummyAria2{} 134 } 135 136 defer node.lock.RUnlock() 137 return &node.aria2RPC 138 } 139 140 func (node *MasterNode) IsMater() bool { 141 return true 142 } 143 144 func (node *MasterNode) DBModel() *model.Node { 145 node.lock.RLock() 146 defer node.lock.RUnlock() 147 148 return node.Model 149 } 150 151 func (r *rpcService) Init() error { 152 r.parent.lock.Lock() 153 defer r.parent.lock.Unlock() 154 r.Initialized = false 155 156 // 客户端已存在,则关闭先前连接 157 if r.Caller != nil { 158 r.Caller.Close() 159 } 160 161 // 解析RPC服务地址 162 server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server) 163 if err != nil { 164 util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err) 165 return err 166 } 167 server.Path = "/jsonrpc" 168 169 // 加载自定义下载配置 170 var globalOptions map[string]interface{} 171 if r.parent.Model.Aria2OptionsSerialized.Options != "" { 172 err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions) 173 if err != nil { 174 util.Log().Warning("Failed to parse aria2 options: %s", err) 175 return err 176 } 177 } 178 179 r.options = &clientOptions{ 180 Options: globalOptions, 181 } 182 timeout := r.parent.Model.Aria2OptionsSerialized.Timeout 183 caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ) 184 185 r.Caller = caller 186 r.Initialized = err == nil 187 return err 188 } 189 190 func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) { 191 r.parent.lock.RLock() 192 // 生成存储路径 193 guid, _ := uuid.NewV4() 194 path := filepath.Join( 195 r.parent.Model.Aria2OptionsSerialized.TempPath, 196 "aria2", 197 guid.String(), 198 ) 199 r.parent.lock.RUnlock() 200 201 // 创建下载任务 202 options := map[string]interface{}{ 203 "dir": path, 204 } 205 for k, v := range r.options.Options { 206 options[k] = v 207 } 208 for k, v := range groupOptions { 209 options[k] = v 210 } 211 212 gid, err := r.Caller.AddURI(task.Source, options) 213 if err != nil || gid == "" { 214 return "", err 215 } 216 217 return gid, nil 218 } 219 220 func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) { 221 res, err := r.Caller.TellStatus(task.GID) 222 if err != nil { 223 // 失败后重试 224 util.Log().Debug("Failed to get download task status, please retry later: %s", err) 225 time.Sleep(r.retryDuration) 226 res, err = r.Caller.TellStatus(task.GID) 227 } 228 229 return res, err 230 } 231 232 func (r *rpcService) Cancel(task *model.Download) error { 233 // 取消下载任务 234 _, err := r.Caller.Remove(task.GID) 235 if err != nil { 236 util.Log().Warning("Failed to cancel task %q: %s", task.GID, err) 237 } 238 239 return err 240 } 241 242 func (r *rpcService) Select(task *model.Download, files []int) error { 243 var selected = make([]string, len(files)) 244 for i := 0; i < len(files); i++ { 245 selected[i] = strconv.Itoa(files[i]) 246 } 247 _, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) 248 return err 249 } 250 251 func (r *rpcService) GetConfig() model.Aria2Option { 252 r.parent.lock.RLock() 253 defer r.parent.lock.RUnlock() 254 255 return r.parent.Model.Aria2OptionsSerialized 256 } 257 258 func (s *rpcService) DeleteTempFile(task *model.Download) error { 259 s.parent.lock.RLock() 260 defer s.parent.lock.RUnlock() 261 262 // 避免被aria2占用,异步执行删除 263 go func(d time.Duration, src string) { 264 time.Sleep(d) 265 err := os.RemoveAll(src) 266 if err != nil { 267 util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err) 268 } 269 }(s.deletePaddingDuration, task.Parent) 270 271 return nil 272 }