github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/tls/reality_client.go (about)

     1  //go:build with_utls
     2  
     3  package tls
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto/aes"
     9  	"crypto/cipher"
    10  	"crypto/ed25519"
    11  	"crypto/hmac"
    12  	"crypto/sha256"
    13  	"crypto/sha512"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"encoding/base64"
    17  	"encoding/binary"
    18  	"encoding/hex"
    19  	"fmt"
    20  	"io"
    21  	mRand "math/rand"
    22  	"net"
    23  	"net/http"
    24  	"reflect"
    25  	"strings"
    26  	"time"
    27  	"unsafe"
    28  
    29  	"github.com/inazumav/sing-box/option"
    30  	"github.com/sagernet/sing/common/debug"
    31  	E "github.com/sagernet/sing/common/exceptions"
    32  	aTLS "github.com/sagernet/sing/common/tls"
    33  	utls "github.com/sagernet/utls"
    34  
    35  	"golang.org/x/crypto/hkdf"
    36  	"golang.org/x/net/http2"
    37  )
    38  
    39  var _ ConfigCompat = (*RealityClientConfig)(nil)
    40  
    41  type RealityClientConfig struct {
    42  	uClient   *UTLSClientConfig
    43  	publicKey []byte
    44  	shortID   [8]byte
    45  }
    46  
    47  func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) {
    48  	if options.UTLS == nil || !options.UTLS.Enabled {
    49  		return nil, E.New("uTLS is required by reality client")
    50  	}
    51  
    52  	uClient, err := NewUTLSClient(ctx, serverAddress, options)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	publicKey, err := base64.RawURLEncoding.DecodeString(options.Reality.PublicKey)
    58  	if err != nil {
    59  		return nil, E.Cause(err, "decode public_key")
    60  	}
    61  	if len(publicKey) != 32 {
    62  		return nil, E.New("invalid public_key")
    63  	}
    64  	var shortID [8]byte
    65  	decodedLen, err := hex.Decode(shortID[:], []byte(options.Reality.ShortID))
    66  	if err != nil {
    67  		return nil, E.Cause(err, "decode short_id")
    68  	}
    69  	if decodedLen > 8 {
    70  		return nil, E.New("invalid short_id")
    71  	}
    72  	return &RealityClientConfig{uClient, publicKey, shortID}, nil
    73  }
    74  
    75  func (e *RealityClientConfig) ServerName() string {
    76  	return e.uClient.ServerName()
    77  }
    78  
    79  func (e *RealityClientConfig) SetServerName(serverName string) {
    80  	e.uClient.SetServerName(serverName)
    81  }
    82  
    83  func (e *RealityClientConfig) NextProtos() []string {
    84  	return e.uClient.NextProtos()
    85  }
    86  
    87  func (e *RealityClientConfig) SetNextProtos(nextProto []string) {
    88  	e.uClient.SetNextProtos(nextProto)
    89  }
    90  
    91  func (e *RealityClientConfig) Config() (*STDConfig, error) {
    92  	return nil, E.New("unsupported usage for reality")
    93  }
    94  
    95  func (e *RealityClientConfig) Client(conn net.Conn) (Conn, error) {
    96  	return ClientHandshake(context.Background(), conn, e)
    97  }
    98  
    99  func (e *RealityClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) {
   100  	verifier := &realityVerifier{
   101  		serverName: e.uClient.ServerName(),
   102  	}
   103  	uConfig := e.uClient.config.Clone()
   104  	uConfig.InsecureSkipVerify = true
   105  	uConfig.SessionTicketsDisabled = true
   106  	uConfig.VerifyPeerCertificate = verifier.VerifyPeerCertificate
   107  	uConn := utls.UClient(conn, uConfig, e.uClient.id)
   108  	verifier.UConn = uConn
   109  	err := uConn.BuildHandshakeState()
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	if len(uConfig.NextProtos) > 0 {
   115  		for _, extension := range uConn.Extensions {
   116  			if alpnExtension, isALPN := extension.(*utls.ALPNExtension); isALPN {
   117  				alpnExtension.AlpnProtocols = uConfig.NextProtos
   118  				break
   119  			}
   120  		}
   121  	}
   122  
   123  	hello := uConn.HandshakeState.Hello
   124  	hello.SessionId = make([]byte, 32)
   125  	copy(hello.Raw[39:], hello.SessionId)
   126  
   127  	var nowTime time.Time
   128  	if uConfig.Time != nil {
   129  		nowTime = uConfig.Time()
   130  	} else {
   131  		nowTime = time.Now()
   132  	}
   133  	binary.BigEndian.PutUint64(hello.SessionId, uint64(nowTime.Unix()))
   134  
   135  	hello.SessionId[0] = 1
   136  	hello.SessionId[1] = 8
   137  	hello.SessionId[2] = 1
   138  	binary.BigEndian.PutUint32(hello.SessionId[4:], uint32(time.Now().Unix()))
   139  	copy(hello.SessionId[8:], e.shortID[:])
   140  
   141  	if debug.Enabled {
   142  		fmt.Printf("REALITY hello.sessionId[:16]: %v\n", hello.SessionId[:16])
   143  	}
   144  
   145  	authKey := uConn.HandshakeState.State13.EcdheParams.SharedKey(e.publicKey)
   146  	if authKey == nil {
   147  		return nil, E.New("nil auth_key")
   148  	}
   149  	verifier.authKey = authKey
   150  	_, err = hkdf.New(sha256.New, authKey, hello.Random[:20], []byte("REALITY")).Read(authKey)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	aesBlock, _ := aes.NewCipher(authKey)
   155  	aesGcmCipher, _ := cipher.NewGCM(aesBlock)
   156  	aesGcmCipher.Seal(hello.SessionId[:0], hello.Random[20:], hello.SessionId[:16], hello.Raw)
   157  	copy(hello.Raw[39:], hello.SessionId)
   158  	if debug.Enabled {
   159  		fmt.Printf("REALITY hello.sessionId: %v\n", hello.SessionId)
   160  		fmt.Printf("REALITY uConn.AuthKey: %v\n", authKey)
   161  	}
   162  
   163  	err = uConn.HandshakeContext(ctx)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	if debug.Enabled {
   169  		fmt.Printf("REALITY Conn.Verified: %v\n", verifier.verified)
   170  	}
   171  
   172  	if !verifier.verified {
   173  		go realityClientFallback(uConn, e.uClient.ServerName(), e.uClient.id)
   174  		return nil, E.New("reality verification failed")
   175  	}
   176  
   177  	return &utlsConnWrapper{uConn}, nil
   178  }
   179  
   180  func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
   181  	defer uConn.Close()
   182  	client := &http.Client{
   183  		Transport: &http2.Transport{
   184  			DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
   185  				return uConn, nil
   186  			},
   187  		},
   188  	}
   189  	request, _ := http.NewRequest("GET", "https://"+serverName, nil)
   190  	request.Header.Set("User-Agent", fingerprint.Client)
   191  	request.AddCookie(&http.Cookie{Name: "padding", Value: strings.Repeat("0", mRand.Intn(32)+30)})
   192  	response, err := client.Do(request)
   193  	if err != nil {
   194  		return
   195  	}
   196  	_, _ = io.Copy(io.Discard, response.Body)
   197  	response.Body.Close()
   198  }
   199  
   200  func (e *RealityClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) {
   201  	e.uClient.config.SessionIDGenerator = generator
   202  }
   203  
   204  func (e *RealityClientConfig) Clone() Config {
   205  	return &RealityClientConfig{
   206  		e.uClient.Clone().(*UTLSClientConfig),
   207  		e.publicKey,
   208  		e.shortID,
   209  	}
   210  }
   211  
   212  type realityVerifier struct {
   213  	*utls.UConn
   214  	serverName string
   215  	authKey    []byte
   216  	verified   bool
   217  }
   218  
   219  func (c *realityVerifier) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   220  	p, _ := reflect.TypeOf(c.Conn).Elem().FieldByName("peerCertificates")
   221  	certs := *(*([]*x509.Certificate))(unsafe.Pointer(uintptr(unsafe.Pointer(c.Conn)) + p.Offset))
   222  	if pub, ok := certs[0].PublicKey.(ed25519.PublicKey); ok {
   223  		h := hmac.New(sha512.New, c.authKey)
   224  		h.Write(pub)
   225  		if bytes.Equal(h.Sum(nil), certs[0].Signature) {
   226  			c.verified = true
   227  			return nil
   228  		}
   229  	}
   230  	opts := x509.VerifyOptions{
   231  		DNSName:       c.serverName,
   232  		Intermediates: x509.NewCertPool(),
   233  	}
   234  	for _, cert := range certs[1:] {
   235  		opts.Intermediates.AddCert(cert)
   236  	}
   237  	if _, err := certs[0].Verify(opts); err != nil {
   238  		return err
   239  	}
   240  	return nil
   241  }