github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/controller.go (about) 1 package cluster 2 3 import ( 4 "bytes" 5 "encoding/gob" 6 "fmt" 7 model "github.com/cloudreve/Cloudreve/v3/models" 8 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" 9 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 10 "github.com/cloudreve/Cloudreve/v3/pkg/auth" 11 "github.com/cloudreve/Cloudreve/v3/pkg/mq" 12 "github.com/cloudreve/Cloudreve/v3/pkg/request" 13 "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 14 "github.com/jinzhu/gorm" 15 "net/url" 16 "sync" 17 ) 18 19 var DefaultController Controller 20 21 // Controller controls communications between master and slave 22 type Controller interface { 23 // Handle heartbeat sent from master 24 HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error) 25 26 // Get Aria2 Instance by master node ID 27 GetAria2Instance(string) (common.Aria2, error) 28 29 // Send event change message to master node 30 SendNotification(string, string, mq.Message) error 31 32 // Submit async task into task pool 33 SubmitTask(string, interface{}, string, func(interface{})) error 34 35 // Get master node info 36 GetMasterInfo(string) (*MasterInfo, error) 37 38 // Get master Oauth based policy credential 39 GetPolicyOauthToken(string, uint) (string, error) 40 } 41 42 type slaveController struct { 43 masters map[string]MasterInfo 44 lock sync.RWMutex 45 } 46 47 // info of master node 48 type MasterInfo struct { 49 ID string 50 TTL int 51 URL *url.URL 52 // used to invoke aria2 rpc calls 53 Instance Node 54 Client request.Client 55 56 jobTracker map[string]bool 57 } 58 59 func InitController() { 60 DefaultController = &slaveController{ 61 masters: make(map[string]MasterInfo), 62 } 63 gob.Register(rpc.StatusInfo{}) 64 } 65 66 func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) { 67 c.lock.Lock() 68 defer c.lock.Unlock() 69 70 req.Node.AfterFind() 71 72 // close old node if exist 73 origin, ok := c.masters[req.SiteID] 74 75 if (ok && req.IsUpdate) || !ok { 76 if ok { 77 origin.Instance.Kill() 78 } 79 80 masterUrl, err := url.Parse(req.SiteURL) 81 if err != nil { 82 return serializer.NodePingResp{}, err 83 } 84 85 c.masters[req.SiteID] = MasterInfo{ 86 ID: req.SiteID, 87 URL: masterUrl, 88 TTL: req.CredentialTTL, 89 Client: request.NewClient( 90 request.WithEndpoint(masterUrl.String()), 91 request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)), 92 request.WithCredential(auth.HMACAuth{ 93 SecretKey: []byte(req.Node.MasterKey), 94 }, int64(req.CredentialTTL)), 95 ), 96 jobTracker: make(map[string]bool), 97 Instance: NewNodeFromDBModel(&model.Node{ 98 Model: gorm.Model{ID: req.Node.ID}, 99 MasterKey: req.Node.MasterKey, 100 Type: model.MasterNodeType, 101 Aria2Enabled: req.Node.Aria2Enabled, 102 Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized, 103 }), 104 } 105 } 106 107 return serializer.NodePingResp{}, nil 108 } 109 110 func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) { 111 c.lock.RLock() 112 defer c.lock.RUnlock() 113 114 if node, ok := c.masters[id]; ok { 115 return node.Instance.GetAria2Instance(), nil 116 } 117 118 return nil, ErrMasterNotFound 119 } 120 121 func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error { 122 c.lock.RLock() 123 124 if node, ok := c.masters[id]; ok { 125 c.lock.RUnlock() 126 127 body := bytes.Buffer{} 128 enc := gob.NewEncoder(&body) 129 if err := enc.Encode(&msg); err != nil { 130 return err 131 } 132 133 res, err := node.Client.Request( 134 "PUT", 135 fmt.Sprintf("/api/v3/slave/notification/%s", subject), 136 &body, 137 ).CheckHTTPResponse(200).DecodeResponse() 138 if err != nil { 139 return err 140 } 141 142 if res.Code != 0 { 143 return serializer.NewErrorFromResponse(res) 144 } 145 146 return nil 147 } 148 149 c.lock.RUnlock() 150 return ErrMasterNotFound 151 } 152 153 // SubmitTask 提交异步任务 154 func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error { 155 c.lock.RLock() 156 defer c.lock.RUnlock() 157 158 if node, ok := c.masters[id]; ok { 159 if _, ok := node.jobTracker[hash]; ok { 160 // 任务已存在,直接返回 161 return nil 162 } 163 164 node.jobTracker[hash] = true 165 submitter(job) 166 return nil 167 } 168 169 return ErrMasterNotFound 170 } 171 172 // GetMasterInfo 获取主机节点信息 173 func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) { 174 c.lock.RLock() 175 defer c.lock.RUnlock() 176 177 if node, ok := c.masters[id]; ok { 178 return &node, nil 179 } 180 181 return nil, ErrMasterNotFound 182 } 183 184 // GetPolicyOauthToken 获取主机存储策略 Oauth 凭证 185 func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) { 186 c.lock.RLock() 187 188 if node, ok := c.masters[id]; ok { 189 c.lock.RUnlock() 190 191 res, err := node.Client.Request( 192 "GET", 193 fmt.Sprintf("/api/v3/slave/credential/%d", policyID), 194 nil, 195 ).CheckHTTPResponse(200).DecodeResponse() 196 if err != nil { 197 return "", err 198 } 199 200 if res.Code != 0 { 201 return "", serializer.NewErrorFromResponse(res) 202 } 203 204 return res.Data.(string), nil 205 } 206 207 c.lock.RUnlock() 208 return "", ErrMasterNotFound 209 }