github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/conf/conf.go (about)

     1  package conf
     2  
     3  import (
     4  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
     5  	"github.com/go-ini/ini"
     6  	"github.com/go-playground/validator/v10"
     7  )
     8  
     9  // database 数据库
    10  type database struct {
    11  	Type        string
    12  	User        string
    13  	Password    string
    14  	Host        string
    15  	Name        string
    16  	TablePrefix string
    17  	DBFile      string
    18  	Port        int
    19  	Charset     string
    20  	UnixSocket  bool
    21  }
    22  
    23  // system 系统通用配置
    24  type system struct {
    25  	Mode          string `validate:"eq=master|eq=slave"`
    26  	Listen        string `validate:"required"`
    27  	Debug         bool
    28  	SessionSecret string
    29  	HashIDSalt    string
    30  	GracePeriod   int    `validate:"gte=0"`
    31  	ProxyHeader   string `validate:"required_with=Listen"`
    32  }
    33  
    34  type ssl struct {
    35  	CertPath string `validate:"omitempty,required"`
    36  	KeyPath  string `validate:"omitempty,required"`
    37  	Listen   string `validate:"required"`
    38  }
    39  
    40  type unix struct {
    41  	Listen string
    42  	Perm   uint32
    43  }
    44  
    45  // slave 作为slave存储端配置
    46  type slave struct {
    47  	Secret          string `validate:"omitempty,gte=64"`
    48  	CallbackTimeout int    `validate:"omitempty,gte=1"`
    49  	SignatureTTL    int    `validate:"omitempty,gte=1"`
    50  }
    51  
    52  // redis 配置
    53  type redis struct {
    54  	Network  string
    55  	Server   string
    56  	User	 string
    57  	Password string
    58  	DB       string
    59  }
    60  
    61  // 跨域配置
    62  type cors struct {
    63  	AllowOrigins     []string
    64  	AllowMethods     []string
    65  	AllowHeaders     []string
    66  	AllowCredentials bool
    67  	ExposeHeaders    []string
    68  	SameSite         string
    69  	Secure           bool
    70  }
    71  
    72  var cfg *ini.File
    73  
    74  const defaultConf = `[System]
    75  Debug = false
    76  Mode = master
    77  Listen = :5212
    78  SessionSecret = {SessionSecret}
    79  HashIDSalt = {HashIDSalt}
    80  `
    81  
    82  // Init 初始化配置文件
    83  func Init(path string) {
    84  	var err error
    85  
    86  	if path == "" || !util.Exists(path) {
    87  		// 创建初始配置文件
    88  		confContent := util.Replace(map[string]string{
    89  			"{SessionSecret}": util.RandStringRunes(64),
    90  			"{HashIDSalt}":    util.RandStringRunes(64),
    91  		}, defaultConf)
    92  		f, err := util.CreatNestedFile(path)
    93  		if err != nil {
    94  			util.Log().Panic("Failed to create config file: %s", err)
    95  		}
    96  
    97  		// 写入配置文件
    98  		_, err = f.WriteString(confContent)
    99  		if err != nil {
   100  			util.Log().Panic("Failed to write config file: %s", err)
   101  		}
   102  
   103  		f.Close()
   104  	}
   105  
   106  	cfg, err = ini.Load(path)
   107  	if err != nil {
   108  		util.Log().Panic("Failed to parse config file %q: %s", path, err)
   109  	}
   110  
   111  	sections := map[string]interface{}{
   112  		"Database":   DatabaseConfig,
   113  		"System":     SystemConfig,
   114  		"SSL":        SSLConfig,
   115  		"UnixSocket": UnixConfig,
   116  		"Redis":      RedisConfig,
   117  		"CORS":       CORSConfig,
   118  		"Slave":      SlaveConfig,
   119  	}
   120  	for sectionName, sectionStruct := range sections {
   121  		err = mapSection(sectionName, sectionStruct)
   122  		if err != nil {
   123  			util.Log().Panic("Failed to parse config section %q: %s", sectionName, err)
   124  		}
   125  	}
   126  
   127  	// 映射数据库配置覆盖
   128  	for _, key := range cfg.Section("OptionOverwrite").Keys() {
   129  		OptionOverwrite[key.Name()] = key.Value()
   130  	}
   131  
   132  	// 重设log等级
   133  	if !SystemConfig.Debug {
   134  		util.Level = util.LevelInformational
   135  		util.GloablLogger = nil
   136  		util.Log()
   137  	}
   138  
   139  }
   140  
   141  // mapSection 将配置文件的 Section 映射到结构体上
   142  func mapSection(section string, confStruct interface{}) error {
   143  	err := cfg.Section(section).MapTo(confStruct)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	// 验证合法性
   149  	validate := validator.New()
   150  	err = validate.Struct(confStruct)
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	return nil
   156  }