github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/tls/std_server.go (about) 1 package tls 2 3 import ( 4 "context" 5 "crypto/tls" 6 "net" 7 "os" 8 "strings" 9 10 "github.com/inazumav/sing-box/adapter" 11 "github.com/inazumav/sing-box/log" 12 "github.com/inazumav/sing-box/option" 13 "github.com/sagernet/sing/common" 14 E "github.com/sagernet/sing/common/exceptions" 15 "github.com/sagernet/sing/common/ntp" 16 17 "github.com/fsnotify/fsnotify" 18 ) 19 20 var errInsecureUnused = E.New("tls: insecure unused") 21 22 type STDServerConfig struct { 23 config *tls.Config 24 logger log.Logger 25 acmeService adapter.Service 26 certificate []byte 27 key []byte 28 certificatePath string 29 keyPath string 30 watcher *fsnotify.Watcher 31 } 32 33 func (c *STDServerConfig) ServerName() string { 34 return c.config.ServerName 35 } 36 37 func (c *STDServerConfig) SetServerName(serverName string) { 38 c.config.ServerName = serverName 39 } 40 41 func (c *STDServerConfig) NextProtos() []string { 42 return c.config.NextProtos 43 } 44 45 func (c *STDServerConfig) SetNextProtos(nextProto []string) { 46 c.config.NextProtos = nextProto 47 } 48 49 func (c *STDServerConfig) Config() (*STDConfig, error) { 50 return c.config, nil 51 } 52 53 func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) { 54 return tls.Client(conn, c.config), nil 55 } 56 57 func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) { 58 return tls.Server(conn, c.config), nil 59 } 60 61 func (c *STDServerConfig) Clone() Config { 62 return &STDServerConfig{ 63 config: c.config.Clone(), 64 } 65 } 66 67 func (c *STDServerConfig) Start() error { 68 if c.acmeService != nil { 69 return c.acmeService.Start() 70 } else { 71 if c.certificatePath == "" && c.keyPath == "" { 72 return nil 73 } 74 err := c.startWatcher() 75 if err != nil { 76 c.logger.Warn("create fsnotify watcher: ", err) 77 } 78 return nil 79 } 80 } 81 82 func (c *STDServerConfig) startWatcher() error { 83 watcher, err := fsnotify.NewWatcher() 84 if err != nil { 85 return err 86 } 87 if c.certificatePath != "" { 88 err = watcher.Add(c.certificatePath) 89 if err != nil { 90 return err 91 } 92 } 93 if c.keyPath != "" { 94 err = watcher.Add(c.keyPath) 95 if err != nil { 96 return err 97 } 98 } 99 c.watcher = watcher 100 go c.loopUpdate() 101 return nil 102 } 103 104 func (c *STDServerConfig) loopUpdate() { 105 for { 106 select { 107 case event, ok := <-c.watcher.Events: 108 if !ok { 109 return 110 } 111 if event.Op&fsnotify.Write != fsnotify.Write { 112 continue 113 } 114 err := c.reloadKeyPair() 115 if err != nil { 116 c.logger.Error(E.Cause(err, "reload TLS key pair")) 117 } 118 case err, ok := <-c.watcher.Errors: 119 if !ok { 120 return 121 } 122 c.logger.Error(E.Cause(err, "fsnotify error")) 123 } 124 } 125 } 126 127 func (c *STDServerConfig) reloadKeyPair() error { 128 if c.certificatePath != "" { 129 certificate, err := os.ReadFile(c.certificatePath) 130 if err != nil { 131 return E.Cause(err, "reload certificate from ", c.certificatePath) 132 } 133 c.certificate = certificate 134 } 135 if c.keyPath != "" { 136 key, err := os.ReadFile(c.keyPath) 137 if err != nil { 138 return E.Cause(err, "reload key from ", c.keyPath) 139 } 140 c.key = key 141 } 142 keyPair, err := tls.X509KeyPair(c.certificate, c.key) 143 if err != nil { 144 return E.Cause(err, "reload key pair") 145 } 146 c.config.Certificates = []tls.Certificate{keyPair} 147 c.logger.Info("reloaded TLS certificate") 148 return nil 149 } 150 151 func (c *STDServerConfig) Close() error { 152 if c.acmeService != nil { 153 return c.acmeService.Close() 154 } 155 if c.watcher != nil { 156 return c.watcher.Close() 157 } 158 return nil 159 } 160 161 func NewSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { 162 if !options.Enabled { 163 return nil, nil 164 } 165 var tlsConfig *tls.Config 166 var acmeService adapter.Service 167 var err error 168 if options.ACME != nil && len(options.ACME.Domain) > 0 { 169 //nolint:staticcheck 170 tlsConfig, acmeService, err = startACME(ctx, common.PtrValueOrDefault(options.ACME)) 171 if err != nil { 172 return nil, err 173 } 174 if options.Insecure { 175 return nil, errInsecureUnused 176 } 177 } else { 178 tlsConfig = &tls.Config{} 179 } 180 tlsConfig.Time = ntp.TimeFuncFromContext(ctx) 181 if options.ServerName != "" { 182 tlsConfig.ServerName = options.ServerName 183 } 184 if len(options.ALPN) > 0 { 185 tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...) 186 } 187 if options.MinVersion != "" { 188 minVersion, err := ParseTLSVersion(options.MinVersion) 189 if err != nil { 190 return nil, E.Cause(err, "parse min_version") 191 } 192 tlsConfig.MinVersion = minVersion 193 } 194 if options.MaxVersion != "" { 195 maxVersion, err := ParseTLSVersion(options.MaxVersion) 196 if err != nil { 197 return nil, E.Cause(err, "parse max_version") 198 } 199 tlsConfig.MaxVersion = maxVersion 200 } 201 if options.CipherSuites != nil { 202 find: 203 for _, cipherSuite := range options.CipherSuites { 204 for _, tlsCipherSuite := range tls.CipherSuites() { 205 if cipherSuite == tlsCipherSuite.Name { 206 tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID) 207 continue find 208 } 209 } 210 return nil, E.New("unknown cipher_suite: ", cipherSuite) 211 } 212 } 213 var certificate []byte 214 var key []byte 215 if acmeService == nil { 216 if len(options.Certificate) > 0 { 217 certificate = []byte(strings.Join(options.Certificate, "\n")) 218 } else if options.CertificatePath != "" { 219 content, err := os.ReadFile(options.CertificatePath) 220 if err != nil { 221 return nil, E.Cause(err, "read certificate") 222 } 223 certificate = content 224 } 225 if len(options.Key) > 0 { 226 key = []byte(strings.Join(options.Key, "\n")) 227 } else if options.KeyPath != "" { 228 content, err := os.ReadFile(options.KeyPath) 229 if err != nil { 230 return nil, E.Cause(err, "read key") 231 } 232 key = content 233 } 234 if certificate == nil && key == nil && options.Insecure { 235 tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 236 return GenerateKeyPair(ntp.TimeFuncFromContext(ctx), info.ServerName) 237 } 238 } else { 239 if certificate == nil { 240 return nil, E.New("missing certificate") 241 } else if key == nil { 242 return nil, E.New("missing key") 243 } 244 245 keyPair, err := tls.X509KeyPair(certificate, key) 246 if err != nil { 247 return nil, E.Cause(err, "parse x509 key pair") 248 } 249 tlsConfig.Certificates = []tls.Certificate{keyPair} 250 } 251 } 252 return &STDServerConfig{ 253 config: tlsConfig, 254 logger: logger, 255 acmeService: acmeService, 256 certificate: certificate, 257 key: key, 258 certificatePath: options.CertificatePath, 259 keyPath: options.KeyPath, 260 }, nil 261 }