github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/aria2/rpc/call.go (about) 1 package rpc 2 3 import ( 4 "context" 5 "errors" 6 "log" 7 "net" 8 "net/http" 9 "net/url" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/gorilla/websocket" 15 ) 16 17 type caller interface { 18 // Call sends a request of rpc to aria2 daemon 19 Call(method string, params, reply interface{}) (err error) 20 Close() error 21 } 22 23 type httpCaller struct { 24 uri string 25 c *http.Client 26 cancel context.CancelFunc 27 wg *sync.WaitGroup 28 once sync.Once 29 } 30 31 func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifer Notifier) *httpCaller { 32 c := &http.Client{ 33 Transport: &http.Transport{ 34 MaxIdleConnsPerHost: 1, 35 MaxConnsPerHost: 1, 36 // TLSClientConfig: tlsConfig, 37 Dial: (&net.Dialer{ 38 Timeout: timeout, 39 KeepAlive: 60 * time.Second, 40 }).Dial, 41 TLSHandshakeTimeout: 3 * time.Second, 42 ResponseHeaderTimeout: timeout, 43 }, 44 } 45 var wg sync.WaitGroup 46 ctx, cancel := context.WithCancel(ctx) 47 h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg} 48 if notifer != nil { 49 h.setNotifier(ctx, *u, notifer) 50 } 51 return h 52 } 53 54 func (h *httpCaller) Close() (err error) { 55 h.once.Do(func() { 56 h.cancel() 57 h.wg.Wait() 58 }) 59 return 60 } 61 62 func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifer Notifier) (err error) { 63 u.Scheme = "ws" 64 conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) 65 if err != nil { 66 return 67 } 68 h.wg.Add(1) 69 go func() { 70 defer h.wg.Done() 71 defer conn.Close() 72 select { 73 case <-ctx.Done(): 74 conn.SetWriteDeadline(time.Now().Add(time.Second)) 75 if err := conn.WriteMessage(websocket.CloseMessage, 76 websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { 77 log.Printf("sending websocket close message: %v", err) 78 } 79 return 80 } 81 }() 82 h.wg.Add(1) 83 go func() { 84 defer h.wg.Done() 85 var request websocketResponse 86 var err error 87 for { 88 select { 89 case <-ctx.Done(): 90 return 91 default: 92 } 93 if err = conn.ReadJSON(&request); err != nil { 94 select { 95 case <-ctx.Done(): 96 return 97 default: 98 } 99 log.Printf("conn.ReadJSON|err:%v", err.Error()) 100 return 101 } 102 switch request.Method { 103 case "aria2.onDownloadStart": 104 notifer.OnDownloadStart(request.Params) 105 case "aria2.onDownloadPause": 106 notifer.OnDownloadPause(request.Params) 107 case "aria2.onDownloadStop": 108 notifer.OnDownloadStop(request.Params) 109 case "aria2.onDownloadComplete": 110 notifer.OnDownloadComplete(request.Params) 111 case "aria2.onDownloadError": 112 notifer.OnDownloadError(request.Params) 113 case "aria2.onBtDownloadComplete": 114 notifer.OnBtDownloadComplete(request.Params) 115 default: 116 log.Printf("unexpected notification: %s", request.Method) 117 } 118 } 119 }() 120 return 121 } 122 123 func (h httpCaller) Call(method string, params, reply interface{}) (err error) { 124 payload, err := EncodeClientRequest(method, params) 125 if err != nil { 126 return 127 } 128 r, err := h.c.Post(h.uri, "application/json", payload) 129 if err != nil { 130 return 131 } 132 err = DecodeClientResponse(r.Body, &reply) 133 r.Body.Close() 134 return 135 } 136 137 type websocketCaller struct { 138 conn *websocket.Conn 139 sendChan chan *sendRequest 140 cancel context.CancelFunc 141 wg *sync.WaitGroup 142 once sync.Once 143 timeout time.Duration 144 } 145 146 func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) { 147 var header = http.Header{} 148 conn, _, err := websocket.DefaultDialer.Dial(uri, header) 149 if err != nil { 150 return nil, err 151 } 152 153 sendChan := make(chan *sendRequest, 16) 154 var wg sync.WaitGroup 155 ctx, cancel := context.WithCancel(ctx) 156 w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout} 157 processor := NewResponseProcessor() 158 wg.Add(1) 159 go func() { // routine:recv 160 defer wg.Done() 161 defer cancel() 162 for { 163 select { 164 case <-ctx.Done(): 165 return 166 default: 167 } 168 var resp websocketResponse 169 if err := conn.ReadJSON(&resp); err != nil { 170 select { 171 case <-ctx.Done(): 172 return 173 default: 174 } 175 log.Printf("conn.ReadJSON|err:%v", err.Error()) 176 return 177 } 178 if resp.Id == nil { // RPC notifications 179 if notifier != nil { 180 switch resp.Method { 181 case "aria2.onDownloadStart": 182 notifier.OnDownloadStart(resp.Params) 183 case "aria2.onDownloadPause": 184 notifier.OnDownloadPause(resp.Params) 185 case "aria2.onDownloadStop": 186 notifier.OnDownloadStop(resp.Params) 187 case "aria2.onDownloadComplete": 188 notifier.OnDownloadComplete(resp.Params) 189 case "aria2.onDownloadError": 190 notifier.OnDownloadError(resp.Params) 191 case "aria2.onBtDownloadComplete": 192 notifier.OnBtDownloadComplete(resp.Params) 193 default: 194 log.Printf("unexpected notification: %s", resp.Method) 195 } 196 } 197 continue 198 } 199 processor.Process(resp.clientResponse) 200 } 201 }() 202 wg.Add(1) 203 go func() { // routine:send 204 defer wg.Done() 205 defer cancel() 206 defer w.conn.Close() 207 208 for { 209 select { 210 case <-ctx.Done(): 211 if err := w.conn.WriteMessage(websocket.CloseMessage, 212 websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { 213 log.Printf("sending websocket close message: %v", err) 214 } 215 return 216 case req := <-sendChan: 217 processor.Add(req.request.Id, func(resp clientResponse) error { 218 err := resp.decode(req.reply) 219 req.cancel() 220 return err 221 }) 222 w.conn.SetWriteDeadline(time.Now().Add(timeout)) 223 w.conn.WriteJSON(req.request) 224 } 225 } 226 }() 227 228 return w, nil 229 } 230 231 func (w *websocketCaller) Close() (err error) { 232 w.once.Do(func() { 233 w.cancel() 234 w.wg.Wait() 235 }) 236 return 237 } 238 239 func (w websocketCaller) Call(method string, params, reply interface{}) (err error) { 240 ctx, cancel := context.WithTimeout(context.Background(), w.timeout) 241 defer cancel() 242 select { 243 case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{ 244 Version: "2.0", 245 Method: method, 246 Params: params, 247 Id: reqid(), 248 }, reply: reply}: 249 250 default: 251 return errors.New("sending channel blocking") 252 } 253 254 select { 255 case <-ctx.Done(): 256 if err := ctx.Err(); err == context.DeadlineExceeded { 257 return err 258 } 259 } 260 return 261 } 262 263 type sendRequest struct { 264 cancel context.CancelFunc 265 request *clientRequest 266 reply interface{} 267 } 268 269 var reqid = func() func() uint64 { 270 var id = uint64(time.Now().UnixNano()) 271 return func() uint64 { 272 return atomic.AddUint64(&id, 1) 273 } 274 }()