github.com/Mrs4s/go-cqhttp@v1.2.0/server/scf.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"runtime/debug"
    12  	"strings"
    13  
    14  	"github.com/Mrs4s/MiraiGo/utils"
    15  	log "github.com/sirupsen/logrus"
    16  	"gopkg.in/yaml.v3"
    17  
    18  	"github.com/Mrs4s/go-cqhttp/coolq"
    19  	"github.com/Mrs4s/go-cqhttp/global"
    20  	api2 "github.com/Mrs4s/go-cqhttp/modules/api"
    21  	"github.com/Mrs4s/go-cqhttp/modules/config"
    22  )
    23  
    24  type lambdaClient struct {
    25  	nextURL     string
    26  	responseURL string
    27  	lambdaType  string
    28  
    29  	client http.Client
    30  }
    31  
    32  type lambdaResponse struct {
    33  	IsBase64Encoded bool              `json:"isBase64Encoded"`
    34  	StatusCode      int               `json:"statusCode"`
    35  	Headers         map[string]string `json:"headers"`
    36  	Body            string            `json:"body"`
    37  }
    38  
    39  type lambdaResponseWriter struct {
    40  	statusCode int
    41  	buf        bytes.Buffer
    42  	header     http.Header
    43  }
    44  
    45  func (l *lambdaResponseWriter) Write(p []byte) (n int, err error) {
    46  	return l.buf.Write(p)
    47  }
    48  
    49  func (l *lambdaResponseWriter) Header() http.Header {
    50  	return l.header
    51  }
    52  
    53  func (l *lambdaResponseWriter) flush() error {
    54  	buffer := global.NewBuffer()
    55  	defer global.PutBuffer(buffer)
    56  	body := utils.B2S(l.buf.Bytes())
    57  	header := make(map[string]string, len(l.header))
    58  	for k, v := range l.header {
    59  		header[k] = v[0]
    60  	}
    61  	_ = json.NewEncoder(buffer).Encode(&lambdaResponse{
    62  		IsBase64Encoded: false,
    63  		StatusCode:      l.statusCode,
    64  		Headers:         header,
    65  		Body:            body,
    66  	})
    67  
    68  	r, _ := http.NewRequest(http.MethodPost, cli.responseURL, buffer)
    69  	do, err := cli.client.Do(r)
    70  	if err != nil {
    71  		return err
    72  	}
    73  	return do.Body.Close()
    74  }
    75  
    76  func (l *lambdaResponseWriter) WriteHeader(statusCode int) {
    77  	l.statusCode = statusCode
    78  }
    79  
    80  var cli *lambdaClient
    81  
    82  // runLambda  type: [scf,aws]
    83  func runLambda(bot *coolq.CQBot, node yaml.Node) {
    84  	var conf LambdaServer
    85  	switch err := node.Decode(&conf); {
    86  	case err != nil:
    87  		log.Warn("读取lambda配置失败 :", err)
    88  		fallthrough
    89  	case conf.Disabled:
    90  		return
    91  	}
    92  
    93  	cli = &lambdaClient{
    94  		lambdaType: conf.Type,
    95  		client:     http.Client{Timeout: 0},
    96  	}
    97  	switch cli.lambdaType { // todo: aws
    98  	case "scf": // tencent serverless function
    99  		base := fmt.Sprintf("http://%s:%s/runtime/",
   100  			os.Getenv("SCF_RUNTIME_API"),
   101  			os.Getenv("SCF_RUNTIME_API_PORT"))
   102  		cli.nextURL = base + "invocation/next"
   103  		cli.responseURL = base + "invocation/response"
   104  		post, err := http.Post(base+"init/ready", "", nil)
   105  		if err != nil {
   106  			log.Warnf("lambda 初始化失败: %v", err)
   107  			return
   108  		}
   109  		_ = post.Body.Close()
   110  	case "aws": // aws lambda
   111  		const apiVersion = "2018-06-01"
   112  		base := fmt.Sprintf("http://%s/%s/runtime/", os.Getenv("AWS_LAMBDA_RUNTIME_API"), apiVersion)
   113  		cli.nextURL = base + "invocation/next"
   114  		cli.responseURL = base + "invocation/response"
   115  	default:
   116  		log.Fatal("unknown lambda type:", conf.Type)
   117  	}
   118  
   119  	api := api2.NewCaller(bot)
   120  	if conf.RateLimit.Enabled {
   121  		api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
   122  	}
   123  	server := &httpServer{
   124  		api:         api,
   125  		accessToken: conf.AccessToken,
   126  	}
   127  
   128  	for {
   129  		req := cli.next()
   130  		writer := lambdaResponseWriter{statusCode: 200, header: make(http.Header)}
   131  		func() {
   132  			defer func() {
   133  				if e := recover(); e != nil {
   134  					log.Warnf("Lambda 出现不可恢复错误: %v\n%s", e, debug.Stack())
   135  				}
   136  			}()
   137  			if req != nil {
   138  				server.ServeHTTP(&writer, req)
   139  			}
   140  		}()
   141  		if err := writer.flush(); err != nil {
   142  			log.Warnf("Lambda 发送响应失败: %v", err)
   143  		}
   144  	}
   145  }
   146  
   147  type lambdaInvoke struct {
   148  	Headers        map[string]string
   149  	HTTPMethod     string `json:"httpMethod"`
   150  	Body           string `json:"body"`
   151  	Path           string `json:"path"`
   152  	QueryString    map[string]string
   153  	RequestContext struct {
   154  		Path string `json:"path"`
   155  	} `json:"requestContext"`
   156  }
   157  
   158  const lambdaDefault = `  # LambdaServer 配置
   159    - lambda:
   160        type: scf # scf: 腾讯云函数 aws: aws Lambda
   161        middlewares:
   162          <<: *default # 引用默认中间件
   163  `
   164  
   165  // LambdaServer 云函数配置
   166  type LambdaServer struct {
   167  	Disabled bool   `yaml:"disabled"`
   168  	Type     string `yaml:"type"`
   169  
   170  	MiddleWares `yaml:"middlewares"`
   171  }
   172  
   173  func init() {
   174  	config.AddServer(&config.Server{
   175  		Brief:   "云函数服务",
   176  		Default: lambdaDefault,
   177  	})
   178  }
   179  
   180  func (c *lambdaClient) next() *http.Request {
   181  	r, err := http.NewRequest(http.MethodGet, c.nextURL, nil)
   182  	if err != nil {
   183  		return nil
   184  	}
   185  	resp, err := c.client.Do(r)
   186  	if err != nil {
   187  		return nil
   188  	}
   189  	defer resp.Body.Close()
   190  	if resp.StatusCode != http.StatusOK {
   191  		return nil
   192  	}
   193  	var req http.Request
   194  	var invoke lambdaInvoke
   195  	_ = json.NewDecoder(resp.Body).Decode(&invoke)
   196  	if invoke.HTTPMethod == "" { // 不是 api 网关
   197  		return nil
   198  	}
   199  
   200  	req.Method = invoke.HTTPMethod
   201  	req.Body = io.NopCloser(strings.NewReader(invoke.Body))
   202  	req.Header = make(map[string][]string)
   203  	for k, v := range invoke.Headers {
   204  		req.Header.Set(k, v)
   205  	}
   206  	req.URL = new(url.URL)
   207  	req.URL.Path = strings.TrimPrefix(invoke.Path, invoke.RequestContext.Path)
   208  	// todo: avoid encoding
   209  	query := make(url.Values)
   210  	for k, v := range invoke.QueryString {
   211  		query[k] = []string{v}
   212  	}
   213  	req.URL.RawQuery = query.Encode()
   214  	return &req
   215  }