github.com/Mrs4s/go-cqhttp@v1.2.0/server/http.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/hmac"
     7  	"crypto/sha1"
     8  	"encoding/base64"
     9  	"encoding/hex"
    10  	"encoding/json"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"regexp"
    18  	"strconv"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/Mrs4s/MiraiGo/utils"
    23  	log "github.com/sirupsen/logrus"
    24  	"github.com/tidwall/gjson"
    25  	"gopkg.in/yaml.v3"
    26  
    27  	"github.com/Mrs4s/go-cqhttp/coolq"
    28  	"github.com/Mrs4s/go-cqhttp/global"
    29  	"github.com/Mrs4s/go-cqhttp/modules/api"
    30  	"github.com/Mrs4s/go-cqhttp/modules/config"
    31  	"github.com/Mrs4s/go-cqhttp/modules/filter"
    32  	"github.com/Mrs4s/go-cqhttp/pkg/onebot"
    33  )
    34  
    35  // HTTPServer HTTP通信相关配置
    36  type HTTPServer struct {
    37  	Disabled    bool   `yaml:"disabled"`
    38  	Version     uint16 `yaml:"version"`
    39  	Address     string `yaml:"address"`
    40  	Host        string `yaml:"host"`
    41  	Port        int    `yaml:"port"`
    42  	Timeout     int32  `yaml:"timeout"`
    43  	LongPolling struct {
    44  		Enabled      bool `yaml:"enabled"`
    45  		MaxQueueSize int  `yaml:"max-queue-size"`
    46  	} `yaml:"long-polling"`
    47  	Post []httpServerPost `yaml:"post"`
    48  
    49  	MiddleWares `yaml:"middlewares"`
    50  }
    51  
    52  type httpServerPost struct {
    53  	URL             string  `yaml:"url"`
    54  	Secret          string  `yaml:"secret"`
    55  	MaxRetries      *uint64 `yaml:"max-retries"`
    56  	RetriesInterval *uint64 `yaml:"retries-interval"`
    57  }
    58  
    59  type httpServer struct {
    60  	api         *api.Caller
    61  	accessToken string
    62  	spec        *onebot.Spec // onebot spec
    63  }
    64  
    65  // HTTPClient 反向HTTP上报客户端
    66  type HTTPClient struct {
    67  	bot             *coolq.CQBot
    68  	secret          string
    69  	addr            string
    70  	filter          string
    71  	apiPort         int
    72  	timeout         int32
    73  	client          *http.Client
    74  	MaxRetries      uint64
    75  	RetriesInterval uint64
    76  }
    77  
    78  type httpCtx struct {
    79  	json     gjson.Result
    80  	query    url.Values
    81  	postForm url.Values
    82  }
    83  
    84  const httpDefault = `
    85    - http: # HTTP 通信设置
    86        address: 0.0.0.0:5700 # HTTP监听地址
    87        version: 11     # OneBot协议版本, 支持 11/12
    88        timeout: 5      # 反向 HTTP 超时时间, 单位秒,<5 时将被忽略
    89        long-polling:   # 长轮询拓展
    90          enabled: false       # 是否开启
    91          max-queue-size: 2000 # 消息队列大小,0 表示不限制队列大小,谨慎使用
    92        middlewares:
    93          <<: *default # 引用默认中间件
    94        post:           # 反向HTTP POST地址列表
    95        #- url: ''                # 地址
    96        #  secret: ''             # 密钥
    97        #  max-retries: 3         # 最大重试,0 时禁用
    98        #  retries-interval: 1500 # 重试时间,单位毫秒,0 时立即
    99        #- url: http://127.0.0.1:5701/ # 地址
   100        #  secret: ''                  # 密钥
   101        #  max-retries: 10             # 最大重试,0 时禁用
   102        #  retries-interval: 1000      # 重试时间,单位毫秒,0 时立即
   103  `
   104  
   105  func init() {
   106  	config.AddServer(&config.Server{Brief: "HTTP通信", Default: httpDefault})
   107  }
   108  
   109  var joinQuery = regexp.MustCompile(`\[(.+?),(.+?)]\.0`)
   110  
   111  func mayJSONParam(p string) bool {
   112  	if strings.HasPrefix(p, "{") || strings.HasPrefix(p, "[") {
   113  		return gjson.Valid(p)
   114  	}
   115  	return false
   116  }
   117  
   118  func (h *httpCtx) get(pattern string, join bool) gjson.Result {
   119  	// support gjson advanced syntax:
   120  	// h.Get("[a,b].0") see usage in http_test.go. See issue #1241, #1325.
   121  	if join && strings.HasPrefix(pattern, "[") && joinQuery.MatchString(pattern) {
   122  		matched := joinQuery.FindStringSubmatch(pattern)
   123  		if r := h.get(matched[1], false); r.Exists() {
   124  			return r
   125  		}
   126  		return h.get(matched[2], false)
   127  	}
   128  
   129  	if h.postForm != nil {
   130  		if form := h.postForm.Get(pattern); form != "" {
   131  			if mayJSONParam(form) {
   132  				return gjson.Result{Type: gjson.JSON, Raw: form}
   133  			}
   134  			return gjson.Result{Type: gjson.String, Str: form}
   135  		}
   136  	}
   137  	if h.query != nil {
   138  		if query := h.query.Get(pattern); query != "" {
   139  			if mayJSONParam(query) {
   140  				return gjson.Result{Type: gjson.JSON, Raw: query}
   141  			}
   142  			return gjson.Result{Type: gjson.String, Str: query}
   143  		}
   144  	}
   145  	return gjson.Result{}
   146  }
   147  
   148  func (h *httpCtx) Get(s string) gjson.Result {
   149  	j := h.json.Get(s)
   150  	if j.Exists() {
   151  		return j
   152  	}
   153  	return h.get(s, true)
   154  }
   155  
   156  func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
   157  	var ctx httpCtx
   158  	contentType := request.Header.Get("Content-Type")
   159  	switch request.Method {
   160  	case http.MethodPost:
   161  		// todo: msg pack
   162  		if s.spec.Version == 12 && strings.Contains(contentType, "application/msgpack") {
   163  			log.Warnf("请求 %v 数据类型暂不支持: MsgPack", request.RequestURI)
   164  			writer.WriteHeader(http.StatusUnsupportedMediaType)
   165  			return
   166  		}
   167  
   168  		if strings.Contains(contentType, "application/json") {
   169  			body, err := io.ReadAll(request.Body)
   170  			if err != nil {
   171  				log.Warnf("获取请求 %v 的Body时出现错误: %v", request.RequestURI, err)
   172  				writer.WriteHeader(http.StatusBadRequest)
   173  				return
   174  			}
   175  			if !gjson.ValidBytes(body) {
   176  				log.Warnf("已拒绝客户端 %v 的请求: 非法Json", request.RemoteAddr)
   177  				writer.WriteHeader(http.StatusBadRequest)
   178  				return
   179  			}
   180  			ctx.json = gjson.Parse(utils.B2S(body))
   181  		}
   182  		if strings.Contains(contentType, "application/x-www-form-urlencoded") {
   183  			err := request.ParseForm()
   184  			if err != nil {
   185  				log.Warnf("已拒绝客户端 %v 的请求: %v", request.RemoteAddr, err)
   186  				writer.WriteHeader(http.StatusBadRequest)
   187  			}
   188  			ctx.postForm = request.PostForm
   189  		}
   190  		fallthrough
   191  	case http.MethodGet:
   192  		ctx.query = request.URL.Query()
   193  
   194  	default:
   195  		log.Warnf("已拒绝客户端 %v 的请求: 方法错误", request.RemoteAddr)
   196  		writer.WriteHeader(http.StatusNotFound)
   197  		return
   198  	}
   199  	if status := checkAuth(request, s.accessToken); status != http.StatusOK {
   200  		writer.WriteHeader(status)
   201  		return
   202  	}
   203  
   204  	var response global.MSG
   205  	if request.URL.Path == "/" {
   206  		action := strings.TrimSuffix(ctx.Get("action").Str, "_async")
   207  		log.Debugf("HTTPServer接收到API调用: %v", action)
   208  		response = s.api.Call(action, s.spec, ctx.Get("params"))
   209  	} else {
   210  		action := strings.TrimPrefix(request.URL.Path, "/")
   211  		action = strings.TrimSuffix(action, "_async")
   212  		log.Debugf("HTTPServer接收到API调用: %v", action)
   213  		response = s.api.Call(action, s.spec, &ctx)
   214  	}
   215  
   216  	writer.Header().Set("Content-Type", "application/json; charset=utf-8")
   217  	writer.WriteHeader(http.StatusOK)
   218  	_ = json.NewEncoder(writer).Encode(response)
   219  }
   220  
   221  func checkAuth(req *http.Request, token string) int {
   222  	if token == "" { // quick path
   223  		return http.StatusOK
   224  	}
   225  
   226  	auth := req.Header.Get("Authorization")
   227  	if auth == "" {
   228  		auth = req.URL.Query().Get("access_token")
   229  	} else {
   230  		_, after, ok := strings.Cut(auth, " ")
   231  		if ok {
   232  			auth = after
   233  		}
   234  	}
   235  
   236  	switch auth {
   237  	case token:
   238  		return http.StatusOK
   239  	case "":
   240  		return http.StatusUnauthorized
   241  	default:
   242  		return http.StatusForbidden
   243  	}
   244  }
   245  
   246  func puint64Operator(p *uint64, def uint64) uint64 {
   247  	if p == nil {
   248  		return def
   249  	}
   250  	return *p
   251  }
   252  
   253  // runHTTP 启动HTTP服务器与HTTP上报客户端
   254  func runHTTP(bot *coolq.CQBot, node yaml.Node) {
   255  	var conf HTTPServer
   256  	switch err := node.Decode(&conf); {
   257  	case err != nil:
   258  		log.Warn("读取http配置失败 :", err)
   259  		fallthrough
   260  	case conf.Disabled:
   261  		return
   262  	}
   263  	network, addr := "tcp", conf.Address
   264  	s := &httpServer{accessToken: conf.AccessToken}
   265  	switch conf.Version {
   266  	default:
   267  		// default v11
   268  		s.spec = onebot.V11
   269  	case 12:
   270  		s.spec = onebot.V12
   271  	}
   272  	switch {
   273  	case conf.Address != "":
   274  		uri, err := url.Parse(conf.Address)
   275  		if err == nil && uri.Scheme != "" {
   276  			network = uri.Scheme
   277  			addr = uri.Host + uri.Path
   278  		}
   279  	case conf.Host != "" || conf.Port != 0:
   280  		addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port)
   281  		log.Warnln("HTTP 服务器使用了过时的配置格式,请更新配置文件!")
   282  	default:
   283  		goto client
   284  	}
   285  	s.api = api.NewCaller(bot)
   286  	if conf.RateLimit.Enabled {
   287  		s.api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
   288  	}
   289  	if conf.LongPolling.Enabled {
   290  		s.api.Use(longPolling(bot, conf.LongPolling.MaxQueueSize))
   291  	}
   292  	go func() {
   293  		listener, err := net.Listen(network, addr)
   294  		if err != nil {
   295  			log.Infof("HTTP 服务启动失败, 请检查端口是否被占用: %v", err)
   296  			log.Warnf("将在五秒后退出.")
   297  			time.Sleep(time.Second * 5)
   298  			os.Exit(1)
   299  		}
   300  		log.Infof("CQ HTTP 服务器已启动: %v", listener.Addr())
   301  		log.Fatal(http.Serve(listener, s))
   302  	}()
   303  client:
   304  	for _, c := range conf.Post {
   305  		if c.URL != "" {
   306  			go HTTPClient{
   307  				bot:             bot,
   308  				secret:          c.Secret,
   309  				addr:            c.URL,
   310  				apiPort:         conf.Port,
   311  				filter:          conf.Filter,
   312  				timeout:         conf.Timeout,
   313  				MaxRetries:      puint64Operator(c.MaxRetries, 3),
   314  				RetriesInterval: puint64Operator(c.RetriesInterval, 1500),
   315  			}.Run()
   316  		}
   317  	}
   318  }
   319  
   320  // Run 运行反向HTTP服务
   321  func (c HTTPClient) Run() {
   322  	filter.Add(c.filter)
   323  	if c.timeout < 5 {
   324  		c.timeout = 5
   325  	}
   326  	rawAddress := c.addr
   327  	network, address := resolveURI(c.addr)
   328  	client := &http.Client{
   329  		Timeout: time.Second * time.Duration(c.timeout),
   330  		Transport: &http.Transport{
   331  			DialContext: func(_ context.Context, _, addr string) (net.Conn, error) {
   332  				if network == "unix" {
   333  					host, _, err := net.SplitHostPort(addr)
   334  					if err != nil {
   335  						host = addr
   336  					}
   337  					filepath, err := base64.RawURLEncoding.DecodeString(host)
   338  					if err == nil {
   339  						addr = string(filepath)
   340  					}
   341  				}
   342  				return net.Dial(network, addr)
   343  			},
   344  		},
   345  	}
   346  	c.addr = address // clean path
   347  	c.client = client
   348  	log.Infof("HTTP POST上报器已启动: %v", rawAddress)
   349  	c.bot.OnEventPush(c.onBotPushEvent)
   350  }
   351  
   352  func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
   353  	if c.filter != "" {
   354  		flt := filter.Find(c.filter)
   355  		if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) {
   356  			log.Debugf("上报Event %v 到 HTTP 服务器 %s 时被过滤.", c.addr, e.JSONBytes())
   357  			return
   358  		}
   359  	}
   360  
   361  	header := make(http.Header)
   362  	header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10))
   363  	header.Set("User-Agent", "CQHttp/4.15.0")
   364  	header.Set("Content-Type", "application/json")
   365  	if c.secret != "" {
   366  		mac := hmac.New(sha1.New, []byte(c.secret))
   367  		_, _ = mac.Write(e.JSONBytes())
   368  		header.Set("X-Signature", "sha1="+hex.EncodeToString(mac.Sum(nil)))
   369  	}
   370  	if c.apiPort != 0 {
   371  		header.Set("X-API-Port", strconv.FormatInt(int64(c.apiPort), 10))
   372  	}
   373  
   374  	var req *http.Request
   375  	var res *http.Response
   376  	var err error
   377  	for i := uint64(0); i <= c.MaxRetries; i++ {
   378  		// see https://stackoverflow.com/questions/31337891/net-http-http-contentlength-222-with-body-length-0
   379  		// we should create a new request for every single post trial
   380  		req, err = http.NewRequest(http.MethodPost, c.addr, bytes.NewReader(e.JSONBytes()))
   381  		if err != nil {
   382  			log.Warnf("上报 Event 数据到 %v 时创建请求失败: %v", c.addr, err)
   383  			return
   384  		}
   385  		req.Header = header
   386  		res, err = c.client.Do(req) // nolint:bodyclose
   387  		if err == nil {
   388  			break
   389  		}
   390  		if i < c.MaxRetries {
   391  			log.Warnf("上报 Event 数据到 %v 失败: %v 将进行第 %d 次重试", c.addr, err, i+1)
   392  		} else {
   393  			log.Warnf("上报 Event 数据 %s 到 %v 失败: %v 停止上报:已达重试上限", e.JSONBytes(), c.addr, err)
   394  			return
   395  		}
   396  		time.Sleep(time.Millisecond * time.Duration(c.RetriesInterval))
   397  	}
   398  	defer res.Body.Close()
   399  
   400  	log.Debugf("上报Event数据 %s 到 %v", e.JSONBytes(), c.addr)
   401  	r, err := io.ReadAll(res.Body)
   402  	if err != nil {
   403  		return
   404  	}
   405  	if gjson.ValidBytes(r) {
   406  		c.bot.CQHandleQuickOperation(gjson.Parse(e.JSONString()), gjson.ParseBytes(r))
   407  	}
   408  }