github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/slave.go (about)

     1  package cluster
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	model "github.com/cloudreve/Cloudreve/v3/models"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
    10  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
    11  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
    12  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
    13  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    14  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    15  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    16  	"io"
    17  	"net/url"
    18  	"strings"
    19  	"sync"
    20  	"time"
    21  )
    22  
    23  type SlaveNode struct {
    24  	Model  *model.Node
    25  	Active bool
    26  
    27  	caller   slaveCaller
    28  	callback func(bool, uint)
    29  	close    chan bool
    30  	lock     sync.RWMutex
    31  }
    32  
    33  type slaveCaller struct {
    34  	parent *SlaveNode
    35  	Client request.Client
    36  }
    37  
    38  // Init 初始化节点
    39  func (node *SlaveNode) Init(nodeModel *model.Node) {
    40  	node.lock.Lock()
    41  	node.Model = nodeModel
    42  
    43  	// Init http request client
    44  	var endpoint *url.URL
    45  	if serverURL, err := url.Parse(node.Model.Server); err == nil {
    46  		var controller *url.URL
    47  		controller, _ = url.Parse("/api/v3/slave/")
    48  		endpoint = serverURL.ResolveReference(controller)
    49  	}
    50  
    51  	signTTL := model.GetIntSetting("slave_api_timeout", 60)
    52  	node.caller.Client = request.NewClient(
    53  		request.WithMasterMeta(),
    54  		request.WithTimeout(time.Duration(signTTL)*time.Second),
    55  		request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
    56  		request.WithEndpoint(endpoint.String()),
    57  	)
    58  
    59  	node.caller.parent = node
    60  	if node.close != nil {
    61  		node.lock.Unlock()
    62  		node.close <- true
    63  		go node.StartPingLoop()
    64  	} else {
    65  		node.Active = true
    66  		node.lock.Unlock()
    67  		go node.StartPingLoop()
    68  	}
    69  }
    70  
    71  // IsFeatureEnabled 查询节点的某项功能是否启用
    72  func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
    73  	node.lock.RLock()
    74  	defer node.lock.RUnlock()
    75  
    76  	switch feature {
    77  	case "aria2":
    78  		return node.Model.Aria2Enabled
    79  	default:
    80  		return false
    81  	}
    82  }
    83  
    84  // SubscribeStatusChange 订阅节点状态更改
    85  func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
    86  	node.lock.Lock()
    87  	node.callback = callback
    88  	node.lock.Unlock()
    89  }
    90  
    91  // Ping 从机节点,返回从机负载
    92  func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
    93  	node.lock.RLock()
    94  	defer node.lock.RUnlock()
    95  
    96  	reqBodyEncoded, err := json.Marshal(req)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	bodyReader := strings.NewReader(string(reqBodyEncoded))
   102  
   103  	resp, err := node.caller.Client.Request(
   104  		"POST",
   105  		"heartbeat",
   106  		bodyReader,
   107  	).CheckHTTPResponse(200).DecodeResponse()
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	// 处理列取结果
   113  	if resp.Code != 0 {
   114  		return nil, serializer.NewErrorFromResponse(resp)
   115  	}
   116  
   117  	var res serializer.NodePingResp
   118  
   119  	if resStr, ok := resp.Data.(string); ok {
   120  		err = json.Unmarshal([]byte(resStr), &res)
   121  		if err != nil {
   122  			return nil, err
   123  		}
   124  	}
   125  
   126  	return &res, nil
   127  }
   128  
   129  // IsActive 返回节点是否在线
   130  func (node *SlaveNode) IsActive() bool {
   131  	node.lock.RLock()
   132  	defer node.lock.RUnlock()
   133  
   134  	return node.Active
   135  }
   136  
   137  // Kill 结束节点内相关循环
   138  func (node *SlaveNode) Kill() {
   139  	node.lock.RLock()
   140  	defer node.lock.RUnlock()
   141  
   142  	if node.close != nil {
   143  		close(node.close)
   144  	}
   145  }
   146  
   147  // GetAria2Instance 获取从机Aria2实例
   148  func (node *SlaveNode) GetAria2Instance() common.Aria2 {
   149  	node.lock.RLock()
   150  	defer node.lock.RUnlock()
   151  
   152  	if !node.Model.Aria2Enabled {
   153  		return &common.DummyAria2{}
   154  	}
   155  
   156  	return &node.caller
   157  }
   158  
   159  func (node *SlaveNode) ID() uint {
   160  	node.lock.RLock()
   161  	defer node.lock.RUnlock()
   162  
   163  	return node.Model.ID
   164  }
   165  
   166  func (node *SlaveNode) StartPingLoop() {
   167  	node.lock.Lock()
   168  	node.close = make(chan bool)
   169  	node.lock.Unlock()
   170  
   171  	tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
   172  	recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
   173  	pingTicker := time.Duration(0)
   174  
   175  	util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
   176  	retry := 0
   177  	recoverMode := false
   178  	isFirstLoop := true
   179  
   180  loop:
   181  	for {
   182  		select {
   183  		case <-time.After(pingTicker):
   184  			if pingTicker == 0 {
   185  				pingTicker = tickDuration
   186  			}
   187  
   188  			util.Log().Debug("Slave node %q send ping.", node.Model.Name)
   189  			res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
   190  			isFirstLoop = false
   191  
   192  			if err != nil {
   193  				util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
   194  				retry++
   195  				if retry >= model.GetIntSetting("slave_node_retry", 3) {
   196  					util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
   197  					node.changeStatus(false)
   198  
   199  					if !recoverMode {
   200  						// 启动恢复监控循环
   201  						util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
   202  						pingTicker = recoverDuration
   203  						recoverMode = true
   204  					}
   205  				}
   206  			} else {
   207  				if recoverMode {
   208  					util.Log().Debug("Slave node %q recovered.", node.Model.Name)
   209  					pingTicker = tickDuration
   210  					recoverMode = false
   211  					isFirstLoop = true
   212  				}
   213  
   214  				util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
   215  				node.changeStatus(true)
   216  				retry = 0
   217  			}
   218  
   219  		case <-node.close:
   220  			util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
   221  			break loop
   222  		}
   223  	}
   224  }
   225  
   226  func (node *SlaveNode) IsMater() bool {
   227  	return false
   228  }
   229  
   230  func (node *SlaveNode) MasterAuthInstance() auth.Auth {
   231  	node.lock.RLock()
   232  	defer node.lock.RUnlock()
   233  
   234  	return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
   235  }
   236  
   237  func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
   238  	node.lock.RLock()
   239  	defer node.lock.RUnlock()
   240  
   241  	return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
   242  }
   243  
   244  func (node *SlaveNode) DBModel() *model.Node {
   245  	node.lock.RLock()
   246  	defer node.lock.RUnlock()
   247  
   248  	return node.Model
   249  }
   250  
   251  // getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
   252  func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
   253  	return &serializer.NodePingReq{
   254  		SiteURL:       model.GetSiteURL().String(),
   255  		IsUpdate:      isUpdate,
   256  		SiteID:        model.GetSettingByName("siteID"),
   257  		Node:          node.Model,
   258  		CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
   259  	}
   260  }
   261  
   262  func (node *SlaveNode) changeStatus(isActive bool) {
   263  	node.lock.RLock()
   264  	id := node.Model.ID
   265  	if isActive != node.Active {
   266  		node.lock.RUnlock()
   267  		node.lock.Lock()
   268  		node.Active = isActive
   269  		node.lock.Unlock()
   270  		node.callback(isActive, id)
   271  	} else {
   272  		node.lock.RUnlock()
   273  	}
   274  
   275  }
   276  
   277  func (s *slaveCaller) Init() error {
   278  	return nil
   279  }
   280  
   281  // SendAria2Call send remote aria2 call to slave node
   282  func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
   283  	reqReader, err := getAria2RequestBody(body)
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  
   288  	return s.Client.Request(
   289  		"POST",
   290  		"aria2/"+scope,
   291  		reqReader,
   292  	).CheckHTTPResponse(200).DecodeResponse()
   293  }
   294  
   295  func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
   296  	s.parent.lock.RLock()
   297  	defer s.parent.lock.RUnlock()
   298  
   299  	req := &serializer.SlaveAria2Call{
   300  		Task:         task,
   301  		GroupOptions: options,
   302  	}
   303  
   304  	res, err := s.SendAria2Call(req, "task")
   305  	if err != nil {
   306  		return "", err
   307  	}
   308  
   309  	if res.Code != 0 {
   310  		return "", serializer.NewErrorFromResponse(res)
   311  	}
   312  
   313  	return res.Data.(string), err
   314  }
   315  
   316  func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
   317  	s.parent.lock.RLock()
   318  	defer s.parent.lock.RUnlock()
   319  
   320  	req := &serializer.SlaveAria2Call{
   321  		Task: task,
   322  	}
   323  
   324  	res, err := s.SendAria2Call(req, "status")
   325  	if err != nil {
   326  		return rpc.StatusInfo{}, err
   327  	}
   328  
   329  	if res.Code != 0 {
   330  		return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
   331  	}
   332  
   333  	var status rpc.StatusInfo
   334  	res.GobDecode(&status)
   335  
   336  	return status, err
   337  }
   338  
   339  func (s *slaveCaller) Cancel(task *model.Download) error {
   340  	s.parent.lock.RLock()
   341  	defer s.parent.lock.RUnlock()
   342  
   343  	req := &serializer.SlaveAria2Call{
   344  		Task: task,
   345  	}
   346  
   347  	res, err := s.SendAria2Call(req, "cancel")
   348  	if err != nil {
   349  		return err
   350  	}
   351  
   352  	if res.Code != 0 {
   353  		return serializer.NewErrorFromResponse(res)
   354  	}
   355  
   356  	return nil
   357  }
   358  
   359  func (s *slaveCaller) Select(task *model.Download, files []int) error {
   360  	s.parent.lock.RLock()
   361  	defer s.parent.lock.RUnlock()
   362  
   363  	req := &serializer.SlaveAria2Call{
   364  		Task:  task,
   365  		Files: files,
   366  	}
   367  
   368  	res, err := s.SendAria2Call(req, "select")
   369  	if err != nil {
   370  		return err
   371  	}
   372  
   373  	if res.Code != 0 {
   374  		return serializer.NewErrorFromResponse(res)
   375  	}
   376  
   377  	return nil
   378  }
   379  
   380  func (s *slaveCaller) GetConfig() model.Aria2Option {
   381  	s.parent.lock.RLock()
   382  	defer s.parent.lock.RUnlock()
   383  
   384  	return s.parent.Model.Aria2OptionsSerialized
   385  }
   386  
   387  func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
   388  	s.parent.lock.RLock()
   389  	defer s.parent.lock.RUnlock()
   390  
   391  	req := &serializer.SlaveAria2Call{
   392  		Task: task,
   393  	}
   394  
   395  	res, err := s.SendAria2Call(req, "delete")
   396  	if err != nil {
   397  		return err
   398  	}
   399  
   400  	if res.Code != 0 {
   401  		return serializer.NewErrorFromResponse(res)
   402  	}
   403  
   404  	return nil
   405  }
   406  
   407  func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
   408  	reqBodyEncoded, err := json.Marshal(body)
   409  	if err != nil {
   410  		return nil, err
   411  	}
   412  
   413  	return strings.NewReader(string(reqBodyEncoded)), nil
   414  }
   415  
   416  // RemoteCallback 发送远程存储策略上传回调请求
   417  func RemoteCallback(url string, body serializer.UploadCallback) error {
   418  	callbackBody, err := json.Marshal(struct {
   419  		Data serializer.UploadCallback `json:"data"`
   420  	}{
   421  		Data: body,
   422  	})
   423  	if err != nil {
   424  		return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
   425  	}
   426  
   427  	resp := request.GeneralClient.Request(
   428  		"POST",
   429  		url,
   430  		bytes.NewReader(callbackBody),
   431  		request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
   432  		request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
   433  	)
   434  
   435  	if resp.Err != nil {
   436  		return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
   437  	}
   438  
   439  	// 解析回调服务端响应
   440  	response, err := resp.DecodeResponse()
   441  	if err != nil {
   442  		msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
   443  		return serializer.NewError(serializer.CodeCallbackError, msg, err)
   444  	}
   445  
   446  	if response.Code != 0 {
   447  		return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
   448  	}
   449  
   450  	return nil
   451  }