github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/tls/ech_server.go (about)

     1  //go:build with_ech
     2  
     3  package tls
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"encoding/pem"
     9  	"net"
    10  	"os"
    11  	"strings"
    12  
    13  	"github.com/inazumav/sing-box/log"
    14  	"github.com/inazumav/sing-box/option"
    15  	cftls "github.com/sagernet/cloudflare-tls"
    16  	E "github.com/sagernet/sing/common/exceptions"
    17  	"github.com/sagernet/sing/common/ntp"
    18  
    19  	"github.com/fsnotify/fsnotify"
    20  )
    21  
    22  type echServerConfig struct {
    23  	config          *cftls.Config
    24  	logger          log.Logger
    25  	certificate     []byte
    26  	key             []byte
    27  	certificatePath string
    28  	keyPath         string
    29  	watcher         *fsnotify.Watcher
    30  	echKeyPath      string
    31  	echWatcher      *fsnotify.Watcher
    32  }
    33  
    34  func (c *echServerConfig) ServerName() string {
    35  	return c.config.ServerName
    36  }
    37  
    38  func (c *echServerConfig) SetServerName(serverName string) {
    39  	c.config.ServerName = serverName
    40  }
    41  
    42  func (c *echServerConfig) NextProtos() []string {
    43  	return c.config.NextProtos
    44  }
    45  
    46  func (c *echServerConfig) SetNextProtos(nextProto []string) {
    47  	c.config.NextProtos = nextProto
    48  }
    49  
    50  func (c *echServerConfig) Config() (*STDConfig, error) {
    51  	return nil, E.New("unsupported usage for ECH")
    52  }
    53  
    54  func (c *echServerConfig) Client(conn net.Conn) (Conn, error) {
    55  	return &echConnWrapper{cftls.Client(conn, c.config)}, nil
    56  }
    57  
    58  func (c *echServerConfig) Server(conn net.Conn) (Conn, error) {
    59  	return &echConnWrapper{cftls.Server(conn, c.config)}, nil
    60  }
    61  
    62  func (c *echServerConfig) Clone() Config {
    63  	return &echServerConfig{
    64  		config: c.config.Clone(),
    65  	}
    66  }
    67  
    68  func (c *echServerConfig) Start() error {
    69  	if c.certificatePath != "" && c.keyPath != "" {
    70  		err := c.startWatcher()
    71  		if err != nil {
    72  			c.logger.Warn("create fsnotify watcher: ", err)
    73  		}
    74  	}
    75  	if c.echKeyPath != "" {
    76  		err := c.startECHWatcher()
    77  		if err != nil {
    78  			c.logger.Warn("create fsnotify watcher: ", err)
    79  		}
    80  	}
    81  	return nil
    82  }
    83  
    84  func (c *echServerConfig) startWatcher() error {
    85  	watcher, err := fsnotify.NewWatcher()
    86  	if err != nil {
    87  		return err
    88  	}
    89  	if c.certificatePath != "" {
    90  		err = watcher.Add(c.certificatePath)
    91  		if err != nil {
    92  			return err
    93  		}
    94  	}
    95  	if c.keyPath != "" {
    96  		err = watcher.Add(c.keyPath)
    97  		if err != nil {
    98  			return err
    99  		}
   100  	}
   101  	c.watcher = watcher
   102  	go c.loopUpdate()
   103  	return nil
   104  }
   105  
   106  func (c *echServerConfig) loopUpdate() {
   107  	for {
   108  		select {
   109  		case event, ok := <-c.watcher.Events:
   110  			if !ok {
   111  				return
   112  			}
   113  			if event.Op&fsnotify.Write != fsnotify.Write {
   114  				continue
   115  			}
   116  			err := c.reloadKeyPair()
   117  			if err != nil {
   118  				c.logger.Error(E.Cause(err, "reload TLS key pair"))
   119  			}
   120  		case err, ok := <-c.watcher.Errors:
   121  			if !ok {
   122  				return
   123  			}
   124  			c.logger.Error(E.Cause(err, "fsnotify error"))
   125  		}
   126  	}
   127  }
   128  
   129  func (c *echServerConfig) reloadKeyPair() error {
   130  	if c.certificatePath != "" {
   131  		certificate, err := os.ReadFile(c.certificatePath)
   132  		if err != nil {
   133  			return E.Cause(err, "reload certificate from ", c.certificatePath)
   134  		}
   135  		c.certificate = certificate
   136  	}
   137  	if c.keyPath != "" {
   138  		key, err := os.ReadFile(c.keyPath)
   139  		if err != nil {
   140  			return E.Cause(err, "reload key from ", c.keyPath)
   141  		}
   142  		c.key = key
   143  	}
   144  	keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
   145  	if err != nil {
   146  		return E.Cause(err, "reload key pair")
   147  	}
   148  	c.config.Certificates = []cftls.Certificate{keyPair}
   149  	c.logger.Info("reloaded TLS certificate")
   150  	return nil
   151  }
   152  
   153  func (c *echServerConfig) startECHWatcher() error {
   154  	watcher, err := fsnotify.NewWatcher()
   155  	if err != nil {
   156  		return err
   157  	}
   158  	err = watcher.Add(c.echKeyPath)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	c.echWatcher = watcher
   163  	go c.loopECHUpdate()
   164  	return nil
   165  }
   166  
   167  func (c *echServerConfig) loopECHUpdate() {
   168  	for {
   169  		select {
   170  		case event, ok := <-c.echWatcher.Events:
   171  			if !ok {
   172  				return
   173  			}
   174  			if event.Op&fsnotify.Write != fsnotify.Write {
   175  				continue
   176  			}
   177  			err := c.reloadECHKey()
   178  			if err != nil {
   179  				c.logger.Error(E.Cause(err, "reload ECH key"))
   180  			}
   181  		case err, ok := <-c.echWatcher.Errors:
   182  			if !ok {
   183  				return
   184  			}
   185  			c.logger.Error(E.Cause(err, "fsnotify error"))
   186  		}
   187  	}
   188  }
   189  
   190  func (c *echServerConfig) reloadECHKey() error {
   191  	echKeyContent, err := os.ReadFile(c.echKeyPath)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	block, rest := pem.Decode(echKeyContent)
   196  	if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
   197  		return E.New("invalid ECH keys pem")
   198  	}
   199  	echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
   200  	if err != nil {
   201  		return E.Cause(err, "parse ECH keys")
   202  	}
   203  	echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
   204  	if err != nil {
   205  		return E.Cause(err, "create ECH key set")
   206  	}
   207  	c.config.ServerECHProvider = echKeySet
   208  	c.logger.Info("reloaded ECH keys")
   209  	return nil
   210  }
   211  
   212  func (c *echServerConfig) Close() error {
   213  	var err error
   214  	if c.watcher != nil {
   215  		err = E.Append(err, c.watcher.Close(), func(err error) error {
   216  			return E.Cause(err, "close certificate watcher")
   217  		})
   218  	}
   219  	if c.echWatcher != nil {
   220  		err = E.Append(err, c.echWatcher.Close(), func(err error) error {
   221  			return E.Cause(err, "close ECH key watcher")
   222  		})
   223  	}
   224  	return err
   225  }
   226  
   227  func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
   228  	if !options.Enabled {
   229  		return nil, nil
   230  	}
   231  	var tlsConfig cftls.Config
   232  	if options.ACME != nil && len(options.ACME.Domain) > 0 {
   233  		return nil, E.New("acme is unavailable in ech")
   234  	}
   235  	tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
   236  	if options.ServerName != "" {
   237  		tlsConfig.ServerName = options.ServerName
   238  	}
   239  	if len(options.ALPN) > 0 {
   240  		tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
   241  	}
   242  	if options.MinVersion != "" {
   243  		minVersion, err := ParseTLSVersion(options.MinVersion)
   244  		if err != nil {
   245  			return nil, E.Cause(err, "parse min_version")
   246  		}
   247  		tlsConfig.MinVersion = minVersion
   248  	}
   249  	if options.MaxVersion != "" {
   250  		maxVersion, err := ParseTLSVersion(options.MaxVersion)
   251  		if err != nil {
   252  			return nil, E.Cause(err, "parse max_version")
   253  		}
   254  		tlsConfig.MaxVersion = maxVersion
   255  	}
   256  	if options.CipherSuites != nil {
   257  	find:
   258  		for _, cipherSuite := range options.CipherSuites {
   259  			for _, tlsCipherSuite := range tls.CipherSuites() {
   260  				if cipherSuite == tlsCipherSuite.Name {
   261  					tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
   262  					continue find
   263  				}
   264  			}
   265  			return nil, E.New("unknown cipher_suite: ", cipherSuite)
   266  		}
   267  	}
   268  	var certificate []byte
   269  	var key []byte
   270  	if len(options.Certificate) > 0 {
   271  		certificate = []byte(strings.Join(options.Certificate, "\n"))
   272  	} else if options.CertificatePath != "" {
   273  		content, err := os.ReadFile(options.CertificatePath)
   274  		if err != nil {
   275  			return nil, E.Cause(err, "read certificate")
   276  		}
   277  		certificate = content
   278  	}
   279  	if len(options.Key) > 0 {
   280  		key = []byte(strings.Join(options.Key, "\n"))
   281  	} else if options.KeyPath != "" {
   282  		content, err := os.ReadFile(options.KeyPath)
   283  		if err != nil {
   284  			return nil, E.Cause(err, "read key")
   285  		}
   286  		key = content
   287  	}
   288  
   289  	if certificate == nil {
   290  		return nil, E.New("missing certificate")
   291  	} else if key == nil {
   292  		return nil, E.New("missing key")
   293  	}
   294  
   295  	keyPair, err := cftls.X509KeyPair(certificate, key)
   296  	if err != nil {
   297  		return nil, E.Cause(err, "parse x509 key pair")
   298  	}
   299  	tlsConfig.Certificates = []cftls.Certificate{keyPair}
   300  
   301  	var echKey []byte
   302  	if len(options.ECH.Key) > 0 {
   303  		echKey = []byte(strings.Join(options.ECH.Key, "\n"))
   304  	} else if options.KeyPath != "" {
   305  		content, err := os.ReadFile(options.ECH.KeyPath)
   306  		if err != nil {
   307  			return nil, E.Cause(err, "read ECH key")
   308  		}
   309  		echKey = content
   310  	} else {
   311  		return nil, E.New("missing ECH key")
   312  	}
   313  
   314  	block, rest := pem.Decode(echKey)
   315  	if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
   316  		return nil, E.New("invalid ECH keys pem")
   317  	}
   318  
   319  	echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
   320  	if err != nil {
   321  		return nil, E.Cause(err, "parse ECH keys")
   322  	}
   323  
   324  	echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
   325  	if err != nil {
   326  		return nil, E.Cause(err, "create ECH key set")
   327  	}
   328  
   329  	tlsConfig.ECHEnabled = true
   330  	tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled
   331  	tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled
   332  	tlsConfig.ServerECHProvider = echKeySet
   333  
   334  	return &echServerConfig{
   335  		config:          &tlsConfig,
   336  		logger:          logger,
   337  		certificate:     certificate,
   338  		key:             key,
   339  		certificatePath: options.CertificatePath,
   340  		keyPath:         options.KeyPath,
   341  		echKeyPath:      options.ECH.KeyPath,
   342  	}, nil
   343  }