github.com/sagernet/sing-box@v1.9.0-rc.20/common/tls/std_server.go (about) 1 package tls 2 3 import ( 4 "context" 5 "crypto/tls" 6 "net" 7 "os" 8 "strings" 9 10 "github.com/sagernet/sing-box/adapter" 11 "github.com/sagernet/sing-box/log" 12 "github.com/sagernet/sing-box/option" 13 "github.com/sagernet/sing/common" 14 E "github.com/sagernet/sing/common/exceptions" 15 "github.com/sagernet/sing/common/ntp" 16 17 "github.com/fsnotify/fsnotify" 18 ) 19 20 var errInsecureUnused = E.New("tls: insecure unused") 21 22 type STDServerConfig struct { 23 config *tls.Config 24 logger log.Logger 25 acmeService adapter.Service 26 certificate []byte 27 key []byte 28 certificatePath string 29 keyPath string 30 watcher *fsnotify.Watcher 31 } 32 33 func (c *STDServerConfig) ServerName() string { 34 return c.config.ServerName 35 } 36 37 func (c *STDServerConfig) SetServerName(serverName string) { 38 c.config.ServerName = serverName 39 } 40 41 func (c *STDServerConfig) NextProtos() []string { 42 if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol { 43 return c.config.NextProtos[1:] 44 } else { 45 return c.config.NextProtos 46 } 47 } 48 49 func (c *STDServerConfig) SetNextProtos(nextProto []string) { 50 if c.acmeService != nil && len(c.config.NextProtos) > 1 && c.config.NextProtos[0] == ACMETLS1Protocol { 51 c.config.NextProtos = append(c.config.NextProtos[:1], nextProto...) 52 } else { 53 c.config.NextProtos = nextProto 54 } 55 } 56 57 func (c *STDServerConfig) Config() (*STDConfig, error) { 58 return c.config, nil 59 } 60 61 func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) { 62 return tls.Client(conn, c.config), nil 63 } 64 65 func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) { 66 return tls.Server(conn, c.config), nil 67 } 68 69 func (c *STDServerConfig) Clone() Config { 70 return &STDServerConfig{ 71 config: c.config.Clone(), 72 } 73 } 74 75 func (c *STDServerConfig) Start() error { 76 if c.acmeService != nil { 77 return c.acmeService.Start() 78 } else { 79 if c.certificatePath == "" && c.keyPath == "" { 80 return nil 81 } 82 err := c.startWatcher() 83 if err != nil { 84 c.logger.Warn("create fsnotify watcher: ", err) 85 } 86 return nil 87 } 88 } 89 90 func (c *STDServerConfig) startWatcher() error { 91 watcher, err := fsnotify.NewWatcher() 92 if err != nil { 93 return err 94 } 95 if c.certificatePath != "" { 96 err = watcher.Add(c.certificatePath) 97 if err != nil { 98 return err 99 } 100 } 101 if c.keyPath != "" { 102 err = watcher.Add(c.keyPath) 103 if err != nil { 104 return err 105 } 106 } 107 c.watcher = watcher 108 go c.loopUpdate() 109 return nil 110 } 111 112 func (c *STDServerConfig) loopUpdate() { 113 for { 114 select { 115 case event, ok := <-c.watcher.Events: 116 if !ok { 117 return 118 } 119 if event.Op&fsnotify.Write != fsnotify.Write { 120 continue 121 } 122 err := c.reloadKeyPair() 123 if err != nil { 124 c.logger.Error(E.Cause(err, "reload TLS key pair")) 125 } 126 case err, ok := <-c.watcher.Errors: 127 if !ok { 128 return 129 } 130 c.logger.Error(E.Cause(err, "fsnotify error")) 131 } 132 } 133 } 134 135 func (c *STDServerConfig) reloadKeyPair() error { 136 if c.certificatePath != "" { 137 certificate, err := os.ReadFile(c.certificatePath) 138 if err != nil { 139 return E.Cause(err, "reload certificate from ", c.certificatePath) 140 } 141 c.certificate = certificate 142 } 143 if c.keyPath != "" { 144 key, err := os.ReadFile(c.keyPath) 145 if err != nil { 146 return E.Cause(err, "reload key from ", c.keyPath) 147 } 148 c.key = key 149 } 150 keyPair, err := tls.X509KeyPair(c.certificate, c.key) 151 if err != nil { 152 return E.Cause(err, "reload key pair") 153 } 154 c.config.Certificates = []tls.Certificate{keyPair} 155 c.logger.Info("reloaded TLS certificate") 156 return nil 157 } 158 159 func (c *STDServerConfig) Close() error { 160 if c.acmeService != nil { 161 return c.acmeService.Close() 162 } 163 if c.watcher != nil { 164 return c.watcher.Close() 165 } 166 return nil 167 } 168 169 func NewSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { 170 if !options.Enabled { 171 return nil, nil 172 } 173 var tlsConfig *tls.Config 174 var acmeService adapter.Service 175 var err error 176 if options.ACME != nil && len(options.ACME.Domain) > 0 { 177 //nolint:staticcheck 178 tlsConfig, acmeService, err = startACME(ctx, common.PtrValueOrDefault(options.ACME)) 179 if err != nil { 180 return nil, err 181 } 182 if options.Insecure { 183 return nil, errInsecureUnused 184 } 185 } else { 186 tlsConfig = &tls.Config{} 187 } 188 tlsConfig.Time = ntp.TimeFuncFromContext(ctx) 189 if options.ServerName != "" { 190 tlsConfig.ServerName = options.ServerName 191 } 192 if len(options.ALPN) > 0 { 193 tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...) 194 } 195 if options.MinVersion != "" { 196 minVersion, err := ParseTLSVersion(options.MinVersion) 197 if err != nil { 198 return nil, E.Cause(err, "parse min_version") 199 } 200 tlsConfig.MinVersion = minVersion 201 } 202 if options.MaxVersion != "" { 203 maxVersion, err := ParseTLSVersion(options.MaxVersion) 204 if err != nil { 205 return nil, E.Cause(err, "parse max_version") 206 } 207 tlsConfig.MaxVersion = maxVersion 208 } 209 if options.CipherSuites != nil { 210 find: 211 for _, cipherSuite := range options.CipherSuites { 212 for _, tlsCipherSuite := range tls.CipherSuites() { 213 if cipherSuite == tlsCipherSuite.Name { 214 tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID) 215 continue find 216 } 217 } 218 return nil, E.New("unknown cipher_suite: ", cipherSuite) 219 } 220 } 221 var certificate []byte 222 var key []byte 223 if acmeService == nil { 224 if len(options.Certificate) > 0 { 225 certificate = []byte(strings.Join(options.Certificate, "\n")) 226 } else if options.CertificatePath != "" { 227 content, err := os.ReadFile(options.CertificatePath) 228 if err != nil { 229 return nil, E.Cause(err, "read certificate") 230 } 231 certificate = content 232 } 233 if len(options.Key) > 0 { 234 key = []byte(strings.Join(options.Key, "\n")) 235 } else if options.KeyPath != "" { 236 content, err := os.ReadFile(options.KeyPath) 237 if err != nil { 238 return nil, E.Cause(err, "read key") 239 } 240 key = content 241 } 242 if certificate == nil && key == nil && options.Insecure { 243 tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 244 return GenerateCertificate(ntp.TimeFuncFromContext(ctx), info.ServerName) 245 } 246 } else { 247 if certificate == nil { 248 return nil, E.New("missing certificate") 249 } else if key == nil { 250 return nil, E.New("missing key") 251 } 252 253 keyPair, err := tls.X509KeyPair(certificate, key) 254 if err != nil { 255 return nil, E.Cause(err, "parse x509 key pair") 256 } 257 tlsConfig.Certificates = []tls.Certificate{keyPair} 258 } 259 } 260 return &STDServerConfig{ 261 config: tlsConfig, 262 logger: logger, 263 acmeService: acmeService, 264 certificate: certificate, 265 key: key, 266 certificatePath: options.CertificatePath, 267 keyPath: options.KeyPath, 268 }, nil 269 }