github.com/sagernet/sing-box@v1.2.7/common/tls/std_server.go (about)

     1  package tls
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"net"
     7  	"os"
     8  
     9  	"github.com/sagernet/sing-box/adapter"
    10  	"github.com/sagernet/sing-box/log"
    11  	"github.com/sagernet/sing-box/option"
    12  	"github.com/sagernet/sing/common"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  
    15  	"github.com/fsnotify/fsnotify"
    16  )
    17  
    18  var errInsecureUnused = E.New("tls: insecure unused")
    19  
    20  type STDServerConfig struct {
    21  	config          *tls.Config
    22  	logger          log.Logger
    23  	acmeService     adapter.Service
    24  	certificate     []byte
    25  	key             []byte
    26  	certificatePath string
    27  	keyPath         string
    28  	watcher         *fsnotify.Watcher
    29  }
    30  
    31  func (c *STDServerConfig) ServerName() string {
    32  	return c.config.ServerName
    33  }
    34  
    35  func (c *STDServerConfig) SetServerName(serverName string) {
    36  	c.config.ServerName = serverName
    37  }
    38  
    39  func (c *STDServerConfig) NextProtos() []string {
    40  	return c.config.NextProtos
    41  }
    42  
    43  func (c *STDServerConfig) SetNextProtos(nextProto []string) {
    44  	c.config.NextProtos = nextProto
    45  }
    46  
    47  func (c *STDServerConfig) Config() (*STDConfig, error) {
    48  	return c.config, nil
    49  }
    50  
    51  func (c *STDServerConfig) Client(conn net.Conn) (Conn, error) {
    52  	return tls.Client(conn, c.config), nil
    53  }
    54  
    55  func (c *STDServerConfig) Server(conn net.Conn) (Conn, error) {
    56  	return tls.Server(conn, c.config), nil
    57  }
    58  
    59  func (c *STDServerConfig) Clone() Config {
    60  	return &STDServerConfig{
    61  		config: c.config.Clone(),
    62  	}
    63  }
    64  
    65  func (c *STDServerConfig) Start() error {
    66  	if c.acmeService != nil {
    67  		return c.acmeService.Start()
    68  	} else {
    69  		if c.certificatePath == "" && c.keyPath == "" {
    70  			return nil
    71  		}
    72  		err := c.startWatcher()
    73  		if err != nil {
    74  			c.logger.Warn("create fsnotify watcher: ", err)
    75  		}
    76  		return nil
    77  	}
    78  }
    79  
    80  func (c *STDServerConfig) startWatcher() error {
    81  	watcher, err := fsnotify.NewWatcher()
    82  	if err != nil {
    83  		return err
    84  	}
    85  	if c.certificatePath != "" {
    86  		err = watcher.Add(c.certificatePath)
    87  		if err != nil {
    88  			return err
    89  		}
    90  	}
    91  	if c.keyPath != "" {
    92  		err = watcher.Add(c.keyPath)
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  	c.watcher = watcher
    98  	go c.loopUpdate()
    99  	return nil
   100  }
   101  
   102  func (c *STDServerConfig) loopUpdate() {
   103  	for {
   104  		select {
   105  		case event, ok := <-c.watcher.Events:
   106  			if !ok {
   107  				return
   108  			}
   109  			if event.Op&fsnotify.Write != fsnotify.Write {
   110  				continue
   111  			}
   112  			err := c.reloadKeyPair()
   113  			if err != nil {
   114  				c.logger.Error(E.Cause(err, "reload TLS key pair"))
   115  			}
   116  		case err, ok := <-c.watcher.Errors:
   117  			if !ok {
   118  				return
   119  			}
   120  			c.logger.Error(E.Cause(err, "fsnotify error"))
   121  		}
   122  	}
   123  }
   124  
   125  func (c *STDServerConfig) reloadKeyPair() error {
   126  	if c.certificatePath != "" {
   127  		certificate, err := os.ReadFile(c.certificatePath)
   128  		if err != nil {
   129  			return E.Cause(err, "reload certificate from ", c.certificatePath)
   130  		}
   131  		c.certificate = certificate
   132  	}
   133  	if c.keyPath != "" {
   134  		key, err := os.ReadFile(c.keyPath)
   135  		if err != nil {
   136  			return E.Cause(err, "reload key from ", c.keyPath)
   137  		}
   138  		c.key = key
   139  	}
   140  	keyPair, err := tls.X509KeyPair(c.certificate, c.key)
   141  	if err != nil {
   142  		return E.Cause(err, "reload key pair")
   143  	}
   144  	c.config.Certificates = []tls.Certificate{keyPair}
   145  	c.logger.Info("reloaded TLS certificate")
   146  	return nil
   147  }
   148  
   149  func (c *STDServerConfig) Close() error {
   150  	if c.acmeService != nil {
   151  		return c.acmeService.Close()
   152  	}
   153  	if c.watcher != nil {
   154  		return c.watcher.Close()
   155  	}
   156  	return nil
   157  }
   158  
   159  func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
   160  	if !options.Enabled {
   161  		return nil, nil
   162  	}
   163  	var tlsConfig *tls.Config
   164  	var acmeService adapter.Service
   165  	var err error
   166  	if options.ACME != nil && len(options.ACME.Domain) > 0 {
   167  		tlsConfig, acmeService, err = startACME(ctx, common.PtrValueOrDefault(options.ACME))
   168  		//nolint:staticcheck
   169  		if err != nil {
   170  			return nil, err
   171  		}
   172  		if options.Insecure {
   173  			return nil, errInsecureUnused
   174  		}
   175  	} else {
   176  		tlsConfig = &tls.Config{}
   177  	}
   178  	tlsConfig.Time = router.TimeFunc()
   179  	if options.ServerName != "" {
   180  		tlsConfig.ServerName = options.ServerName
   181  	}
   182  	if len(options.ALPN) > 0 {
   183  		tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
   184  	}
   185  	if options.MinVersion != "" {
   186  		minVersion, err := ParseTLSVersion(options.MinVersion)
   187  		if err != nil {
   188  			return nil, E.Cause(err, "parse min_version")
   189  		}
   190  		tlsConfig.MinVersion = minVersion
   191  	}
   192  	if options.MaxVersion != "" {
   193  		maxVersion, err := ParseTLSVersion(options.MaxVersion)
   194  		if err != nil {
   195  			return nil, E.Cause(err, "parse max_version")
   196  		}
   197  		tlsConfig.MaxVersion = maxVersion
   198  	}
   199  	if options.CipherSuites != nil {
   200  	find:
   201  		for _, cipherSuite := range options.CipherSuites {
   202  			for _, tlsCipherSuite := range tls.CipherSuites() {
   203  				if cipherSuite == tlsCipherSuite.Name {
   204  					tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
   205  					continue find
   206  				}
   207  			}
   208  			return nil, E.New("unknown cipher_suite: ", cipherSuite)
   209  		}
   210  	}
   211  	var certificate []byte
   212  	var key []byte
   213  	if acmeService == nil {
   214  		if options.Certificate != "" {
   215  			certificate = []byte(options.Certificate)
   216  		} else if options.CertificatePath != "" {
   217  			content, err := os.ReadFile(options.CertificatePath)
   218  			if err != nil {
   219  				return nil, E.Cause(err, "read certificate")
   220  			}
   221  			certificate = content
   222  		}
   223  		if options.Key != "" {
   224  			key = []byte(options.Key)
   225  		} else if options.KeyPath != "" {
   226  			content, err := os.ReadFile(options.KeyPath)
   227  			if err != nil {
   228  				return nil, E.Cause(err, "read key")
   229  			}
   230  			key = content
   231  		}
   232  		if certificate == nil && key == nil && options.Insecure {
   233  			tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   234  				return GenerateKeyPair(router.TimeFunc(), info.ServerName)
   235  			}
   236  		} else {
   237  			if certificate == nil {
   238  				return nil, E.New("missing certificate")
   239  			} else if key == nil {
   240  				return nil, E.New("missing key")
   241  			}
   242  
   243  			keyPair, err := tls.X509KeyPair(certificate, key)
   244  			if err != nil {
   245  				return nil, E.Cause(err, "parse x509 key pair")
   246  			}
   247  			tlsConfig.Certificates = []tls.Certificate{keyPair}
   248  		}
   249  	}
   250  	return &STDServerConfig{
   251  		config:          tlsConfig,
   252  		logger:          logger,
   253  		acmeService:     acmeService,
   254  		certificate:     certificate,
   255  		key:             key,
   256  		certificatePath: options.CertificatePath,
   257  		keyPath:         options.KeyPath,
   258  	}, nil
   259  }