github.com/Mrs4s/go-cqhttp@v1.2.0/modules/config/config.go (about)

     1  // Package config 包含go-cqhttp操作配置文件的相关函数
     2  package config
     3  
     4  import (
     5  	"bufio"
     6  	_ "embed" // embed the default config file
     7  	"fmt"
     8  	"os"
     9  	"regexp"
    10  	"strings"
    11  
    12  	log "github.com/sirupsen/logrus"
    13  	"gopkg.in/yaml.v3"
    14  )
    15  
    16  // defaultConfig 默认配置文件
    17  //
    18  //go:embed default_config.yml
    19  var defaultConfig string
    20  
    21  // Reconnect 重连配置
    22  type Reconnect struct {
    23  	Disabled bool `yaml:"disabled"`
    24  	Delay    uint `yaml:"delay"`
    25  	MaxTimes uint `yaml:"max-times"`
    26  	Interval int  `yaml:"interval"`
    27  }
    28  
    29  // Account 账号配置
    30  type Account struct {
    31  	Uin                  int64        `yaml:"uin"`
    32  	Password             string       `yaml:"password"`
    33  	Encrypt              bool         `yaml:"encrypt"`
    34  	Status               int          `yaml:"status"`
    35  	ReLogin              *Reconnect   `yaml:"relogin"`
    36  	UseSSOAddress        bool         `yaml:"use-sso-address"`
    37  	AllowTempSession     bool         `yaml:"allow-temp-session"`
    38  	SignServers          []SignServer `yaml:"sign-servers"`
    39  	RuleChangeSignServer int          `yaml:"rule-change-sign-server"`
    40  	MaxCheckCount        uint         `yaml:"max-check-count"`
    41  	SignServerTimeout    uint         `yaml:"sign-server-timeout"`
    42  	IsBelow110           bool         `yaml:"is-below-110"`
    43  	AutoRegister         bool         `yaml:"auto-register"`
    44  	AutoRefreshToken     bool         `yaml:"auto-refresh-token"`
    45  	RefreshInterval      int64        `yaml:"refresh-interval"`
    46  }
    47  
    48  // SignServer 签名服务器
    49  type SignServer struct {
    50  	URL           string `yaml:"url"`
    51  	Key           string `yaml:"key"`
    52  	Authorization string `yaml:"authorization"`
    53  }
    54  
    55  // Config 总配置文件
    56  type Config struct {
    57  	Account   *Account `yaml:"account"`
    58  	Heartbeat struct {
    59  		Disabled bool `yaml:"disabled"`
    60  		Interval int  `yaml:"interval"`
    61  	} `yaml:"heartbeat"`
    62  
    63  	Message struct {
    64  		PostFormat          string `yaml:"post-format"`
    65  		ProxyRewrite        string `yaml:"proxy-rewrite"`
    66  		IgnoreInvalidCQCode bool   `yaml:"ignore-invalid-cqcode"`
    67  		ForceFragment       bool   `yaml:"force-fragment"`
    68  		FixURL              bool   `yaml:"fix-url"`
    69  		ReportSelfMessage   bool   `yaml:"report-self-message"`
    70  		RemoveReplyAt       bool   `yaml:"remove-reply-at"`
    71  		ExtraReplyData      bool   `yaml:"extra-reply-data"`
    72  		SkipMimeScan        bool   `yaml:"skip-mime-scan"`
    73  		ConvertWebpImage    bool   `yaml:"convert-webp-image"`
    74  		HTTPTimeout         int    `yaml:"http-timeout"`
    75  	} `yaml:"message"`
    76  
    77  	Output struct {
    78  		LogLevel    string `yaml:"log-level"`
    79  		LogAging    int    `yaml:"log-aging"`
    80  		LogForceNew bool   `yaml:"log-force-new"`
    81  		LogColorful *bool  `yaml:"log-colorful"`
    82  		Debug       bool   `yaml:"debug"`
    83  	} `yaml:"output"`
    84  
    85  	Servers  []map[string]yaml.Node `yaml:"servers"`
    86  	Database map[string]yaml.Node   `yaml:"database"`
    87  }
    88  
    89  // Server 的简介和初始配置
    90  type Server struct {
    91  	Brief   string
    92  	Default string
    93  }
    94  
    95  // Parse 从默认配置文件路径中获取
    96  func Parse(path string) *Config {
    97  	file, err := os.ReadFile(path)
    98  	config := &Config{}
    99  	if err == nil {
   100  		err = yaml.NewDecoder(strings.NewReader(expand(string(file), os.Getenv))).Decode(config)
   101  		if err != nil {
   102  			log.Fatal("配置文件不合法!", err)
   103  		}
   104  	} else {
   105  		generateConfig()
   106  		os.Exit(0)
   107  	}
   108  	return config
   109  }
   110  
   111  var serverconfs []*Server
   112  
   113  // AddServer 添加该服务的简介和默认配置
   114  func AddServer(s *Server) {
   115  	serverconfs = append(serverconfs, s)
   116  }
   117  
   118  // generateConfig 生成配置文件
   119  func generateConfig() {
   120  	fmt.Println("未找到配置文件,正在为您生成配置文件中!")
   121  	sb := strings.Builder{}
   122  	sb.WriteString(defaultConfig)
   123  	hint := "请选择你需要的通信方式:"
   124  	for i, s := range serverconfs {
   125  		hint += fmt.Sprintf("\n> %d: %s", i, s.Brief)
   126  	}
   127  	hint += `
   128  请输入你需要的编号(0-9),可输入多个,同一编号也可输入多个(如: 233)
   129  您的选择是:`
   130  	fmt.Print(hint)
   131  	input := bufio.NewReader(os.Stdin)
   132  	readString, err := input.ReadString('\n')
   133  	if err != nil {
   134  		log.Fatal("输入不合法: ", err)
   135  	}
   136  	rmax := len(serverconfs)
   137  	if rmax > 10 {
   138  		rmax = 10
   139  	}
   140  	for _, r := range readString {
   141  		r -= '0'
   142  		if r >= 0 && r < rune(rmax) {
   143  			sb.WriteString(serverconfs[r].Default)
   144  		}
   145  	}
   146  	_ = os.WriteFile("config.yml", []byte(sb.String()), 0o644)
   147  	fmt.Println("默认配置文件已生成,请修改 config.yml 后重新启动!")
   148  	_, _ = input.ReadString('\n')
   149  }
   150  
   151  // expand 使用正则进行环境变量展开
   152  // os.ExpandEnv 字符 $ 无法逃逸
   153  // https://github.com/golang/go/issues/43482
   154  func expand(s string, mapping func(string) string) string {
   155  	r := regexp.MustCompile(`\${([a-zA-Z_]+[a-zA-Z0-9_:/.]*)}`)
   156  	return r.ReplaceAllStringFunc(s, func(s string) string {
   157  		s = strings.Trim(s, "${}")
   158  		before, after, ok := strings.Cut(s, ":")
   159  		m := mapping(before)
   160  		if ok && m == "" {
   161  			return after
   162  		}
   163  		return m
   164  	})
   165  }