github.com/fumiama/NanoBot@v0.0.0-20231122134259-c22d8183efca/bot.go (about)

     1  package nano
     2  
     3  import (
     4  	"encoding/base64"
     5  	"encoding/json"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"reflect"
    10  	"strconv"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  	"unsafe"
    15  
    16  	"github.com/RomiChan/syncx"
    17  	"github.com/RomiChan/websocket"
    18  	log "github.com/sirupsen/logrus"
    19  )
    20  
    21  var (
    22  	clients   = syncx.Map[string, *Bot]{}
    23  	isrunning uintptr
    24  )
    25  
    26  const (
    27  	// SuperUserAllQQUsers 使所有 QQ 用户成为超级用户
    28  	SuperUserAllQQUsers = "AllQQUsers"
    29  )
    30  
    31  // Bot 一个机器人实例的配置
    32  type Bot struct {
    33  	AppID      string          `yaml:"AppID"` // AppID is BotAppID(开发者ID)
    34  	Token      string          `yaml:"Token"` // Token is 机器人令牌 有 Secret 则使用新版 API
    35  	token      string          // token 是通过 secret 获得的残血 token
    36  	Secret     string          `yaml:"Secret"`     // Secret is 机器人令牌 V2 (AppSecret/ClientSecret) 沙盒目前虽然能登录但无法收发消息
    37  	SuperUsers []string        `yaml:"SuperUsers"` // SuperUsers 超级用户, 特殊: AllQQUsers 将使所有 QQ 用户成为超级用户
    38  	Timeout    time.Duration   `yaml:"Timeout"`    // Timeout is API 调用超时
    39  	Handler    *Handler        `yaml:"-"`          // Handler 注册对各种事件的处理
    40  	Intents    uint32          `yaml:"Intents"`    // Intents 欲接收的事件
    41  	ShardIndex uint8           `yaml:"ShardIndex"` // ShardIndex 本连接为第几个分片, 默认 1, 0 为不使用分片
    42  	ShardCount uint8           `yaml:"ShardCount"` // ShardCount 分片总数
    43  	shard      [2]byte         // shard 分片
    44  	Properties json.RawMessage `yaml:"Properties"` // Properties 一些环境变量, 目前没用
    45  
    46  	gateway   string                      // gateway 获得的网关
    47  	seq       uint32                      // seq 最新的 s
    48  	heartbeat uint32                      // heartbeat 心跳周期, 单位毫秒
    49  	expiresec int64                       // expiresec Token 有效时间
    50  	handlers  map[string]eventHandlerType // handlers 方便调用的 handler
    51  	mu        sync.Mutex                  // 写锁
    52  	conn      *websocket.Conn             // conn 目前的 wss 连接
    53  	hbonce    sync.Once                   // hbonce 保证仅执行一次 heartbeat
    54  	exonce    sync.Once                   // exonce 保证仅执行一次刷新 token
    55  	client    *http.Client                // client 主要配置 timeout
    56  
    57  	ready EventReady // ready 连接成功后下发的 bot 基本信息
    58  }
    59  
    60  // GetReady 获得 bot 基本信息
    61  func (ctx *Ctx) GetReady() *EventReady {
    62  	return &ctx.caller.ready
    63  }
    64  
    65  // getinitinfo 获得 gateway 和 shard
    66  func (bot *Bot) getinitinfo() (secret, gw string, shard [2]byte, err error) {
    67  	shard[1] = 1
    68  	if bot.client == nil {
    69  		bot.client = http.DefaultClient
    70  	}
    71  	secret = bot.Secret
    72  	if bot.Secret != "" {
    73  		bot.Secret = ""
    74  	}
    75  	if bot.ShardIndex == 0 {
    76  		gw, err = bot.GetGeneralWSSGatewayNoContext()
    77  		if err != nil {
    78  			return
    79  		}
    80  	} else {
    81  		var sgw *ShardWSSGateway
    82  		sgw, err = bot.GetShardWSSGatewayNoContext()
    83  		if err != nil {
    84  			return
    85  		}
    86  		if bot.ShardCount == 0 {
    87  			log.Infoln(getLogHeader(), "使用网关推荐Shards数:", sgw.Shards)
    88  			bot.ShardCount = uint8(sgw.Shards)
    89  		}
    90  		if bot.ShardCount <= bot.ShardIndex {
    91  			err = errors.New("shard index " + strconv.Itoa(int(bot.ShardIndex)) + " >= suggested size " + strconv.Itoa(sgw.Shards))
    92  			return
    93  		}
    94  		gw = sgw.URL
    95  		shard[0] = byte(bot.ShardIndex)
    96  		shard[1] = byte(bot.ShardCount)
    97  	}
    98  	return
    99  }
   100  
   101  // Start clients without blocking
   102  func Start(bots ...*Bot) error {
   103  	if !atomic.CompareAndSwapUintptr(&isrunning, 0, 1) {
   104  		log.Warnln(getLogHeader(), "已忽略重复调用的", getThisFuncName())
   105  	}
   106  	for _, b := range bots {
   107  		s, gw, shard, err := b.getinitinfo()
   108  		if err != nil {
   109  			return err
   110  		}
   111  		go b.Init(s, gw, shard).Connect().Listen()
   112  	}
   113  	return nil
   114  }
   115  
   116  // Run clients and block self in listening last one
   117  func Run(preblock func(), bots ...*Bot) error {
   118  	if !atomic.CompareAndSwapUintptr(&isrunning, 0, 1) {
   119  		log.Warnln(getLogHeader(), "已忽略重复调用的", getThisFuncName())
   120  	}
   121  	var b *Bot
   122  	switch len(bots) {
   123  	case 0:
   124  		return nil
   125  	case 1:
   126  		b = bots[0]
   127  		s, gw, shard, err := b.getinitinfo()
   128  		if err != nil {
   129  			return err
   130  		}
   131  		b.Init(s, gw, shard)
   132  	default:
   133  		for _, b := range bots[:len(bots)-1] {
   134  			s, gw, shard, err := b.getinitinfo()
   135  			if err != nil {
   136  				return err
   137  			}
   138  			go b.Init(s, gw, shard).Connect().Listen()
   139  		}
   140  		b = bots[len(bots)-1]
   141  		s, gw, shard, err := b.getinitinfo()
   142  		if err != nil {
   143  			return err
   144  		}
   145  		b.Init(s, gw, shard)
   146  	}
   147  	b.Connect()
   148  	if preblock != nil {
   149  		preblock()
   150  	}
   151  	b.Listen()
   152  	return nil
   153  }
   154  
   155  // Init 初始化, 只需执行一次
   156  func (bot *Bot) Init(secret, gateway string, shard [2]byte) *Bot {
   157  	bot.gateway = gateway
   158  	bot.shard = shard
   159  	if bot.Timeout == 0 {
   160  		bot.Timeout = time.Minute
   161  	}
   162  	bot.client = &http.Client{
   163  		Timeout: bot.Timeout,
   164  	}
   165  	if bot.Handler != nil {
   166  		h := reflect.ValueOf(bot.Handler).Elem()
   167  		t := h.Type()
   168  		bot.handlers = make(map[string]eventHandlerType, h.NumField()*4)
   169  		for i := 0; i < h.NumField(); i++ {
   170  			f := h.Field(i)
   171  			if f.IsZero() {
   172  				continue
   173  			}
   174  			tp := t.Field(i).Name[2:] // skip On
   175  			log.Infoln(getLogHeader(), "注册处理函数", tp)
   176  			handler := f.Interface()
   177  			bot.handlers[tp] = eventHandlerType{
   178  				h: *(*generalHandleType)(unsafe.Add(unsafe.Pointer(&handler), unsafe.Sizeof(uintptr(0)))),
   179  				t: t.Field(i).Type.In(2).Elem(),
   180  			}
   181  		}
   182  	}
   183  	bot.Secret = secret
   184  	if bot.IsV2() {
   185  		for {
   186  			err := bot.GetAppAccessTokenNoContext()
   187  			if err == nil {
   188  				log.Infoln(getLogHeader(), "获得 Token: "+bot.token+", 超时:", bot.expiresec, "秒")
   189  				bot.exonce.Do(func() {
   190  					go bot.refreshtoken()
   191  				})
   192  				break
   193  			}
   194  			log.Infoln(getLogHeader(), "获得 Token 失败:", err)
   195  			time.Sleep(time.Second * 3)
   196  		}
   197  	}
   198  	return bot
   199  }
   200  
   201  // IsV2 判断是否运行于 V2 API 下
   202  func (bot *Bot) IsV2() bool {
   203  	return bot.Secret != ""
   204  }
   205  
   206  // Authorization 返回 Authorization Header value
   207  func (bot *Bot) Authorization() string {
   208  	if bot.IsV2() {
   209  		return "QQBot " + bot.token
   210  	}
   211  	return "Bot " + bot.AppID + "." + bot.Token
   212  }
   213  
   214  // receive 收一个 payload
   215  func (bot *Bot) reveive() (payload WebsocketPayload, err error) {
   216  	err = bot.conn.ReadJSON(&payload)
   217  	return
   218  }
   219  
   220  // Connect 连接到 Gateway + 鉴权连接
   221  //
   222  // https://bot.q.qq.com/wiki/develop/api/gateway/reference.html#_1-%E8%BF%9E%E6%8E%A5%E5%88%B0-gateway
   223  func (bot *Bot) Connect() *Bot {
   224  	network, address := resolveURI(bot.gateway)
   225  	log.Infoln(getLogHeader(), "开始尝试连接到网关:", address, ", AppID:", bot.AppID)
   226  	dialer := websocket.Dialer{
   227  		NetDial: func(_, addr string) (net.Conn, error) {
   228  			if network == "unix" {
   229  				host, _, err := net.SplitHostPort(addr)
   230  				if err != nil {
   231  					host = addr
   232  				}
   233  				filepath, err := base64.RawURLEncoding.DecodeString(host)
   234  				if err == nil {
   235  					addr = BytesToString(filepath)
   236  				}
   237  			}
   238  			return net.Dial(network, addr) // support unix socket transport
   239  		},
   240  	}
   241  	for {
   242  		conn, resp, err := dialer.Dial(address, http.Header{})
   243  		if err != nil {
   244  			log.Warnf(getLogHeader(), "连接到网关 %v 时出现错误: %v", bot.gateway, err)
   245  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   246  			continue
   247  		}
   248  		bot.conn = conn
   249  		_ = resp.Body.Close()
   250  		payload, err := bot.reveive()
   251  		if err != nil {
   252  			log.Warnln(getLogHeader(), "获取心跳间隔时出现错误:", err)
   253  			_ = conn.Close()
   254  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   255  			continue
   256  		}
   257  		hb, err := payload.GetHeartbeatInterval()
   258  		if err != nil {
   259  			log.Warnln(getLogHeader(), "解析心跳间隔时出现错误:", err)
   260  			_ = conn.Close()
   261  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   262  			continue
   263  		}
   264  		payload.Op = OpCodeIdentify
   265  		err = payload.WrapData(&OpCodeIdentifyMessage{
   266  			Token:      bot.Authorization(),
   267  			Intents:    bot.Intents,
   268  			Shard:      bot.shard,
   269  			Properties: bot.Properties,
   270  		})
   271  		if err != nil {
   272  			log.Warnln(getLogHeader(), "包装 Identify 时出现错误:", err)
   273  			_ = conn.Close()
   274  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   275  			continue
   276  		}
   277  		err = bot.SendPayload(&payload)
   278  		if err != nil {
   279  			log.Warnln(getLogHeader(), "发送 Identify 时出现错误:", err)
   280  			_ = conn.Close()
   281  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   282  			continue
   283  		}
   284  		payload, err = bot.reveive()
   285  		if err != nil {
   286  			log.Warnln(getLogHeader(), "获取 EventReady 时出现错误:", err)
   287  			_ = conn.Close()
   288  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   289  			continue
   290  		}
   291  		bot.ready, bot.seq, err = payload.GetEventReady()
   292  		if err != nil {
   293  			log.Warnln(getLogHeader(), "解析 EventReady 时出现错误:", err)
   294  			_ = conn.Close()
   295  			time.Sleep(2 * time.Second) // 等待两秒后重新连接
   296  			continue
   297  		}
   298  		atomic.StoreUint32(&bot.heartbeat, hb)
   299  		break
   300  	}
   301  	clients.Store(bot.Token+"_"+strconv.Itoa(int(bot.shard[0])), bot)
   302  	log.Infoln(getLogHeader(), "连接到网关成功, 用户名:", bot.ready.User.Username)
   303  	bot.hbonce.Do(func() {
   304  		go bot.doheartbeat()
   305  	})
   306  	return bot
   307  }
   308  
   309  // refreshtoken 以 Expire 为间隔刷新 Token
   310  func (bot *Bot) refreshtoken() {
   311  	for {
   312  		time.Sleep(time.Second * 10)
   313  		if atomic.LoadUint32(&bot.heartbeat) == 0 {
   314  			log.Warnln(getLogHeader(), "等待服务器建立连接...")
   315  			continue
   316  		}
   317  		time.Sleep(time.Duration(bot.expiresec) * time.Second)
   318  		err := bot.GetAppAccessTokenNoContext()
   319  		if err != nil {
   320  			log.Warnln(getLogHeader(), "刷新 Token 时出现错误:", err)
   321  		} else {
   322  			log.Infoln(getLogHeader(), "刷新 Token: "+bot.token+", 超时:", bot.expiresec, "秒")
   323  		}
   324  	}
   325  }
   326  
   327  // doheartbeat 按指定间隔进行心跳包发送
   328  func (bot *Bot) doheartbeat() {
   329  	payload := struct {
   330  		Op OpCode  `json:"op"`
   331  		D  *uint32 `json:"d"`
   332  	}{Op: OpCodeHeartbeat}
   333  	for {
   334  		if atomic.LoadUint32(&bot.heartbeat) == 0 {
   335  			time.Sleep(time.Second)
   336  			log.Warnln(getLogHeader(), "等待服务器建立连接...")
   337  			continue
   338  		}
   339  		time.Sleep(time.Duration(bot.heartbeat) * time.Millisecond)
   340  		if bot.seq == 0 {
   341  			payload.D = nil
   342  		} else {
   343  			payload.D = &bot.seq
   344  		}
   345  		bot.mu.Lock()
   346  		err := bot.conn.WriteJSON(&payload)
   347  		bot.mu.Unlock()
   348  		if err != nil {
   349  			log.Warnln(getLogHeader(), "发送心跳时出现错误:", err)
   350  		}
   351  	}
   352  }
   353  
   354  // Resume 恢复连接
   355  //
   356  // https://bot.q.qq.com/wiki/develop/api/gateway/reference.html#_4-%E6%81%A2%E5%A4%8D%E8%BF%9E%E6%8E%A5
   357  func (bot *Bot) Resume() error {
   358  	network, address := resolveURI(bot.gateway)
   359  	dialer := websocket.Dialer{
   360  		NetDial: func(_, addr string) (net.Conn, error) {
   361  			if network == "unix" {
   362  				host, _, err := net.SplitHostPort(addr)
   363  				if err != nil {
   364  					host = addr
   365  				}
   366  				filepath, err := base64.RawURLEncoding.DecodeString(host)
   367  				if err == nil {
   368  					addr = BytesToString(filepath)
   369  				}
   370  			}
   371  			return net.Dial(network, addr) // support unix socket transport
   372  		},
   373  	}
   374  	conn, resp, err := dialer.Dial(address, http.Header{})
   375  	if err != nil {
   376  		return err
   377  	}
   378  	bot.conn = conn
   379  	_ = resp.Body.Close()
   380  	payload := WebsocketPayload{Op: OpCodeResume}
   381  	payload.WrapData(&struct {
   382  		T string `json:"token"`
   383  		S string `json:"session_id"`
   384  		Q uint32 `json:"seq"`
   385  	}{bot.Authorization(), bot.ready.SessionID, bot.seq})
   386  	return bot.SendPayload(&payload)
   387  }
   388  
   389  // Listen 监听事件
   390  func (bot *Bot) Listen() {
   391  	log.Infoln(getLogHeader(), "开始监听", bot.ready.User.Username, "的事件")
   392  	payload := WebsocketPayload{}
   393  	lastheartbeat := time.Now()
   394  	for {
   395  		payload.Reset()
   396  		err := bot.conn.ReadJSON(&payload)
   397  		if err != nil { // reconnect
   398  			atomic.StoreUint32(&bot.heartbeat, 0)
   399  			k := bot.Token + "_" + strconv.Itoa(int(bot.shard[0]))
   400  			clients.Delete(k)
   401  			log.Warnln(getLogHeader(), bot.ready.User.Username, "的网关连接断开, 尝试恢复:", err)
   402  			for {
   403  				time.Sleep(time.Second)
   404  				err = bot.Resume()
   405  				if err == nil {
   406  					break
   407  				}
   408  				log.Warnln(getLogHeader(), bot.ready.User.Username, "的网关连接恢复失败:", err)
   409  			}
   410  			clients.Store(k, bot)
   411  			continue
   412  		}
   413  		log.Debug(getLogHeader(), " 接收到第 ", payload.S, " 个事件: ", payload.Op, ", 类型: ", payload.T, ", 数据: ", BytesToString(payload.D))
   414  		switch payload.Op {
   415  		case OpCodeDispatch: // Receive
   416  			if payload.S <= bot.seq {
   417  				log.Warn(getLogHeader(), " 忽略重复编号: ", payload.S, ", 事件: ", payload.Op, ", 类型: ", payload.T)
   418  				continue
   419  			}
   420  			switch payload.T {
   421  			case "RESUMED":
   422  				log.Infoln(getLogHeader(), bot.ready.User.Username, "的网关连接恢复完成")
   423  			default:
   424  				bot.processEvent(&payload)
   425  			}
   426  		case OpCodeHeartbeat: // Send/Receive
   427  			log.Debugln(getLogHeader(), "收到服务端推送心跳, 间隔:", time.Since(lastheartbeat))
   428  			lastheartbeat = time.Now()
   429  		case OpCodeReconnect: // Receive
   430  			log.Warnln(getLogHeader(), "收到服务端通知重连")
   431  			atomic.StoreUint32(&bot.heartbeat, 0)
   432  			bot.Connect()
   433  		case OpCodeInvalidSession: // Receive
   434  			log.Warnln(getLogHeader(), bot.ready.User.Username, "的网关连接恢复失败: InvalidSession, 尝试重连...")
   435  			atomic.StoreUint32(&bot.heartbeat, 0)
   436  			bot.Connect()
   437  		case OpCodeHello: // Receive
   438  			intv, err := payload.GetHeartbeatInterval()
   439  			if err != nil {
   440  				log.Warnln(getLogHeader(), "解析心跳间隔时出现错误:", err)
   441  				continue
   442  			}
   443  			atomic.StoreUint32(&bot.heartbeat, intv)
   444  		case OpCodeHeartbeatACK: // Receive/Reply
   445  			log.Debugln(getLogHeader(), "收到心跳返回, 间隔:", time.Since(lastheartbeat))
   446  			lastheartbeat = time.Now()
   447  		case OpCodeHTTPCallbackACK: // Reply
   448  		default:
   449  			log.Warn(getLogHeader(), " 忽略未知事件, 序号: ", payload.S, ", Op: ", payload.Op, ", 类型: ", payload.T, ", 数据: ", BytesToString(payload.D))
   450  		}
   451  		if payload.S > bot.seq {
   452  			bot.seq = payload.S
   453  		}
   454  	}
   455  }
   456  
   457  // GetBot 获取指定的bot (Ctx)实例
   458  func GetBot(id string) *Ctx {
   459  	caller, ok := clients.Load(id)
   460  	if !ok {
   461  		return nil
   462  	}
   463  	return &Ctx{caller: caller}
   464  }
   465  
   466  // RangeBot 遍历所有bot (Ctx)实例
   467  //
   468  // 单次操作返回 true 则继续遍历,否则退出
   469  func RangeBot(iter func(id string, ctx *Ctx) bool) {
   470  	clients.Range(func(key string, value *Bot) bool {
   471  		return iter(key, &Ctx{caller: value})
   472  	})
   473  }
   474  
   475  // GetFirstSuperUser 在 ids 中获得 SuperUsers 列表的首个 qq
   476  //
   477  // 找不到返回 nil
   478  func (bot *Bot) GetFirstSuperUser(ids ...string) string {
   479  	m := make(map[string]struct{}, len(ids)*4)
   480  	for _, qq := range ids {
   481  		m[qq] = struct{}{}
   482  	}
   483  	for _, qq := range bot.SuperUsers {
   484  		if _, ok := m[qq]; ok {
   485  			return qq
   486  		}
   487  	}
   488  	return ""
   489  }