github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/controller.go (about)

     1  package cluster
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/gob"
     6  	"fmt"
     7  	model "github.com/cloudreve/Cloudreve/v3/models"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
    10  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
    11  	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
    12  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    13  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    14  	"github.com/jinzhu/gorm"
    15  	"net/url"
    16  	"sync"
    17  )
    18  
    19  var DefaultController Controller
    20  
    21  // Controller controls communications between master and slave
    22  type Controller interface {
    23  	// Handle heartbeat sent from master
    24  	HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
    25  
    26  	// Get Aria2 Instance by master node ID
    27  	GetAria2Instance(string) (common.Aria2, error)
    28  
    29  	// Send event change message to master node
    30  	SendNotification(string, string, mq.Message) error
    31  
    32  	// Submit async task into task pool
    33  	SubmitTask(string, interface{}, string, func(interface{})) error
    34  
    35  	// Get master node info
    36  	GetMasterInfo(string) (*MasterInfo, error)
    37  
    38  	// Get master Oauth based policy credential
    39  	GetPolicyOauthToken(string, uint) (string, error)
    40  }
    41  
    42  type slaveController struct {
    43  	masters map[string]MasterInfo
    44  	lock    sync.RWMutex
    45  }
    46  
    47  // info of master node
    48  type MasterInfo struct {
    49  	ID  string
    50  	TTL int
    51  	URL *url.URL
    52  	// used to invoke aria2 rpc calls
    53  	Instance Node
    54  	Client   request.Client
    55  
    56  	jobTracker map[string]bool
    57  }
    58  
    59  func InitController() {
    60  	DefaultController = &slaveController{
    61  		masters: make(map[string]MasterInfo),
    62  	}
    63  	gob.Register(rpc.StatusInfo{})
    64  }
    65  
    66  func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
    67  	c.lock.Lock()
    68  	defer c.lock.Unlock()
    69  
    70  	req.Node.AfterFind()
    71  
    72  	// close old node if exist
    73  	origin, ok := c.masters[req.SiteID]
    74  
    75  	if (ok && req.IsUpdate) || !ok {
    76  		if ok {
    77  			origin.Instance.Kill()
    78  		}
    79  
    80  		masterUrl, err := url.Parse(req.SiteURL)
    81  		if err != nil {
    82  			return serializer.NodePingResp{}, err
    83  		}
    84  
    85  		c.masters[req.SiteID] = MasterInfo{
    86  			ID:  req.SiteID,
    87  			URL: masterUrl,
    88  			TTL: req.CredentialTTL,
    89  			Client: request.NewClient(
    90  				request.WithEndpoint(masterUrl.String()),
    91  				request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
    92  				request.WithCredential(auth.HMACAuth{
    93  					SecretKey: []byte(req.Node.MasterKey),
    94  				}, int64(req.CredentialTTL)),
    95  			),
    96  			jobTracker: make(map[string]bool),
    97  			Instance: NewNodeFromDBModel(&model.Node{
    98  				Model:                  gorm.Model{ID: req.Node.ID},
    99  				MasterKey:              req.Node.MasterKey,
   100  				Type:                   model.MasterNodeType,
   101  				Aria2Enabled:           req.Node.Aria2Enabled,
   102  				Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
   103  			}),
   104  		}
   105  	}
   106  
   107  	return serializer.NodePingResp{}, nil
   108  }
   109  
   110  func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
   111  	c.lock.RLock()
   112  	defer c.lock.RUnlock()
   113  
   114  	if node, ok := c.masters[id]; ok {
   115  		return node.Instance.GetAria2Instance(), nil
   116  	}
   117  
   118  	return nil, ErrMasterNotFound
   119  }
   120  
   121  func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
   122  	c.lock.RLock()
   123  
   124  	if node, ok := c.masters[id]; ok {
   125  		c.lock.RUnlock()
   126  
   127  		body := bytes.Buffer{}
   128  		enc := gob.NewEncoder(&body)
   129  		if err := enc.Encode(&msg); err != nil {
   130  			return err
   131  		}
   132  
   133  		res, err := node.Client.Request(
   134  			"PUT",
   135  			fmt.Sprintf("/api/v3/slave/notification/%s", subject),
   136  			&body,
   137  		).CheckHTTPResponse(200).DecodeResponse()
   138  		if err != nil {
   139  			return err
   140  		}
   141  
   142  		if res.Code != 0 {
   143  			return serializer.NewErrorFromResponse(res)
   144  		}
   145  
   146  		return nil
   147  	}
   148  
   149  	c.lock.RUnlock()
   150  	return ErrMasterNotFound
   151  }
   152  
   153  // SubmitTask 提交异步任务
   154  func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
   155  	c.lock.RLock()
   156  	defer c.lock.RUnlock()
   157  
   158  	if node, ok := c.masters[id]; ok {
   159  		if _, ok := node.jobTracker[hash]; ok {
   160  			// 任务已存在,直接返回
   161  			return nil
   162  		}
   163  
   164  		node.jobTracker[hash] = true
   165  		submitter(job)
   166  		return nil
   167  	}
   168  
   169  	return ErrMasterNotFound
   170  }
   171  
   172  // GetMasterInfo 获取主机节点信息
   173  func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
   174  	c.lock.RLock()
   175  	defer c.lock.RUnlock()
   176  
   177  	if node, ok := c.masters[id]; ok {
   178  		return &node, nil
   179  	}
   180  
   181  	return nil, ErrMasterNotFound
   182  }
   183  
   184  // GetPolicyOauthToken 获取主机存储策略 Oauth 凭证
   185  func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) {
   186  	c.lock.RLock()
   187  
   188  	if node, ok := c.masters[id]; ok {
   189  		c.lock.RUnlock()
   190  
   191  		res, err := node.Client.Request(
   192  			"GET",
   193  			fmt.Sprintf("/api/v3/slave/credential/%d", policyID),
   194  			nil,
   195  		).CheckHTTPResponse(200).DecodeResponse()
   196  		if err != nil {
   197  			return "", err
   198  		}
   199  
   200  		if res.Code != 0 {
   201  			return "", serializer.NewErrorFromResponse(res)
   202  		}
   203  
   204  		return res.Data.(string), nil
   205  	}
   206  
   207  	c.lock.RUnlock()
   208  	return "", ErrMasterNotFound
   209  }