github.com/Mrs4s/go-cqhttp@v1.2.0/server/websocket.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"runtime/debug"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/Mrs4s/MiraiGo/utils"
    18  	"github.com/RomiChan/websocket"
    19  	log "github.com/sirupsen/logrus"
    20  	"github.com/tidwall/gjson"
    21  	"gopkg.in/yaml.v3"
    22  
    23  	"github.com/Mrs4s/go-cqhttp/coolq"
    24  	"github.com/Mrs4s/go-cqhttp/global"
    25  	"github.com/Mrs4s/go-cqhttp/modules/api"
    26  	"github.com/Mrs4s/go-cqhttp/modules/config"
    27  	"github.com/Mrs4s/go-cqhttp/modules/filter"
    28  	"github.com/Mrs4s/go-cqhttp/pkg/onebot"
    29  )
    30  
    31  type webSocketServer struct {
    32  	bot  *coolq.CQBot
    33  	conf *WebsocketServer
    34  
    35  	mu        sync.Mutex
    36  	eventConn []*wsConn
    37  
    38  	token     string
    39  	handshake string
    40  	filter    string
    41  }
    42  
    43  // websocketClient WebSocket客户端实例
    44  type websocketClient struct {
    45  	bot       *coolq.CQBot
    46  	mu        sync.Mutex
    47  	universal *wsConn
    48  	event     *wsConn
    49  
    50  	token             string
    51  	filter            string
    52  	reconnectInterval time.Duration
    53  	limiter           api.Handler
    54  }
    55  
    56  type wsConn struct {
    57  	mu        sync.Mutex
    58  	conn      *websocket.Conn
    59  	apiCaller *api.Caller
    60  }
    61  
    62  func (c *wsConn) WriteText(b []byte) error {
    63  	c.mu.Lock()
    64  	defer c.mu.Unlock()
    65  	_ = c.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
    66  	return c.conn.WriteMessage(websocket.TextMessage, b)
    67  }
    68  
    69  func (c *wsConn) Close() error {
    70  	return c.conn.Close()
    71  }
    72  
    73  var upgrader = websocket.Upgrader{
    74  	CheckOrigin: func(r *http.Request) bool {
    75  		return true
    76  	},
    77  }
    78  
    79  const wsDefault = `  # 正向WS设置
    80    - ws:
    81        # 正向WS服务器监听地址
    82        address: 0.0.0.0:8080
    83        middlewares:
    84          <<: *default # 引用默认中间件
    85  `
    86  
    87  const wsReverseDefault = `  # 反向WS设置
    88    - ws-reverse:
    89        # 反向WS Universal 地址
    90        # 注意 设置了此项地址后下面两项将会被忽略
    91        universal: ws://your_websocket_universal.server
    92        # 反向WS API 地址
    93        api: ws://your_websocket_api.server
    94        # 反向WS Event 地址
    95        event: ws://your_websocket_event.server
    96        # 重连间隔 单位毫秒
    97        reconnect-interval: 3000
    98        middlewares:
    99          <<: *default # 引用默认中间件
   100  `
   101  
   102  // WebsocketServer 正向WS相关配置
   103  type WebsocketServer struct {
   104  	Disabled bool   `yaml:"disabled"`
   105  	Address  string `yaml:"address"`
   106  	Host     string `yaml:"host"`
   107  	Port     int    `yaml:"port"`
   108  
   109  	MiddleWares `yaml:"middlewares"`
   110  }
   111  
   112  // WebsocketReverse 反向WS相关配置
   113  type WebsocketReverse struct {
   114  	Disabled          bool   `yaml:"disabled"`
   115  	Universal         string `yaml:"universal"`
   116  	API               string `yaml:"api"`
   117  	Event             string `yaml:"event"`
   118  	ReconnectInterval int    `yaml:"reconnect-interval"`
   119  
   120  	MiddleWares `yaml:"middlewares"`
   121  }
   122  
   123  func init() {
   124  	config.AddServer(&config.Server{
   125  		Brief:   "正向 Websocket 通信",
   126  		Default: wsDefault,
   127  	})
   128  	config.AddServer(&config.Server{
   129  		Brief:   "反向 Websocket 通信",
   130  		Default: wsReverseDefault,
   131  	})
   132  }
   133  
   134  // runWSServer 运行一个正向WS server
   135  func runWSServer(b *coolq.CQBot, node yaml.Node) {
   136  	var conf WebsocketServer
   137  	switch err := node.Decode(&conf); {
   138  	case err != nil:
   139  		log.Warn("读取正向Websocket配置失败 :", err)
   140  		fallthrough
   141  	case conf.Disabled:
   142  		return
   143  	}
   144  
   145  	network, address := "tcp", conf.Address
   146  	if conf.Address == "" && (conf.Host != "" || conf.Port != 0) {
   147  		log.Warn("正向 Websocket 使用了过时的配置格式,请更新配置文件")
   148  		address = fmt.Sprintf("%s:%d", conf.Host, conf.Port)
   149  	} else {
   150  		uri, err := url.Parse(conf.Address)
   151  		if err == nil && uri.Scheme != "" {
   152  			network = uri.Scheme
   153  			address = uri.Host + uri.Path
   154  		}
   155  	}
   156  	s := &webSocketServer{
   157  		bot:    b,
   158  		conf:   &conf,
   159  		token:  conf.AccessToken,
   160  		filter: conf.Filter,
   161  	}
   162  	filter.Add(s.filter)
   163  	s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`,
   164  		b.Client.Uin, time.Now().Unix())
   165  	b.OnEventPush(s.onBotPushEvent)
   166  	mux := http.ServeMux{}
   167  	mux.HandleFunc("/event", s.event)
   168  	mux.HandleFunc("/api", s.api)
   169  	mux.HandleFunc("/", s.any)
   170  	listener, err := net.Listen(network, address)
   171  	if err != nil {
   172  		log.Fatal(err)
   173  	}
   174  	log.Infof("CQ WebSocket 服务器已启动: %v", listener.Addr())
   175  	log.Fatal(http.Serve(listener, &mux))
   176  }
   177  
   178  // runWSClient 运行一个反向向WS client
   179  func runWSClient(b *coolq.CQBot, node yaml.Node) {
   180  	var conf WebsocketReverse
   181  	switch err := node.Decode(&conf); {
   182  	case err != nil:
   183  		log.Warn("读取反向Websocket配置失败 :", err)
   184  		fallthrough
   185  	case conf.Disabled:
   186  		return
   187  	}
   188  
   189  	c := &websocketClient{
   190  		bot:    b,
   191  		token:  conf.AccessToken,
   192  		filter: conf.Filter,
   193  	}
   194  	filter.Add(c.filter)
   195  
   196  	if conf.ReconnectInterval != 0 {
   197  		c.reconnectInterval = time.Duration(conf.ReconnectInterval) * time.Millisecond
   198  	} else {
   199  		c.reconnectInterval = time.Second * 5
   200  	}
   201  
   202  	if conf.RateLimit.Enabled {
   203  		c.limiter = rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)
   204  	}
   205  
   206  	if conf.Universal != "" {
   207  		c.connect("Universal", conf.Universal, &c.universal)
   208  		c.bot.OnEventPush(c.onBotPushEvent("Universal", conf.Universal, &c.universal))
   209  		return // 连接到 Universal 后, 不再连接其他
   210  	}
   211  	if conf.API != "" {
   212  		c.connect("API", conf.API, nil)
   213  	}
   214  	if conf.Event != "" {
   215  		c.connect("Event", conf.Event, &c.event)
   216  		c.bot.OnEventPush(c.onBotPushEvent("Event", conf.Event, &c.event))
   217  	}
   218  }
   219  
   220  func resolveURI(addr string) (network, address string) {
   221  	network, address = "tcp", addr
   222  	uri, err := url.Parse(addr)
   223  	if err == nil && uri.Scheme != "" {
   224  		scheme, ext, _ := strings.Cut(uri.Scheme, "+")
   225  		if ext != "" {
   226  			network = ext
   227  			uri.Scheme = scheme // remove `+unix`/`+tcp4`
   228  			if ext == "unix" {
   229  				uri.Host, uri.Path, _ = strings.Cut(uri.Path, ":")
   230  				uri.Host = base64.StdEncoding.EncodeToString([]byte(uri.Host))
   231  			}
   232  			address = uri.String()
   233  		}
   234  	}
   235  	return
   236  }
   237  
   238  func (c *websocketClient) connect(typ, addr string, conptr **wsConn) {
   239  	log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, addr)
   240  	header := http.Header{
   241  		"X-Client-Role": []string{typ},
   242  		"X-Self-ID":     []string{strconv.FormatInt(c.bot.Client.Uin, 10)},
   243  		"User-Agent":    []string{"CQHttp/4.15.0"},
   244  	}
   245  	if c.token != "" {
   246  		header["Authorization"] = []string{"Token " + c.token}
   247  	}
   248  
   249  	network, address := resolveURI(addr)
   250  	dialer := websocket.Dialer{
   251  		NetDial: func(_, addr string) (net.Conn, error) {
   252  			if network == "unix" {
   253  				host, _, err := net.SplitHostPort(addr)
   254  				if err != nil {
   255  					host = addr
   256  				}
   257  				filepath, err := base64.RawURLEncoding.DecodeString(host)
   258  				if err == nil {
   259  					addr = string(filepath)
   260  				}
   261  			}
   262  			return net.Dial(network, addr) // support unix socket transport
   263  		},
   264  	}
   265  
   266  	conn, _, err := dialer.Dial(address, header) // nolint
   267  	if err != nil {
   268  		log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, addr, err)
   269  		if c.reconnectInterval != 0 {
   270  			time.Sleep(c.reconnectInterval)
   271  			c.connect(typ, addr, conptr)
   272  		}
   273  		return
   274  	}
   275  
   276  	switch typ {
   277  	case "Event", "Universal":
   278  		handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, c.bot.Client.Uin, time.Now().Unix())
   279  		err = conn.WriteMessage(websocket.TextMessage, []byte(handshake))
   280  		if err != nil {
   281  			log.Warnf("反向WebSocket 握手时出现错误: %v", err)
   282  		}
   283  	}
   284  
   285  	log.Infof("已连接到反向WebSocket %s服务器 %v", typ, addr)
   286  
   287  	var wrappedConn *wsConn
   288  	if conptr != nil && *conptr != nil {
   289  		wrappedConn = *conptr
   290  	} else {
   291  		wrappedConn = new(wsConn)
   292  		if conptr != nil {
   293  			*conptr = wrappedConn
   294  		}
   295  	}
   296  
   297  	wrappedConn.conn = conn
   298  	wrappedConn.apiCaller = api.NewCaller(c.bot)
   299  	if c.limiter != nil {
   300  		wrappedConn.apiCaller.Use(c.limiter)
   301  	}
   302  
   303  	if typ != "Event" {
   304  		go c.listenAPI(typ, addr, wrappedConn)
   305  	}
   306  }
   307  
   308  func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) {
   309  	defer func() { _ = conn.Close() }()
   310  	for {
   311  		buffer := global.NewBuffer()
   312  		t, reader, err := conn.conn.NextReader()
   313  		if err != nil {
   314  			log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
   315  			break
   316  		}
   317  		_, err = buffer.ReadFrom(reader)
   318  		if err != nil {
   319  			log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
   320  			break
   321  		}
   322  		if t == websocket.TextMessage {
   323  			go func(buffer *bytes.Buffer) {
   324  				defer global.PutBuffer(buffer)
   325  				conn.handleRequest(c.bot, buffer.Bytes())
   326  			}(buffer)
   327  		} else {
   328  			global.PutBuffer(buffer)
   329  		}
   330  	}
   331  	if c.reconnectInterval != 0 {
   332  		time.Sleep(c.reconnectInterval)
   333  		if typ == "API" { // Universal 不重连,避免多次重连
   334  			go c.connect(typ, url, nil)
   335  		}
   336  	}
   337  }
   338  
   339  func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e *coolq.Event) {
   340  	return func(e *coolq.Event) {
   341  		c.mu.Lock()
   342  		defer c.mu.Unlock()
   343  
   344  		flt := filter.Find(c.filter)
   345  		if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) {
   346  			log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes())
   347  			return
   348  		}
   349  
   350  		log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes())
   351  		if err := (*conn).WriteText(e.JSONBytes()); err != nil {
   352  			log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err)
   353  			_ = (*conn).Close()
   354  			if c.reconnectInterval != 0 {
   355  				time.Sleep(c.reconnectInterval)
   356  				c.connect(typ, url, conn)
   357  			}
   358  		}
   359  	}
   360  }
   361  
   362  func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
   363  	status := checkAuth(r, s.token)
   364  	if status != http.StatusOK {
   365  		log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
   366  		w.WriteHeader(status)
   367  		return
   368  	}
   369  
   370  	c, err := upgrader.Upgrade(w, r, nil)
   371  	if err != nil {
   372  		log.Warnf("处理 WebSocket 请求时出现错误: %v", err)
   373  		return
   374  	}
   375  
   376  	err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake))
   377  	if err != nil {
   378  		log.Warnf("WebSocket 握手时出现错误: %v", err)
   379  		_ = c.Close()
   380  		return
   381  	}
   382  
   383  	log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr)
   384  	conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)}
   385  	s.mu.Lock()
   386  	s.eventConn = append(s.eventConn, conn)
   387  	s.mu.Unlock()
   388  }
   389  
   390  func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
   391  	status := checkAuth(r, s.token)
   392  	if status != http.StatusOK {
   393  		log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
   394  		w.WriteHeader(status)
   395  		return
   396  	}
   397  
   398  	c, err := upgrader.Upgrade(w, r, nil)
   399  	if err != nil {
   400  		log.Warnf("处理 WebSocket 请求时出现错误: %v", err)
   401  		return
   402  	}
   403  
   404  	log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr)
   405  	conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)}
   406  	if s.conf.RateLimit.Enabled {
   407  		conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
   408  	}
   409  	s.listenAPI(conn)
   410  }
   411  
   412  func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
   413  	status := checkAuth(r, s.token)
   414  	if status != http.StatusOK {
   415  		log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
   416  		w.WriteHeader(status)
   417  		return
   418  	}
   419  
   420  	c, err := upgrader.Upgrade(w, r, nil)
   421  	if err != nil {
   422  		log.Warnf("处理 WebSocket 请求时出现错误: %v", err)
   423  		return
   424  	}
   425  
   426  	err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake))
   427  	if err != nil {
   428  		log.Warnf("WebSocket 握手时出现错误: %v", err)
   429  		_ = c.Close()
   430  		return
   431  	}
   432  
   433  	log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr)
   434  	conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)}
   435  	if s.conf.RateLimit.Enabled {
   436  		conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
   437  	}
   438  	s.mu.Lock()
   439  	s.eventConn = append(s.eventConn, conn)
   440  	s.mu.Unlock()
   441  	s.listenAPI(conn)
   442  }
   443  
   444  func (s *webSocketServer) listenAPI(c *wsConn) {
   445  	defer func() { _ = c.Close() }()
   446  	for {
   447  		buffer := global.NewBuffer()
   448  		t, reader, err := c.conn.NextReader()
   449  		if err != nil {
   450  			break
   451  		}
   452  		_, err = buffer.ReadFrom(reader)
   453  		if err != nil {
   454  			break
   455  		}
   456  
   457  		if t == websocket.TextMessage {
   458  			go func(buffer *bytes.Buffer) {
   459  				defer global.PutBuffer(buffer)
   460  				c.handleRequest(s.bot, buffer.Bytes())
   461  			}(buffer)
   462  		} else {
   463  			global.PutBuffer(buffer)
   464  		}
   465  	}
   466  }
   467  
   468  func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
   469  	defer func() {
   470  		if err := recover(); err != nil {
   471  			log.Errorf("处置WS命令时发生无法恢复的异常:%v\n%s", err, debug.Stack())
   472  			_ = c.Close()
   473  		}
   474  	}()
   475  
   476  	j := gjson.Parse(utils.B2S(payload))
   477  	t := strings.TrimSuffix(j.Get("action").Str, "_async")
   478  	params := j.Get("params")
   479  	log.Debugf("WS接收到API调用: %v 参数: %v", t, params.Raw)
   480  	ret := c.apiCaller.Call(t, onebot.V11, params)
   481  	if j.Get("echo").Exists() {
   482  		ret["echo"] = j.Get("echo").Value()
   483  	}
   484  
   485  	c.mu.Lock()
   486  	defer c.mu.Unlock()
   487  	_ = c.conn.SetWriteDeadline(time.Now().Add(time.Second * 15))
   488  	writer, err := c.conn.NextWriter(websocket.TextMessage)
   489  	if err != nil {
   490  		log.Errorf("无法响应API调用(连接已断开?): %v", err)
   491  		return
   492  	}
   493  	_ = json.NewEncoder(writer).Encode(ret)
   494  	_ = writer.Close()
   495  }
   496  
   497  func (s *webSocketServer) onBotPushEvent(e *coolq.Event) {
   498  	flt := filter.Find(s.filter)
   499  	if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) {
   500  		log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JSONBytes())
   501  		return
   502  	}
   503  
   504  	s.mu.Lock()
   505  	defer s.mu.Unlock()
   506  
   507  	j := 0
   508  	for i := 0; i < len(s.eventConn); i++ {
   509  		conn := s.eventConn[i]
   510  		log.Debugf("向WS客户端推送Event: %s", e.JSONBytes())
   511  		if err := conn.WriteText(e.JSONBytes()); err != nil {
   512  			_ = conn.Close()
   513  			conn = nil
   514  			continue
   515  		}
   516  		if i != j {
   517  			// i != j means that some connection has been closed.
   518  			// use an in-place removal to avoid copying.
   519  			s.eventConn[j] = conn
   520  		}
   521  		j++
   522  	}
   523  	s.eventConn = s.eventConn[:j]
   524  }