github.com/Mrs4s/go-cqhttp@v1.2.0/server/http.go (about) 1 package server 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/hmac" 7 "crypto/sha1" 8 "encoding/base64" 9 "encoding/hex" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net" 14 "net/http" 15 "net/url" 16 "os" 17 "regexp" 18 "strconv" 19 "strings" 20 "time" 21 22 "github.com/Mrs4s/MiraiGo/utils" 23 log "github.com/sirupsen/logrus" 24 "github.com/tidwall/gjson" 25 "gopkg.in/yaml.v3" 26 27 "github.com/Mrs4s/go-cqhttp/coolq" 28 "github.com/Mrs4s/go-cqhttp/global" 29 "github.com/Mrs4s/go-cqhttp/modules/api" 30 "github.com/Mrs4s/go-cqhttp/modules/config" 31 "github.com/Mrs4s/go-cqhttp/modules/filter" 32 "github.com/Mrs4s/go-cqhttp/pkg/onebot" 33 ) 34 35 // HTTPServer HTTP通信相关配置 36 type HTTPServer struct { 37 Disabled bool `yaml:"disabled"` 38 Version uint16 `yaml:"version"` 39 Address string `yaml:"address"` 40 Host string `yaml:"host"` 41 Port int `yaml:"port"` 42 Timeout int32 `yaml:"timeout"` 43 LongPolling struct { 44 Enabled bool `yaml:"enabled"` 45 MaxQueueSize int `yaml:"max-queue-size"` 46 } `yaml:"long-polling"` 47 Post []httpServerPost `yaml:"post"` 48 49 MiddleWares `yaml:"middlewares"` 50 } 51 52 type httpServerPost struct { 53 URL string `yaml:"url"` 54 Secret string `yaml:"secret"` 55 MaxRetries *uint64 `yaml:"max-retries"` 56 RetriesInterval *uint64 `yaml:"retries-interval"` 57 } 58 59 type httpServer struct { 60 api *api.Caller 61 accessToken string 62 spec *onebot.Spec // onebot spec 63 } 64 65 // HTTPClient 反向HTTP上报客户端 66 type HTTPClient struct { 67 bot *coolq.CQBot 68 secret string 69 addr string 70 filter string 71 apiPort int 72 timeout int32 73 client *http.Client 74 MaxRetries uint64 75 RetriesInterval uint64 76 } 77 78 type httpCtx struct { 79 json gjson.Result 80 query url.Values 81 postForm url.Values 82 } 83 84 const httpDefault = ` 85 - http: # HTTP 通信设置 86 address: 0.0.0.0:5700 # HTTP监听地址 87 version: 11 # OneBot协议版本, 支持 11/12 88 timeout: 5 # 反向 HTTP 超时时间, 单位秒,<5 时将被忽略 89 long-polling: # 长轮询拓展 90 enabled: false # 是否开启 91 max-queue-size: 2000 # 消息队列大小,0 表示不限制队列大小,谨慎使用 92 middlewares: 93 <<: *default # 引用默认中间件 94 post: # 反向HTTP POST地址列表 95 #- url: '' # 地址 96 # secret: '' # 密钥 97 # max-retries: 3 # 最大重试,0 时禁用 98 # retries-interval: 1500 # 重试时间,单位毫秒,0 时立即 99 #- url: http://127.0.0.1:5701/ # 地址 100 # secret: '' # 密钥 101 # max-retries: 10 # 最大重试,0 时禁用 102 # retries-interval: 1000 # 重试时间,单位毫秒,0 时立即 103 ` 104 105 func init() { 106 config.AddServer(&config.Server{Brief: "HTTP通信", Default: httpDefault}) 107 } 108 109 var joinQuery = regexp.MustCompile(`\[(.+?),(.+?)]\.0`) 110 111 func mayJSONParam(p string) bool { 112 if strings.HasPrefix(p, "{") || strings.HasPrefix(p, "[") { 113 return gjson.Valid(p) 114 } 115 return false 116 } 117 118 func (h *httpCtx) get(pattern string, join bool) gjson.Result { 119 // support gjson advanced syntax: 120 // h.Get("[a,b].0") see usage in http_test.go. See issue #1241, #1325. 121 if join && strings.HasPrefix(pattern, "[") && joinQuery.MatchString(pattern) { 122 matched := joinQuery.FindStringSubmatch(pattern) 123 if r := h.get(matched[1], false); r.Exists() { 124 return r 125 } 126 return h.get(matched[2], false) 127 } 128 129 if h.postForm != nil { 130 if form := h.postForm.Get(pattern); form != "" { 131 if mayJSONParam(form) { 132 return gjson.Result{Type: gjson.JSON, Raw: form} 133 } 134 return gjson.Result{Type: gjson.String, Str: form} 135 } 136 } 137 if h.query != nil { 138 if query := h.query.Get(pattern); query != "" { 139 if mayJSONParam(query) { 140 return gjson.Result{Type: gjson.JSON, Raw: query} 141 } 142 return gjson.Result{Type: gjson.String, Str: query} 143 } 144 } 145 return gjson.Result{} 146 } 147 148 func (h *httpCtx) Get(s string) gjson.Result { 149 j := h.json.Get(s) 150 if j.Exists() { 151 return j 152 } 153 return h.get(s, true) 154 } 155 156 func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 157 var ctx httpCtx 158 contentType := request.Header.Get("Content-Type") 159 switch request.Method { 160 case http.MethodPost: 161 // todo: msg pack 162 if s.spec.Version == 12 && strings.Contains(contentType, "application/msgpack") { 163 log.Warnf("请求 %v 数据类型暂不支持: MsgPack", request.RequestURI) 164 writer.WriteHeader(http.StatusUnsupportedMediaType) 165 return 166 } 167 168 if strings.Contains(contentType, "application/json") { 169 body, err := io.ReadAll(request.Body) 170 if err != nil { 171 log.Warnf("获取请求 %v 的Body时出现错误: %v", request.RequestURI, err) 172 writer.WriteHeader(http.StatusBadRequest) 173 return 174 } 175 if !gjson.ValidBytes(body) { 176 log.Warnf("已拒绝客户端 %v 的请求: 非法Json", request.RemoteAddr) 177 writer.WriteHeader(http.StatusBadRequest) 178 return 179 } 180 ctx.json = gjson.Parse(utils.B2S(body)) 181 } 182 if strings.Contains(contentType, "application/x-www-form-urlencoded") { 183 err := request.ParseForm() 184 if err != nil { 185 log.Warnf("已拒绝客户端 %v 的请求: %v", request.RemoteAddr, err) 186 writer.WriteHeader(http.StatusBadRequest) 187 } 188 ctx.postForm = request.PostForm 189 } 190 fallthrough 191 case http.MethodGet: 192 ctx.query = request.URL.Query() 193 194 default: 195 log.Warnf("已拒绝客户端 %v 的请求: 方法错误", request.RemoteAddr) 196 writer.WriteHeader(http.StatusNotFound) 197 return 198 } 199 if status := checkAuth(request, s.accessToken); status != http.StatusOK { 200 writer.WriteHeader(status) 201 return 202 } 203 204 var response global.MSG 205 if request.URL.Path == "/" { 206 action := strings.TrimSuffix(ctx.Get("action").Str, "_async") 207 log.Debugf("HTTPServer接收到API调用: %v", action) 208 response = s.api.Call(action, s.spec, ctx.Get("params")) 209 } else { 210 action := strings.TrimPrefix(request.URL.Path, "/") 211 action = strings.TrimSuffix(action, "_async") 212 log.Debugf("HTTPServer接收到API调用: %v", action) 213 response = s.api.Call(action, s.spec, &ctx) 214 } 215 216 writer.Header().Set("Content-Type", "application/json; charset=utf-8") 217 writer.WriteHeader(http.StatusOK) 218 _ = json.NewEncoder(writer).Encode(response) 219 } 220 221 func checkAuth(req *http.Request, token string) int { 222 if token == "" { // quick path 223 return http.StatusOK 224 } 225 226 auth := req.Header.Get("Authorization") 227 if auth == "" { 228 auth = req.URL.Query().Get("access_token") 229 } else { 230 _, after, ok := strings.Cut(auth, " ") 231 if ok { 232 auth = after 233 } 234 } 235 236 switch auth { 237 case token: 238 return http.StatusOK 239 case "": 240 return http.StatusUnauthorized 241 default: 242 return http.StatusForbidden 243 } 244 } 245 246 func puint64Operator(p *uint64, def uint64) uint64 { 247 if p == nil { 248 return def 249 } 250 return *p 251 } 252 253 // runHTTP 启动HTTP服务器与HTTP上报客户端 254 func runHTTP(bot *coolq.CQBot, node yaml.Node) { 255 var conf HTTPServer 256 switch err := node.Decode(&conf); { 257 case err != nil: 258 log.Warn("读取http配置失败 :", err) 259 fallthrough 260 case conf.Disabled: 261 return 262 } 263 network, addr := "tcp", conf.Address 264 s := &httpServer{accessToken: conf.AccessToken} 265 switch conf.Version { 266 default: 267 // default v11 268 s.spec = onebot.V11 269 case 12: 270 s.spec = onebot.V12 271 } 272 switch { 273 case conf.Address != "": 274 uri, err := url.Parse(conf.Address) 275 if err == nil && uri.Scheme != "" { 276 network = uri.Scheme 277 addr = uri.Host + uri.Path 278 } 279 case conf.Host != "" || conf.Port != 0: 280 addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port) 281 log.Warnln("HTTP 服务器使用了过时的配置格式,请更新配置文件!") 282 default: 283 goto client 284 } 285 s.api = api.NewCaller(bot) 286 if conf.RateLimit.Enabled { 287 s.api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) 288 } 289 if conf.LongPolling.Enabled { 290 s.api.Use(longPolling(bot, conf.LongPolling.MaxQueueSize)) 291 } 292 go func() { 293 listener, err := net.Listen(network, addr) 294 if err != nil { 295 log.Infof("HTTP 服务启动失败, 请检查端口是否被占用: %v", err) 296 log.Warnf("将在五秒后退出.") 297 time.Sleep(time.Second * 5) 298 os.Exit(1) 299 } 300 log.Infof("CQ HTTP 服务器已启动: %v", listener.Addr()) 301 log.Fatal(http.Serve(listener, s)) 302 }() 303 client: 304 for _, c := range conf.Post { 305 if c.URL != "" { 306 go HTTPClient{ 307 bot: bot, 308 secret: c.Secret, 309 addr: c.URL, 310 apiPort: conf.Port, 311 filter: conf.Filter, 312 timeout: conf.Timeout, 313 MaxRetries: puint64Operator(c.MaxRetries, 3), 314 RetriesInterval: puint64Operator(c.RetriesInterval, 1500), 315 }.Run() 316 } 317 } 318 } 319 320 // Run 运行反向HTTP服务 321 func (c HTTPClient) Run() { 322 filter.Add(c.filter) 323 if c.timeout < 5 { 324 c.timeout = 5 325 } 326 rawAddress := c.addr 327 network, address := resolveURI(c.addr) 328 client := &http.Client{ 329 Timeout: time.Second * time.Duration(c.timeout), 330 Transport: &http.Transport{ 331 DialContext: func(_ context.Context, _, addr string) (net.Conn, error) { 332 if network == "unix" { 333 host, _, err := net.SplitHostPort(addr) 334 if err != nil { 335 host = addr 336 } 337 filepath, err := base64.RawURLEncoding.DecodeString(host) 338 if err == nil { 339 addr = string(filepath) 340 } 341 } 342 return net.Dial(network, addr) 343 }, 344 }, 345 } 346 c.addr = address // clean path 347 c.client = client 348 log.Infof("HTTP POST上报器已启动: %v", rawAddress) 349 c.bot.OnEventPush(c.onBotPushEvent) 350 } 351 352 func (c *HTTPClient) onBotPushEvent(e *coolq.Event) { 353 if c.filter != "" { 354 flt := filter.Find(c.filter) 355 if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) { 356 log.Debugf("上报Event %v 到 HTTP 服务器 %s 时被过滤.", c.addr, e.JSONBytes()) 357 return 358 } 359 } 360 361 header := make(http.Header) 362 header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10)) 363 header.Set("User-Agent", "CQHttp/4.15.0") 364 header.Set("Content-Type", "application/json") 365 if c.secret != "" { 366 mac := hmac.New(sha1.New, []byte(c.secret)) 367 _, _ = mac.Write(e.JSONBytes()) 368 header.Set("X-Signature", "sha1="+hex.EncodeToString(mac.Sum(nil))) 369 } 370 if c.apiPort != 0 { 371 header.Set("X-API-Port", strconv.FormatInt(int64(c.apiPort), 10)) 372 } 373 374 var req *http.Request 375 var res *http.Response 376 var err error 377 for i := uint64(0); i <= c.MaxRetries; i++ { 378 // see https://stackoverflow.com/questions/31337891/net-http-http-contentlength-222-with-body-length-0 379 // we should create a new request for every single post trial 380 req, err = http.NewRequest(http.MethodPost, c.addr, bytes.NewReader(e.JSONBytes())) 381 if err != nil { 382 log.Warnf("上报 Event 数据到 %v 时创建请求失败: %v", c.addr, err) 383 return 384 } 385 req.Header = header 386 res, err = c.client.Do(req) // nolint:bodyclose 387 if err == nil { 388 break 389 } 390 if i < c.MaxRetries { 391 log.Warnf("上报 Event 数据到 %v 失败: %v 将进行第 %d 次重试", c.addr, err, i+1) 392 } else { 393 log.Warnf("上报 Event 数据 %s 到 %v 失败: %v 停止上报:已达重试上限", e.JSONBytes(), c.addr, err) 394 return 395 } 396 time.Sleep(time.Millisecond * time.Duration(c.RetriesInterval)) 397 } 398 defer res.Body.Close() 399 400 log.Debugf("上报Event数据 %s 到 %v", e.JSONBytes(), c.addr) 401 r, err := io.ReadAll(res.Body) 402 if err != nil { 403 return 404 } 405 if gjson.ValidBytes(r) { 406 c.bot.CQHandleQuickOperation(gjson.Parse(e.JSONString()), gjson.ParseBytes(r)) 407 } 408 }