github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/master.go (about)

     1  package cluster
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	model "github.com/cloudreve/Cloudreve/v3/models"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
    10  	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
    11  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    12  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    13  	"github.com/gofrs/uuid"
    14  	"net/url"
    15  	"os"
    16  	"path/filepath"
    17  	"strconv"
    18  	"strings"
    19  	"sync"
    20  	"time"
    21  )
    22  
    23  const (
    24  	deleteTempFileDuration = 60 * time.Second
    25  	statusRetryDuration    = 10 * time.Second
    26  )
    27  
    28  type MasterNode struct {
    29  	Model    *model.Node
    30  	aria2RPC rpcService
    31  	lock     sync.RWMutex
    32  }
    33  
    34  // RPCService 通过RPC服务的Aria2任务管理器
    35  type rpcService struct {
    36  	Caller      rpc.Client
    37  	Initialized bool
    38  
    39  	retryDuration         time.Duration
    40  	deletePaddingDuration time.Duration
    41  	parent                *MasterNode
    42  	options               *clientOptions
    43  }
    44  
    45  type clientOptions struct {
    46  	Options map[string]interface{} // 创建下载时额外添加的设置
    47  }
    48  
    49  // Init 初始化节点
    50  func (node *MasterNode) Init(nodeModel *model.Node) {
    51  	node.lock.Lock()
    52  	node.Model = nodeModel
    53  	node.aria2RPC.parent = node
    54  	node.aria2RPC.retryDuration = statusRetryDuration
    55  	node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
    56  	node.lock.Unlock()
    57  
    58  	node.lock.RLock()
    59  	if node.Model.Aria2Enabled {
    60  		node.lock.RUnlock()
    61  		node.aria2RPC.Init()
    62  		return
    63  	}
    64  	node.lock.RUnlock()
    65  }
    66  
    67  func (node *MasterNode) ID() uint {
    68  	node.lock.RLock()
    69  	defer node.lock.RUnlock()
    70  
    71  	return node.Model.ID
    72  }
    73  
    74  func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
    75  	return &serializer.NodePingResp{}, nil
    76  }
    77  
    78  // IsFeatureEnabled 查询节点的某项功能是否启用
    79  func (node *MasterNode) IsFeatureEnabled(feature string) bool {
    80  	node.lock.RLock()
    81  	defer node.lock.RUnlock()
    82  
    83  	switch feature {
    84  	case "aria2":
    85  		return node.Model.Aria2Enabled
    86  	default:
    87  		return false
    88  	}
    89  }
    90  
    91  func (node *MasterNode) MasterAuthInstance() auth.Auth {
    92  	node.lock.RLock()
    93  	defer node.lock.RUnlock()
    94  
    95  	return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
    96  }
    97  
    98  func (node *MasterNode) SlaveAuthInstance() auth.Auth {
    99  	node.lock.RLock()
   100  	defer node.lock.RUnlock()
   101  
   102  	return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
   103  }
   104  
   105  // SubscribeStatusChange 订阅节点状态更改
   106  func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
   107  }
   108  
   109  // IsActive 返回节点是否在线
   110  func (node *MasterNode) IsActive() bool {
   111  	return true
   112  }
   113  
   114  // Kill 结束aria2请求
   115  func (node *MasterNode) Kill() {
   116  	if node.aria2RPC.Caller != nil {
   117  		node.aria2RPC.Caller.Close()
   118  	}
   119  }
   120  
   121  // GetAria2Instance 获取主机Aria2实例
   122  func (node *MasterNode) GetAria2Instance() common.Aria2 {
   123  	node.lock.RLock()
   124  
   125  	if !node.Model.Aria2Enabled {
   126  		node.lock.RUnlock()
   127  		return &common.DummyAria2{}
   128  	}
   129  
   130  	if !node.aria2RPC.Initialized {
   131  		node.lock.RUnlock()
   132  		node.aria2RPC.Init()
   133  		return &common.DummyAria2{}
   134  	}
   135  
   136  	defer node.lock.RUnlock()
   137  	return &node.aria2RPC
   138  }
   139  
   140  func (node *MasterNode) IsMater() bool {
   141  	return true
   142  }
   143  
   144  func (node *MasterNode) DBModel() *model.Node {
   145  	node.lock.RLock()
   146  	defer node.lock.RUnlock()
   147  
   148  	return node.Model
   149  }
   150  
   151  func (r *rpcService) Init() error {
   152  	r.parent.lock.Lock()
   153  	defer r.parent.lock.Unlock()
   154  	r.Initialized = false
   155  
   156  	// 客户端已存在,则关闭先前连接
   157  	if r.Caller != nil {
   158  		r.Caller.Close()
   159  	}
   160  
   161  	// 解析RPC服务地址
   162  	server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
   163  	if err != nil {
   164  		util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
   165  		return err
   166  	}
   167  	server.Path = "/jsonrpc"
   168  
   169  	// 加载自定义下载配置
   170  	var globalOptions map[string]interface{}
   171  	if r.parent.Model.Aria2OptionsSerialized.Options != "" {
   172  		err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
   173  		if err != nil {
   174  			util.Log().Warning("Failed to parse aria2 options: %s", err)
   175  			return err
   176  		}
   177  	}
   178  
   179  	r.options = &clientOptions{
   180  		Options: globalOptions,
   181  	}
   182  	timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
   183  	caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
   184  
   185  	r.Caller = caller
   186  	r.Initialized = err == nil
   187  	return err
   188  }
   189  
   190  func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
   191  	r.parent.lock.RLock()
   192  	// 生成存储路径
   193  	guid, _ := uuid.NewV4()
   194  	path := filepath.Join(
   195  		r.parent.Model.Aria2OptionsSerialized.TempPath,
   196  		"aria2",
   197  		guid.String(),
   198  	)
   199  	r.parent.lock.RUnlock()
   200  
   201  	// 创建下载任务
   202  	options := map[string]interface{}{
   203  		"dir": path,
   204  	}
   205  	for k, v := range r.options.Options {
   206  		options[k] = v
   207  	}
   208  	for k, v := range groupOptions {
   209  		options[k] = v
   210  	}
   211  
   212  	gid, err := r.Caller.AddURI(task.Source, options)
   213  	if err != nil || gid == "" {
   214  		return "", err
   215  	}
   216  
   217  	return gid, nil
   218  }
   219  
   220  func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
   221  	res, err := r.Caller.TellStatus(task.GID)
   222  	if err != nil {
   223  		// 失败后重试
   224  		util.Log().Debug("Failed to get download task status, please retry later: %s", err)
   225  		time.Sleep(r.retryDuration)
   226  		res, err = r.Caller.TellStatus(task.GID)
   227  	}
   228  
   229  	return res, err
   230  }
   231  
   232  func (r *rpcService) Cancel(task *model.Download) error {
   233  	// 取消下载任务
   234  	_, err := r.Caller.Remove(task.GID)
   235  	if err != nil {
   236  		util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
   237  	}
   238  
   239  	return err
   240  }
   241  
   242  func (r *rpcService) Select(task *model.Download, files []int) error {
   243  	var selected = make([]string, len(files))
   244  	for i := 0; i < len(files); i++ {
   245  		selected[i] = strconv.Itoa(files[i])
   246  	}
   247  	_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
   248  	return err
   249  }
   250  
   251  func (r *rpcService) GetConfig() model.Aria2Option {
   252  	r.parent.lock.RLock()
   253  	defer r.parent.lock.RUnlock()
   254  
   255  	return r.parent.Model.Aria2OptionsSerialized
   256  }
   257  
   258  func (s *rpcService) DeleteTempFile(task *model.Download) error {
   259  	s.parent.lock.RLock()
   260  	defer s.parent.lock.RUnlock()
   261  
   262  	// 避免被aria2占用,异步执行删除
   263  	go func(d time.Duration, src string) {
   264  		time.Sleep(d)
   265  		err := os.RemoveAll(src)
   266  		if err != nil {
   267  			util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
   268  		}
   269  	}(s.deletePaddingDuration, task.Parent)
   270  
   271  	return nil
   272  }