github.com/Mrs4s/go-cqhttp@v1.2.0/server/middlewares.go (about)

     1  package server
     2  
     3  import (
     4  	"container/list"
     5  	"context"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/Mrs4s/go-cqhttp/coolq"
    10  	"github.com/Mrs4s/go-cqhttp/global"
    11  	"github.com/Mrs4s/go-cqhttp/modules/api"
    12  	"github.com/Mrs4s/go-cqhttp/pkg/onebot"
    13  
    14  	"golang.org/x/time/rate"
    15  )
    16  
    17  // MiddleWares 通信中间件
    18  type MiddleWares struct {
    19  	AccessToken string `yaml:"access-token"`
    20  	Filter      string `yaml:"filter"`
    21  	RateLimit   struct {
    22  		Enabled   bool    `yaml:"enabled"`
    23  		Frequency float64 `yaml:"frequency"`
    24  		Bucket    int     `yaml:"bucket"`
    25  	} `yaml:"rate-limit"`
    26  }
    27  
    28  func rateLimit(frequency float64, bucketSize int) api.Handler {
    29  	limiter := rate.NewLimiter(rate.Limit(frequency), bucketSize)
    30  	return func(_ string, _ *onebot.Spec, _ api.Getter) global.MSG {
    31  		_ = limiter.Wait(context.Background())
    32  		return nil
    33  	}
    34  }
    35  
    36  func longPolling(bot *coolq.CQBot, maxSize int) api.Handler {
    37  	var mutex sync.Mutex
    38  	cond := sync.NewCond(&mutex)
    39  	queue := list.New()
    40  	bot.OnEventPush(func(event *coolq.Event) {
    41  		mutex.Lock()
    42  		defer mutex.Unlock()
    43  		queue.PushBack(event.Raw)
    44  		for maxSize != 0 && queue.Len() > maxSize {
    45  			queue.Remove(queue.Front())
    46  		}
    47  		cond.Signal()
    48  	})
    49  	return func(action string, spec *onebot.Spec, p api.Getter) global.MSG {
    50  		switch {
    51  		case spec.Version == 11 && action == "get_updates": // ok
    52  		case spec.Version == 12 && action == "get_latest_events": // ok
    53  		default:
    54  			return nil
    55  		}
    56  		var (
    57  			ch      = make(chan []any)
    58  			timeout = time.Duration(p.Get("timeout").Int()) * time.Second
    59  		)
    60  		go func() {
    61  			mutex.Lock()
    62  			defer mutex.Unlock()
    63  			for queue.Len() == 0 {
    64  				cond.Wait()
    65  			}
    66  			limit := int(p.Get("limit").Int())
    67  			if limit <= 0 || queue.Len() < limit {
    68  				limit = queue.Len()
    69  			}
    70  			ret := make([]any, limit)
    71  			elem := queue.Front()
    72  			for i := 0; i < limit; i++ {
    73  				ret[i] = elem.Value
    74  				elem = elem.Next()
    75  			}
    76  			select {
    77  			case ch <- ret:
    78  				for i := 0; i < limit; i++ { // remove sent msg
    79  					queue.Remove(queue.Front())
    80  				}
    81  			default:
    82  				// don't block if parent already return due to timeout
    83  			}
    84  		}()
    85  		if timeout != 0 {
    86  			select {
    87  			case <-time.After(timeout):
    88  				return coolq.OK([]any{})
    89  			case ret := <-ch:
    90  				return coolq.OK(ret)
    91  			}
    92  		}
    93  		return coolq.OK(<-ch)
    94  	}
    95  }