github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/badtls/badtls.go (about)

     1  //go:build go1.20 && !go1.21
     2  
     3  package badtls
     4  
     5  import (
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"crypto/tls"
     9  	"encoding/binary"
    10  	"io"
    11  	"net"
    12  	"reflect"
    13  	"sync"
    14  	"sync/atomic"
    15  	"unsafe"
    16  
    17  	"github.com/inazumav/sing-box/log"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/buf"
    20  	"github.com/sagernet/sing/common/bufio"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	N "github.com/sagernet/sing/common/network"
    23  	aTLS "github.com/sagernet/sing/common/tls"
    24  )
    25  
    26  type Conn struct {
    27  	*tls.Conn
    28  	writer              N.ExtendedWriter
    29  	isHandshakeComplete *atomic.Bool
    30  	activeCall          *atomic.Int32
    31  	closeNotifySent     *bool
    32  	version             *uint16
    33  	rand                io.Reader
    34  	halfAccess          *sync.Mutex
    35  	halfError           *error
    36  	cipher              cipher.AEAD
    37  	explicitNonceLen    int
    38  	halfPtr             uintptr
    39  	halfSeq             []byte
    40  	halfScratchBuf      []byte
    41  }
    42  
    43  func TryCreate(conn aTLS.Conn) aTLS.Conn {
    44  	tlsConn, ok := conn.(*tls.Conn)
    45  	if !ok {
    46  		return conn
    47  	}
    48  	badConn, err := Create(tlsConn)
    49  	if err != nil {
    50  		log.Warn("initialize badtls: ", err)
    51  		return conn
    52  	}
    53  	return badConn
    54  }
    55  
    56  func Create(conn *tls.Conn) (aTLS.Conn, error) {
    57  	rawConn := reflect.Indirect(reflect.ValueOf(conn))
    58  	rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
    59  	if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
    60  		return nil, E.New("badtls: invalid isHandshakeComplete")
    61  	}
    62  	isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
    63  	if !isHandshakeComplete.Load() {
    64  		return nil, E.New("handshake not finished")
    65  	}
    66  	rawActiveCall := rawConn.FieldByName("activeCall")
    67  	if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
    68  		return nil, E.New("badtls: invalid active call")
    69  	}
    70  	activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
    71  	rawHalfConn := rawConn.FieldByName("out")
    72  	if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
    73  		return nil, E.New("badtls: invalid half conn")
    74  	}
    75  	rawVersion := rawConn.FieldByName("vers")
    76  	if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 {
    77  		return nil, E.New("badtls: invalid version")
    78  	}
    79  	version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr()))
    80  	rawCloseNotifySent := rawConn.FieldByName("closeNotifySent")
    81  	if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool {
    82  		return nil, E.New("badtls: invalid notify")
    83  	}
    84  	closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr()))
    85  	rawConfig := reflect.Indirect(rawConn.FieldByName("config"))
    86  	if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct {
    87  		return nil, E.New("badtls: bad config")
    88  	}
    89  	config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr()))
    90  	randReader := config.Rand
    91  	if randReader == nil {
    92  		randReader = rand.Reader
    93  	}
    94  	rawHalfMutex := rawHalfConn.FieldByName("Mutex")
    95  	if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
    96  		return nil, E.New("badtls: invalid half mutex")
    97  	}
    98  	halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
    99  	rawHalfError := rawHalfConn.FieldByName("err")
   100  	if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface {
   101  		return nil, E.New("badtls: invalid half error")
   102  	}
   103  	halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr()))
   104  	rawHalfCipherInterface := rawHalfConn.FieldByName("cipher")
   105  	if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface {
   106  		return nil, E.New("badtls: invalid cipher interface")
   107  	}
   108  	rawHalfCipher := rawHalfCipherInterface.Elem()
   109  	aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD)
   110  	if !loaded {
   111  		return nil, E.New("badtls: invalid AEAD cipher")
   112  	}
   113  	var explicitNonceLen int
   114  	switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName {
   115  	case "tls.prefixNonceAEAD":
   116  		explicitNonceLen = aeadCipher.NonceSize()
   117  	case "tls.xorNonceAEAD":
   118  	default:
   119  		return nil, E.New("badtls: unknown cipher type: ", cipherName)
   120  	}
   121  	rawHalfSeq := rawHalfConn.FieldByName("seq")
   122  	if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array {
   123  		return nil, E.New("badtls: invalid seq")
   124  	}
   125  	halfSeq := rawHalfSeq.Bytes()
   126  	rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf")
   127  	if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array {
   128  		return nil, E.New("badtls: invalid scratchBuf")
   129  	}
   130  	halfScratchBuf := rawHalfScratchBuf.Bytes()
   131  	return &Conn{
   132  		Conn:                conn,
   133  		writer:              bufio.NewExtendedWriter(conn.NetConn()),
   134  		isHandshakeComplete: isHandshakeComplete,
   135  		activeCall:          activeCall,
   136  		closeNotifySent:     closeNotifySent,
   137  		version:             version,
   138  		halfAccess:          halfAccess,
   139  		halfError:           halfError,
   140  		cipher:              aeadCipher,
   141  		explicitNonceLen:    explicitNonceLen,
   142  		rand:                randReader,
   143  		halfPtr:             rawHalfConn.UnsafeAddr(),
   144  		halfSeq:             halfSeq,
   145  		halfScratchBuf:      halfScratchBuf,
   146  	}, nil
   147  }
   148  
   149  func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
   150  	if buffer.Len() > maxPlaintext {
   151  		defer buffer.Release()
   152  		return common.Error(c.Write(buffer.Bytes()))
   153  	}
   154  	for {
   155  		x := c.activeCall.Load()
   156  		if x&1 != 0 {
   157  			return net.ErrClosed
   158  		}
   159  		if c.activeCall.CompareAndSwap(x, x+2) {
   160  			break
   161  		}
   162  	}
   163  	defer c.activeCall.Add(-2)
   164  	c.halfAccess.Lock()
   165  	defer c.halfAccess.Unlock()
   166  	if err := *c.halfError; err != nil {
   167  		return err
   168  	}
   169  	if *c.closeNotifySent {
   170  		return errShutdown
   171  	}
   172  	dataLen := buffer.Len()
   173  	dataBytes := buffer.Bytes()
   174  	outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen)
   175  	outBuf[0] = 23
   176  	version := *c.version
   177  	if version == 0 {
   178  		version = tls.VersionTLS10
   179  	} else if version == tls.VersionTLS13 {
   180  		version = tls.VersionTLS12
   181  	}
   182  	binary.BigEndian.PutUint16(outBuf[1:], version)
   183  	var nonce []byte
   184  	if c.explicitNonceLen > 0 {
   185  		nonce = outBuf[5 : 5+c.explicitNonceLen]
   186  		if c.explicitNonceLen < 16 {
   187  			copy(nonce, c.halfSeq)
   188  		} else {
   189  			if _, err := io.ReadFull(c.rand, nonce); err != nil {
   190  				return err
   191  			}
   192  		}
   193  	}
   194  	if len(nonce) == 0 {
   195  		nonce = c.halfSeq
   196  	}
   197  	if *c.version == tls.VersionTLS13 {
   198  		buffer.FreeBytes()[0] = 23
   199  		binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead()))
   200  		c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen])
   201  		buffer.Extend(1 + c.cipher.Overhead())
   202  	} else {
   203  		binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen))
   204  		additionalData := append(c.halfScratchBuf[:0], c.halfSeq...)
   205  		additionalData = append(additionalData, outBuf[:recordHeaderLen]...)
   206  		c.cipher.Seal(outBuf, nonce, dataBytes, additionalData)
   207  		buffer.Extend(c.cipher.Overhead())
   208  		binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
   209  	}
   210  	incSeq(c.halfPtr)
   211  	log.Trace("badtls write ", buffer.Len())
   212  	return c.writer.WriteBuffer(buffer)
   213  }
   214  
   215  func (c *Conn) FrontHeadroom() int {
   216  	return recordHeaderLen + c.explicitNonceLen
   217  }
   218  
   219  func (c *Conn) RearHeadroom() int {
   220  	return 1 + c.cipher.Overhead()
   221  }
   222  
   223  func (c *Conn) WriterMTU() int {
   224  	return maxPlaintext
   225  }
   226  
   227  func (c *Conn) Upstream() any {
   228  	return c.Conn
   229  }
   230  
   231  func (c *Conn) UpstreamWriter() any {
   232  	return c.NetConn()
   233  }