github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria/xplus.go (about)

     1  package hysteria
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"math/rand"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing/common"
    11  	"github.com/sagernet/sing/common/buf"
    12  	"github.com/sagernet/sing/common/bufio"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  )
    16  
    17  const xplusSaltLen = 16
    18  
    19  func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn {
    20  	vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn)
    21  	if isVectorised {
    22  		return &VectorisedXPlusConn{
    23  			XPlusPacketConn: XPlusPacketConn{
    24  				PacketConn: conn,
    25  				key:        key,
    26  				rand:       rand.New(rand.NewSource(time.Now().UnixNano())),
    27  			},
    28  			writer: vectorisedWriter,
    29  		}
    30  	} else {
    31  		return &XPlusPacketConn{
    32  			PacketConn: conn,
    33  			key:        key,
    34  			rand:       rand.New(rand.NewSource(time.Now().UnixNano())),
    35  		}
    36  	}
    37  }
    38  
    39  type XPlusPacketConn struct {
    40  	net.PacketConn
    41  	key        []byte
    42  	randAccess sync.Mutex
    43  	rand       *rand.Rand
    44  }
    45  
    46  func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    47  	n, addr, err = c.PacketConn.ReadFrom(p)
    48  	if err != nil {
    49  		return
    50  	} else if n < xplusSaltLen {
    51  		n = 0
    52  		return
    53  	}
    54  	key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...))
    55  	for i := range p[xplusSaltLen:] {
    56  		p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size]
    57  	}
    58  	n -= xplusSaltLen
    59  	return
    60  }
    61  
    62  func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    63  	// can't use unsafe buffer on WriteTo
    64  	buffer := buf.NewSize(len(p) + xplusSaltLen)
    65  	defer buffer.Release()
    66  	salt := buffer.Extend(xplusSaltLen)
    67  	c.randAccess.Lock()
    68  	_, _ = c.rand.Read(salt)
    69  	c.randAccess.Unlock()
    70  	key := sha256.Sum256(append(c.key, salt...))
    71  	for i := range p {
    72  		common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size]))
    73  	}
    74  	return c.PacketConn.WriteTo(buffer.Bytes(), addr)
    75  }
    76  
    77  func (c *XPlusPacketConn) Upstream() any {
    78  	return c.PacketConn
    79  }
    80  
    81  type VectorisedXPlusConn struct {
    82  	XPlusPacketConn
    83  	writer N.VectorisedPacketWriter
    84  }
    85  
    86  func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    87  	header := buf.NewSize(xplusSaltLen)
    88  	defer header.Release()
    89  	salt := header.Extend(xplusSaltLen)
    90  	c.randAccess.Lock()
    91  	_, _ = c.rand.Read(salt)
    92  	c.randAccess.Unlock()
    93  	key := sha256.Sum256(append(c.key, salt...))
    94  	for i := range p {
    95  		p[i] ^= key[i%sha256.Size]
    96  	}
    97  	return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr))
    98  }
    99  
   100  func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
   101  	header := buf.NewSize(xplusSaltLen)
   102  	defer header.Release()
   103  	salt := header.Extend(xplusSaltLen)
   104  	c.randAccess.Lock()
   105  	_, _ = c.rand.Read(salt)
   106  	c.randAccess.Unlock()
   107  	key := sha256.Sum256(append(c.key, salt...))
   108  	var index int
   109  	for _, buffer := range buffers {
   110  		data := buffer.Bytes()
   111  		for i := range data {
   112  			data[i] ^= key[index%sha256.Size]
   113  			index++
   114  		}
   115  	}
   116  	buffers = append([]*buf.Buffer{header}, buffers...)
   117  	return c.writer.WriteVectorisedPacket(buffers, destination)
   118  }