github.com/Mrs4s/go-cqhttp@v1.2.0/server/websocket.go (about) 1 package server 2 3 import ( 4 "bytes" 5 "encoding/base64" 6 "encoding/json" 7 "fmt" 8 "net" 9 "net/http" 10 "net/url" 11 "runtime/debug" 12 "strconv" 13 "strings" 14 "sync" 15 "time" 16 17 "github.com/Mrs4s/MiraiGo/utils" 18 "github.com/RomiChan/websocket" 19 log "github.com/sirupsen/logrus" 20 "github.com/tidwall/gjson" 21 "gopkg.in/yaml.v3" 22 23 "github.com/Mrs4s/go-cqhttp/coolq" 24 "github.com/Mrs4s/go-cqhttp/global" 25 "github.com/Mrs4s/go-cqhttp/modules/api" 26 "github.com/Mrs4s/go-cqhttp/modules/config" 27 "github.com/Mrs4s/go-cqhttp/modules/filter" 28 "github.com/Mrs4s/go-cqhttp/pkg/onebot" 29 ) 30 31 type webSocketServer struct { 32 bot *coolq.CQBot 33 conf *WebsocketServer 34 35 mu sync.Mutex 36 eventConn []*wsConn 37 38 token string 39 handshake string 40 filter string 41 } 42 43 // websocketClient WebSocket客户端实例 44 type websocketClient struct { 45 bot *coolq.CQBot 46 mu sync.Mutex 47 universal *wsConn 48 event *wsConn 49 50 token string 51 filter string 52 reconnectInterval time.Duration 53 limiter api.Handler 54 } 55 56 type wsConn struct { 57 mu sync.Mutex 58 conn *websocket.Conn 59 apiCaller *api.Caller 60 } 61 62 func (c *wsConn) WriteText(b []byte) error { 63 c.mu.Lock() 64 defer c.mu.Unlock() 65 _ = c.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) 66 return c.conn.WriteMessage(websocket.TextMessage, b) 67 } 68 69 func (c *wsConn) Close() error { 70 return c.conn.Close() 71 } 72 73 var upgrader = websocket.Upgrader{ 74 CheckOrigin: func(r *http.Request) bool { 75 return true 76 }, 77 } 78 79 const wsDefault = ` # 正向WS设置 80 - ws: 81 # 正向WS服务器监听地址 82 address: 0.0.0.0:8080 83 middlewares: 84 <<: *default # 引用默认中间件 85 ` 86 87 const wsReverseDefault = ` # 反向WS设置 88 - ws-reverse: 89 # 反向WS Universal 地址 90 # 注意 设置了此项地址后下面两项将会被忽略 91 universal: ws://your_websocket_universal.server 92 # 反向WS API 地址 93 api: ws://your_websocket_api.server 94 # 反向WS Event 地址 95 event: ws://your_websocket_event.server 96 # 重连间隔 单位毫秒 97 reconnect-interval: 3000 98 middlewares: 99 <<: *default # 引用默认中间件 100 ` 101 102 // WebsocketServer 正向WS相关配置 103 type WebsocketServer struct { 104 Disabled bool `yaml:"disabled"` 105 Address string `yaml:"address"` 106 Host string `yaml:"host"` 107 Port int `yaml:"port"` 108 109 MiddleWares `yaml:"middlewares"` 110 } 111 112 // WebsocketReverse 反向WS相关配置 113 type WebsocketReverse struct { 114 Disabled bool `yaml:"disabled"` 115 Universal string `yaml:"universal"` 116 API string `yaml:"api"` 117 Event string `yaml:"event"` 118 ReconnectInterval int `yaml:"reconnect-interval"` 119 120 MiddleWares `yaml:"middlewares"` 121 } 122 123 func init() { 124 config.AddServer(&config.Server{ 125 Brief: "正向 Websocket 通信", 126 Default: wsDefault, 127 }) 128 config.AddServer(&config.Server{ 129 Brief: "反向 Websocket 通信", 130 Default: wsReverseDefault, 131 }) 132 } 133 134 // runWSServer 运行一个正向WS server 135 func runWSServer(b *coolq.CQBot, node yaml.Node) { 136 var conf WebsocketServer 137 switch err := node.Decode(&conf); { 138 case err != nil: 139 log.Warn("读取正向Websocket配置失败 :", err) 140 fallthrough 141 case conf.Disabled: 142 return 143 } 144 145 network, address := "tcp", conf.Address 146 if conf.Address == "" && (conf.Host != "" || conf.Port != 0) { 147 log.Warn("正向 Websocket 使用了过时的配置格式,请更新配置文件") 148 address = fmt.Sprintf("%s:%d", conf.Host, conf.Port) 149 } else { 150 uri, err := url.Parse(conf.Address) 151 if err == nil && uri.Scheme != "" { 152 network = uri.Scheme 153 address = uri.Host + uri.Path 154 } 155 } 156 s := &webSocketServer{ 157 bot: b, 158 conf: &conf, 159 token: conf.AccessToken, 160 filter: conf.Filter, 161 } 162 filter.Add(s.filter) 163 s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, 164 b.Client.Uin, time.Now().Unix()) 165 b.OnEventPush(s.onBotPushEvent) 166 mux := http.ServeMux{} 167 mux.HandleFunc("/event", s.event) 168 mux.HandleFunc("/api", s.api) 169 mux.HandleFunc("/", s.any) 170 listener, err := net.Listen(network, address) 171 if err != nil { 172 log.Fatal(err) 173 } 174 log.Infof("CQ WebSocket 服务器已启动: %v", listener.Addr()) 175 log.Fatal(http.Serve(listener, &mux)) 176 } 177 178 // runWSClient 运行一个反向向WS client 179 func runWSClient(b *coolq.CQBot, node yaml.Node) { 180 var conf WebsocketReverse 181 switch err := node.Decode(&conf); { 182 case err != nil: 183 log.Warn("读取反向Websocket配置失败 :", err) 184 fallthrough 185 case conf.Disabled: 186 return 187 } 188 189 c := &websocketClient{ 190 bot: b, 191 token: conf.AccessToken, 192 filter: conf.Filter, 193 } 194 filter.Add(c.filter) 195 196 if conf.ReconnectInterval != 0 { 197 c.reconnectInterval = time.Duration(conf.ReconnectInterval) * time.Millisecond 198 } else { 199 c.reconnectInterval = time.Second * 5 200 } 201 202 if conf.RateLimit.Enabled { 203 c.limiter = rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket) 204 } 205 206 if conf.Universal != "" { 207 c.connect("Universal", conf.Universal, &c.universal) 208 c.bot.OnEventPush(c.onBotPushEvent("Universal", conf.Universal, &c.universal)) 209 return // 连接到 Universal 后, 不再连接其他 210 } 211 if conf.API != "" { 212 c.connect("API", conf.API, nil) 213 } 214 if conf.Event != "" { 215 c.connect("Event", conf.Event, &c.event) 216 c.bot.OnEventPush(c.onBotPushEvent("Event", conf.Event, &c.event)) 217 } 218 } 219 220 func resolveURI(addr string) (network, address string) { 221 network, address = "tcp", addr 222 uri, err := url.Parse(addr) 223 if err == nil && uri.Scheme != "" { 224 scheme, ext, _ := strings.Cut(uri.Scheme, "+") 225 if ext != "" { 226 network = ext 227 uri.Scheme = scheme // remove `+unix`/`+tcp4` 228 if ext == "unix" { 229 uri.Host, uri.Path, _ = strings.Cut(uri.Path, ":") 230 uri.Host = base64.StdEncoding.EncodeToString([]byte(uri.Host)) 231 } 232 address = uri.String() 233 } 234 } 235 return 236 } 237 238 func (c *websocketClient) connect(typ, addr string, conptr **wsConn) { 239 log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, addr) 240 header := http.Header{ 241 "X-Client-Role": []string{typ}, 242 "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, 243 "User-Agent": []string{"CQHttp/4.15.0"}, 244 } 245 if c.token != "" { 246 header["Authorization"] = []string{"Token " + c.token} 247 } 248 249 network, address := resolveURI(addr) 250 dialer := websocket.Dialer{ 251 NetDial: func(_, addr string) (net.Conn, error) { 252 if network == "unix" { 253 host, _, err := net.SplitHostPort(addr) 254 if err != nil { 255 host = addr 256 } 257 filepath, err := base64.RawURLEncoding.DecodeString(host) 258 if err == nil { 259 addr = string(filepath) 260 } 261 } 262 return net.Dial(network, addr) // support unix socket transport 263 }, 264 } 265 266 conn, _, err := dialer.Dial(address, header) // nolint 267 if err != nil { 268 log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, addr, err) 269 if c.reconnectInterval != 0 { 270 time.Sleep(c.reconnectInterval) 271 c.connect(typ, addr, conptr) 272 } 273 return 274 } 275 276 switch typ { 277 case "Event", "Universal": 278 handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, c.bot.Client.Uin, time.Now().Unix()) 279 err = conn.WriteMessage(websocket.TextMessage, []byte(handshake)) 280 if err != nil { 281 log.Warnf("反向WebSocket 握手时出现错误: %v", err) 282 } 283 } 284 285 log.Infof("已连接到反向WebSocket %s服务器 %v", typ, addr) 286 287 var wrappedConn *wsConn 288 if conptr != nil && *conptr != nil { 289 wrappedConn = *conptr 290 } else { 291 wrappedConn = new(wsConn) 292 if conptr != nil { 293 *conptr = wrappedConn 294 } 295 } 296 297 wrappedConn.conn = conn 298 wrappedConn.apiCaller = api.NewCaller(c.bot) 299 if c.limiter != nil { 300 wrappedConn.apiCaller.Use(c.limiter) 301 } 302 303 if typ != "Event" { 304 go c.listenAPI(typ, addr, wrappedConn) 305 } 306 } 307 308 func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) { 309 defer func() { _ = conn.Close() }() 310 for { 311 buffer := global.NewBuffer() 312 t, reader, err := conn.conn.NextReader() 313 if err != nil { 314 log.Warnf("监听反向WS %s时出现错误: %v", typ, err) 315 break 316 } 317 _, err = buffer.ReadFrom(reader) 318 if err != nil { 319 log.Warnf("监听反向WS %s时出现错误: %v", typ, err) 320 break 321 } 322 if t == websocket.TextMessage { 323 go func(buffer *bytes.Buffer) { 324 defer global.PutBuffer(buffer) 325 conn.handleRequest(c.bot, buffer.Bytes()) 326 }(buffer) 327 } else { 328 global.PutBuffer(buffer) 329 } 330 } 331 if c.reconnectInterval != 0 { 332 time.Sleep(c.reconnectInterval) 333 if typ == "API" { // Universal 不重连,避免多次重连 334 go c.connect(typ, url, nil) 335 } 336 } 337 } 338 339 func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e *coolq.Event) { 340 return func(e *coolq.Event) { 341 c.mu.Lock() 342 defer c.mu.Unlock() 343 344 flt := filter.Find(c.filter) 345 if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) { 346 log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes()) 347 return 348 } 349 350 log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes()) 351 if err := (*conn).WriteText(e.JSONBytes()); err != nil { 352 log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err) 353 _ = (*conn).Close() 354 if c.reconnectInterval != 0 { 355 time.Sleep(c.reconnectInterval) 356 c.connect(typ, url, conn) 357 } 358 } 359 } 360 } 361 362 func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { 363 status := checkAuth(r, s.token) 364 if status != http.StatusOK { 365 log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status) 366 w.WriteHeader(status) 367 return 368 } 369 370 c, err := upgrader.Upgrade(w, r, nil) 371 if err != nil { 372 log.Warnf("处理 WebSocket 请求时出现错误: %v", err) 373 return 374 } 375 376 err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake)) 377 if err != nil { 378 log.Warnf("WebSocket 握手时出现错误: %v", err) 379 _ = c.Close() 380 return 381 } 382 383 log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) 384 conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)} 385 s.mu.Lock() 386 s.eventConn = append(s.eventConn, conn) 387 s.mu.Unlock() 388 } 389 390 func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { 391 status := checkAuth(r, s.token) 392 if status != http.StatusOK { 393 log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status) 394 w.WriteHeader(status) 395 return 396 } 397 398 c, err := upgrader.Upgrade(w, r, nil) 399 if err != nil { 400 log.Warnf("处理 WebSocket 请求时出现错误: %v", err) 401 return 402 } 403 404 log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) 405 conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)} 406 if s.conf.RateLimit.Enabled { 407 conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) 408 } 409 s.listenAPI(conn) 410 } 411 412 func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { 413 status := checkAuth(r, s.token) 414 if status != http.StatusOK { 415 log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status) 416 w.WriteHeader(status) 417 return 418 } 419 420 c, err := upgrader.Upgrade(w, r, nil) 421 if err != nil { 422 log.Warnf("处理 WebSocket 请求时出现错误: %v", err) 423 return 424 } 425 426 err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake)) 427 if err != nil { 428 log.Warnf("WebSocket 握手时出现错误: %v", err) 429 _ = c.Close() 430 return 431 } 432 433 log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) 434 conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)} 435 if s.conf.RateLimit.Enabled { 436 conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) 437 } 438 s.mu.Lock() 439 s.eventConn = append(s.eventConn, conn) 440 s.mu.Unlock() 441 s.listenAPI(conn) 442 } 443 444 func (s *webSocketServer) listenAPI(c *wsConn) { 445 defer func() { _ = c.Close() }() 446 for { 447 buffer := global.NewBuffer() 448 t, reader, err := c.conn.NextReader() 449 if err != nil { 450 break 451 } 452 _, err = buffer.ReadFrom(reader) 453 if err != nil { 454 break 455 } 456 457 if t == websocket.TextMessage { 458 go func(buffer *bytes.Buffer) { 459 defer global.PutBuffer(buffer) 460 c.handleRequest(s.bot, buffer.Bytes()) 461 }(buffer) 462 } else { 463 global.PutBuffer(buffer) 464 } 465 } 466 } 467 468 func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { 469 defer func() { 470 if err := recover(); err != nil { 471 log.Errorf("处置WS命令时发生无法恢复的异常:%v\n%s", err, debug.Stack()) 472 _ = c.Close() 473 } 474 }() 475 476 j := gjson.Parse(utils.B2S(payload)) 477 t := strings.TrimSuffix(j.Get("action").Str, "_async") 478 params := j.Get("params") 479 log.Debugf("WS接收到API调用: %v 参数: %v", t, params.Raw) 480 ret := c.apiCaller.Call(t, onebot.V11, params) 481 if j.Get("echo").Exists() { 482 ret["echo"] = j.Get("echo").Value() 483 } 484 485 c.mu.Lock() 486 defer c.mu.Unlock() 487 _ = c.conn.SetWriteDeadline(time.Now().Add(time.Second * 15)) 488 writer, err := c.conn.NextWriter(websocket.TextMessage) 489 if err != nil { 490 log.Errorf("无法响应API调用(连接已断开?): %v", err) 491 return 492 } 493 _ = json.NewEncoder(writer).Encode(ret) 494 _ = writer.Close() 495 } 496 497 func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { 498 flt := filter.Find(s.filter) 499 if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) { 500 log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JSONBytes()) 501 return 502 } 503 504 s.mu.Lock() 505 defer s.mu.Unlock() 506 507 j := 0 508 for i := 0; i < len(s.eventConn); i++ { 509 conn := s.eventConn[i] 510 log.Debugf("向WS客户端推送Event: %s", e.JSONBytes()) 511 if err := conn.WriteText(e.JSONBytes()); err != nil { 512 _ = conn.Close() 513 conn = nil 514 continue 515 } 516 if i != j { 517 // i != j means that some connection has been closed. 518 // use an in-place removal to avoid copying. 519 s.eventConn[j] = conn 520 } 521 j++ 522 } 523 s.eventConn = s.eventConn[:j] 524 }