github.com/sagernet/sing-box@v1.9.0-rc.20/common/tls/reality_client.go (about)

     1  //go:build with_utls
     2  
     3  package tls
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto/aes"
     9  	"crypto/cipher"
    10  	"crypto/ecdh"
    11  	"crypto/ed25519"
    12  	"crypto/hmac"
    13  	"crypto/sha256"
    14  	"crypto/sha512"
    15  	"crypto/tls"
    16  	"crypto/x509"
    17  	"encoding/base64"
    18  	"encoding/binary"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"io"
    22  	mRand "math/rand"
    23  	"net"
    24  	"net/http"
    25  	"reflect"
    26  	"strings"
    27  	"time"
    28  	"unsafe"
    29  
    30  	"github.com/sagernet/sing-box/option"
    31  	"github.com/sagernet/sing/common/debug"
    32  	E "github.com/sagernet/sing/common/exceptions"
    33  	aTLS "github.com/sagernet/sing/common/tls"
    34  	utls "github.com/sagernet/utls"
    35  
    36  	"golang.org/x/crypto/hkdf"
    37  	"golang.org/x/net/http2"
    38  )
    39  
    40  var _ ConfigCompat = (*RealityClientConfig)(nil)
    41  
    42  type RealityClientConfig struct {
    43  	uClient   *UTLSClientConfig
    44  	publicKey []byte
    45  	shortID   [8]byte
    46  }
    47  
    48  func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) {
    49  	if options.UTLS == nil || !options.UTLS.Enabled {
    50  		return nil, E.New("uTLS is required by reality client")
    51  	}
    52  
    53  	uClient, err := NewUTLSClient(ctx, serverAddress, options)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	publicKey, err := base64.RawURLEncoding.DecodeString(options.Reality.PublicKey)
    59  	if err != nil {
    60  		return nil, E.Cause(err, "decode public_key")
    61  	}
    62  	if len(publicKey) != 32 {
    63  		return nil, E.New("invalid public_key")
    64  	}
    65  	var shortID [8]byte
    66  	decodedLen, err := hex.Decode(shortID[:], []byte(options.Reality.ShortID))
    67  	if err != nil {
    68  		return nil, E.Cause(err, "decode short_id")
    69  	}
    70  	if decodedLen > 8 {
    71  		return nil, E.New("invalid short_id")
    72  	}
    73  	return &RealityClientConfig{uClient, publicKey, shortID}, nil
    74  }
    75  
    76  func (e *RealityClientConfig) ServerName() string {
    77  	return e.uClient.ServerName()
    78  }
    79  
    80  func (e *RealityClientConfig) SetServerName(serverName string) {
    81  	e.uClient.SetServerName(serverName)
    82  }
    83  
    84  func (e *RealityClientConfig) NextProtos() []string {
    85  	return e.uClient.NextProtos()
    86  }
    87  
    88  func (e *RealityClientConfig) SetNextProtos(nextProto []string) {
    89  	e.uClient.SetNextProtos(nextProto)
    90  }
    91  
    92  func (e *RealityClientConfig) Config() (*STDConfig, error) {
    93  	return nil, E.New("unsupported usage for reality")
    94  }
    95  
    96  func (e *RealityClientConfig) Client(conn net.Conn) (Conn, error) {
    97  	return ClientHandshake(context.Background(), conn, e)
    98  }
    99  
   100  func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
   101  	verifier := &realityVerifier{
   102  		serverName: e.uClient.ServerName(),
   103  	}
   104  	uConfig := e.uClient.config.Clone()
   105  	uConfig.InsecureSkipVerify = true
   106  	uConfig.SessionTicketsDisabled = true
   107  	uConfig.VerifyPeerCertificate = verifier.VerifyPeerCertificate
   108  	uConn := utls.UClient(conn, uConfig, e.uClient.id)
   109  	verifier.UConn = uConn
   110  	err := uConn.BuildHandshakeState()
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	if len(uConfig.NextProtos) > 0 {
   116  		for _, extension := range uConn.Extensions {
   117  			if alpnExtension, isALPN := extension.(*utls.ALPNExtension); isALPN {
   118  				alpnExtension.AlpnProtocols = uConfig.NextProtos
   119  				break
   120  			}
   121  		}
   122  	}
   123  
   124  	hello := uConn.HandshakeState.Hello
   125  	hello.SessionId = make([]byte, 32)
   126  	copy(hello.Raw[39:], hello.SessionId)
   127  
   128  	var nowTime time.Time
   129  	if uConfig.Time != nil {
   130  		nowTime = uConfig.Time()
   131  	} else {
   132  		nowTime = time.Now()
   133  	}
   134  	binary.BigEndian.PutUint64(hello.SessionId, uint64(nowTime.Unix()))
   135  
   136  	hello.SessionId[0] = 1
   137  	hello.SessionId[1] = 8
   138  	hello.SessionId[2] = 1
   139  	binary.BigEndian.PutUint32(hello.SessionId[4:], uint32(time.Now().Unix()))
   140  	copy(hello.SessionId[8:], e.shortID[:])
   141  	if debug.Enabled {
   142  		fmt.Printf("REALITY hello.sessionId[:16]: %v\n", hello.SessionId[:16])
   143  	}
   144  	publicKey, err := ecdh.X25519().NewPublicKey(e.publicKey)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	ecdheKey := uConn.HandshakeState.State13.EcdheKey
   149  	if ecdheKey == nil {
   150  		return nil, E.New("nil ecdhe_key")
   151  	}
   152  	authKey, err := ecdheKey.ECDH(publicKey)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	if authKey == nil {
   157  		return nil, E.New("nil auth_key")
   158  	}
   159  	verifier.authKey = authKey
   160  	_, err = hkdf.New(sha256.New, authKey, hello.Random[:20], []byte("REALITY")).Read(authKey)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	aesBlock, _ := aes.NewCipher(authKey)
   165  	aesGcmCipher, _ := cipher.NewGCM(aesBlock)
   166  	aesGcmCipher.Seal(hello.SessionId[:0], hello.Random[20:], hello.SessionId[:16], hello.Raw)
   167  	copy(hello.Raw[39:], hello.SessionId)
   168  	if debug.Enabled {
   169  		fmt.Printf("REALITY hello.sessionId: %v\n", hello.SessionId)
   170  		fmt.Printf("REALITY uConn.AuthKey: %v\n", authKey)
   171  	}
   172  
   173  	err = uConn.HandshakeContext(ctx)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	if debug.Enabled {
   179  		fmt.Printf("REALITY Conn.Verified: %v\n", verifier.verified)
   180  	}
   181  
   182  	if !verifier.verified {
   183  		go realityClientFallback(uConn, e.uClient.ServerName(), e.uClient.id)
   184  		return nil, E.New("reality verification failed")
   185  	}
   186  
   187  	return &utlsConnWrapper{uConn}, nil
   188  }
   189  
   190  func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
   191  	defer uConn.Close()
   192  	client := &http.Client{
   193  		Transport: &http2.Transport{
   194  			DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
   195  				return uConn, nil
   196  			},
   197  		},
   198  	}
   199  	request, _ := http.NewRequest("GET", "https://"+serverName, nil)
   200  	request.Header.Set("User-Agent", fingerprint.Client)
   201  	request.AddCookie(&http.Cookie{Name: "padding", Value: strings.Repeat("0", mRand.Intn(32)+30)})
   202  	response, err := client.Do(request)
   203  	if err != nil {
   204  		return
   205  	}
   206  	_, _ = io.Copy(io.Discard, response.Body)
   207  	response.Body.Close()
   208  }
   209  
   210  func (e *RealityClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) {
   211  	e.uClient.config.SessionIDGenerator = generator
   212  }
   213  
   214  func (e *RealityClientConfig) Clone() Config {
   215  	return &RealityClientConfig{
   216  		e.uClient.Clone().(*UTLSClientConfig),
   217  		e.publicKey,
   218  		e.shortID,
   219  	}
   220  }
   221  
   222  type realityVerifier struct {
   223  	*utls.UConn
   224  	serverName string
   225  	authKey    []byte
   226  	verified   bool
   227  }
   228  
   229  func (c *realityVerifier) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   230  	p, _ := reflect.TypeOf(c.Conn).Elem().FieldByName("peerCertificates")
   231  	certs := *(*([]*x509.Certificate))(unsafe.Pointer(uintptr(unsafe.Pointer(c.Conn)) + p.Offset))
   232  	if pub, ok := certs[0].PublicKey.(ed25519.PublicKey); ok {
   233  		h := hmac.New(sha512.New, c.authKey)
   234  		h.Write(pub)
   235  		if bytes.Equal(h.Sum(nil), certs[0].Signature) {
   236  			c.verified = true
   237  			return nil
   238  		}
   239  	}
   240  	opts := x509.VerifyOptions{
   241  		DNSName:       c.serverName,
   242  		Intermediates: x509.NewCertPool(),
   243  	}
   244  	for _, cert := range certs[1:] {
   245  		opts.Intermediates.AddCert(cert)
   246  	}
   247  	if _, err := certs[0].Verify(opts); err != nil {
   248  		return err
   249  	}
   250  	return nil
   251  }