github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/aria2/rpc/call.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/gorilla/websocket"
    15  )
    16  
    17  type caller interface {
    18  	// Call sends a request of rpc to aria2 daemon
    19  	Call(method string, params, reply interface{}) (err error)
    20  	Close() error
    21  }
    22  
    23  type httpCaller struct {
    24  	uri    string
    25  	c      *http.Client
    26  	cancel context.CancelFunc
    27  	wg     *sync.WaitGroup
    28  	once   sync.Once
    29  }
    30  
    31  func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifer Notifier) *httpCaller {
    32  	c := &http.Client{
    33  		Transport: &http.Transport{
    34  			MaxIdleConnsPerHost: 1,
    35  			MaxConnsPerHost:     1,
    36  			// TLSClientConfig:     tlsConfig,
    37  			Dial: (&net.Dialer{
    38  				Timeout:   timeout,
    39  				KeepAlive: 60 * time.Second,
    40  			}).Dial,
    41  			TLSHandshakeTimeout:   3 * time.Second,
    42  			ResponseHeaderTimeout: timeout,
    43  		},
    44  	}
    45  	var wg sync.WaitGroup
    46  	ctx, cancel := context.WithCancel(ctx)
    47  	h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg}
    48  	if notifer != nil {
    49  		h.setNotifier(ctx, *u, notifer)
    50  	}
    51  	return h
    52  }
    53  
    54  func (h *httpCaller) Close() (err error) {
    55  	h.once.Do(func() {
    56  		h.cancel()
    57  		h.wg.Wait()
    58  	})
    59  	return
    60  }
    61  
    62  func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifer Notifier) (err error) {
    63  	u.Scheme = "ws"
    64  	conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
    65  	if err != nil {
    66  		return
    67  	}
    68  	h.wg.Add(1)
    69  	go func() {
    70  		defer h.wg.Done()
    71  		defer conn.Close()
    72  		select {
    73  		case <-ctx.Done():
    74  			conn.SetWriteDeadline(time.Now().Add(time.Second))
    75  			if err := conn.WriteMessage(websocket.CloseMessage,
    76  				websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
    77  				log.Printf("sending websocket close message: %v", err)
    78  			}
    79  			return
    80  		}
    81  	}()
    82  	h.wg.Add(1)
    83  	go func() {
    84  		defer h.wg.Done()
    85  		var request websocketResponse
    86  		var err error
    87  		for {
    88  			select {
    89  			case <-ctx.Done():
    90  				return
    91  			default:
    92  			}
    93  			if err = conn.ReadJSON(&request); err != nil {
    94  				select {
    95  				case <-ctx.Done():
    96  					return
    97  				default:
    98  				}
    99  				log.Printf("conn.ReadJSON|err:%v", err.Error())
   100  				return
   101  			}
   102  			switch request.Method {
   103  			case "aria2.onDownloadStart":
   104  				notifer.OnDownloadStart(request.Params)
   105  			case "aria2.onDownloadPause":
   106  				notifer.OnDownloadPause(request.Params)
   107  			case "aria2.onDownloadStop":
   108  				notifer.OnDownloadStop(request.Params)
   109  			case "aria2.onDownloadComplete":
   110  				notifer.OnDownloadComplete(request.Params)
   111  			case "aria2.onDownloadError":
   112  				notifer.OnDownloadError(request.Params)
   113  			case "aria2.onBtDownloadComplete":
   114  				notifer.OnBtDownloadComplete(request.Params)
   115  			default:
   116  				log.Printf("unexpected notification: %s", request.Method)
   117  			}
   118  		}
   119  	}()
   120  	return
   121  }
   122  
   123  func (h httpCaller) Call(method string, params, reply interface{}) (err error) {
   124  	payload, err := EncodeClientRequest(method, params)
   125  	if err != nil {
   126  		return
   127  	}
   128  	r, err := h.c.Post(h.uri, "application/json", payload)
   129  	if err != nil {
   130  		return
   131  	}
   132  	err = DecodeClientResponse(r.Body, &reply)
   133  	r.Body.Close()
   134  	return
   135  }
   136  
   137  type websocketCaller struct {
   138  	conn     *websocket.Conn
   139  	sendChan chan *sendRequest
   140  	cancel   context.CancelFunc
   141  	wg       *sync.WaitGroup
   142  	once     sync.Once
   143  	timeout  time.Duration
   144  }
   145  
   146  func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) {
   147  	var header = http.Header{}
   148  	conn, _, err := websocket.DefaultDialer.Dial(uri, header)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	sendChan := make(chan *sendRequest, 16)
   154  	var wg sync.WaitGroup
   155  	ctx, cancel := context.WithCancel(ctx)
   156  	w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout}
   157  	processor := NewResponseProcessor()
   158  	wg.Add(1)
   159  	go func() { // routine:recv
   160  		defer wg.Done()
   161  		defer cancel()
   162  		for {
   163  			select {
   164  			case <-ctx.Done():
   165  				return
   166  			default:
   167  			}
   168  			var resp websocketResponse
   169  			if err := conn.ReadJSON(&resp); err != nil {
   170  				select {
   171  				case <-ctx.Done():
   172  					return
   173  				default:
   174  				}
   175  				log.Printf("conn.ReadJSON|err:%v", err.Error())
   176  				return
   177  			}
   178  			if resp.Id == nil { // RPC notifications
   179  				if notifier != nil {
   180  					switch resp.Method {
   181  					case "aria2.onDownloadStart":
   182  						notifier.OnDownloadStart(resp.Params)
   183  					case "aria2.onDownloadPause":
   184  						notifier.OnDownloadPause(resp.Params)
   185  					case "aria2.onDownloadStop":
   186  						notifier.OnDownloadStop(resp.Params)
   187  					case "aria2.onDownloadComplete":
   188  						notifier.OnDownloadComplete(resp.Params)
   189  					case "aria2.onDownloadError":
   190  						notifier.OnDownloadError(resp.Params)
   191  					case "aria2.onBtDownloadComplete":
   192  						notifier.OnBtDownloadComplete(resp.Params)
   193  					default:
   194  						log.Printf("unexpected notification: %s", resp.Method)
   195  					}
   196  				}
   197  				continue
   198  			}
   199  			processor.Process(resp.clientResponse)
   200  		}
   201  	}()
   202  	wg.Add(1)
   203  	go func() { // routine:send
   204  		defer wg.Done()
   205  		defer cancel()
   206  		defer w.conn.Close()
   207  
   208  		for {
   209  			select {
   210  			case <-ctx.Done():
   211  				if err := w.conn.WriteMessage(websocket.CloseMessage,
   212  					websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
   213  					log.Printf("sending websocket close message: %v", err)
   214  				}
   215  				return
   216  			case req := <-sendChan:
   217  				processor.Add(req.request.Id, func(resp clientResponse) error {
   218  					err := resp.decode(req.reply)
   219  					req.cancel()
   220  					return err
   221  				})
   222  				w.conn.SetWriteDeadline(time.Now().Add(timeout))
   223  				w.conn.WriteJSON(req.request)
   224  			}
   225  		}
   226  	}()
   227  
   228  	return w, nil
   229  }
   230  
   231  func (w *websocketCaller) Close() (err error) {
   232  	w.once.Do(func() {
   233  		w.cancel()
   234  		w.wg.Wait()
   235  	})
   236  	return
   237  }
   238  
   239  func (w websocketCaller) Call(method string, params, reply interface{}) (err error) {
   240  	ctx, cancel := context.WithTimeout(context.Background(), w.timeout)
   241  	defer cancel()
   242  	select {
   243  	case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{
   244  		Version: "2.0",
   245  		Method:  method,
   246  		Params:  params,
   247  		Id:      reqid(),
   248  	}, reply: reply}:
   249  
   250  	default:
   251  		return errors.New("sending channel blocking")
   252  	}
   253  
   254  	select {
   255  	case <-ctx.Done():
   256  		if err := ctx.Err(); err == context.DeadlineExceeded {
   257  			return err
   258  		}
   259  	}
   260  	return
   261  }
   262  
   263  type sendRequest struct {
   264  	cancel  context.CancelFunc
   265  	request *clientRequest
   266  	reply   interface{}
   267  }
   268  
   269  var reqid = func() func() uint64 {
   270  	var id = uint64(time.Now().UnixNano())
   271  	return func() uint64 {
   272  		return atomic.AddUint64(&id, 1)
   273  	}
   274  }()