github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/padding.go (about)

     1  package mux
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"math/rand"
     7  	"net"
     8  
     9  	"github.com/sagernet/sing/common"
    10  	"github.com/sagernet/sing/common/buf"
    11  	"github.com/sagernet/sing/common/bufio"
    12  	N "github.com/sagernet/sing/common/network"
    13  	"github.com/sagernet/sing/common/rw"
    14  )
    15  
    16  const kFirstPaddings = 16
    17  
    18  type paddingConn struct {
    19  	N.ExtendedConn
    20  	writer           N.VectorisedWriter
    21  	readPadding      int
    22  	writePadding     int
    23  	readRemaining    int
    24  	paddingRemaining int
    25  }
    26  
    27  func newPaddingConn(conn net.Conn) net.Conn {
    28  	writer, isVectorised := bufio.CreateVectorisedWriter(conn)
    29  	if isVectorised {
    30  		return &vectorisedPaddingConn{
    31  			paddingConn{
    32  				ExtendedConn: bufio.NewExtendedConn(conn),
    33  				writer:       bufio.NewVectorisedWriter(conn),
    34  			},
    35  			writer,
    36  		}
    37  	} else {
    38  		return &paddingConn{
    39  			ExtendedConn: bufio.NewExtendedConn(conn),
    40  			writer:       bufio.NewVectorisedWriter(conn),
    41  		}
    42  	}
    43  }
    44  
    45  func (c *paddingConn) Read(p []byte) (n int, err error) {
    46  	if c.readRemaining > 0 {
    47  		if len(p) > c.readRemaining {
    48  			p = p[:c.readRemaining]
    49  		}
    50  		n, err = c.ExtendedConn.Read(p)
    51  		if err != nil {
    52  			return
    53  		}
    54  		c.readRemaining -= n
    55  		return
    56  	}
    57  	if c.paddingRemaining > 0 {
    58  		err = rw.SkipN(c.ExtendedConn, c.paddingRemaining)
    59  		if err != nil {
    60  			return
    61  		}
    62  		c.paddingRemaining = 0
    63  	}
    64  	if c.readPadding < kFirstPaddings {
    65  		var paddingHdr []byte
    66  		if len(p) >= 4 {
    67  			paddingHdr = p[:4]
    68  		} else {
    69  			paddingHdr = make([]byte, 4)
    70  		}
    71  		_, err = io.ReadFull(c.ExtendedConn, paddingHdr)
    72  		if err != nil {
    73  			return
    74  		}
    75  		originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
    76  		paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
    77  		if len(p) > originalDataSize {
    78  			p = p[:originalDataSize]
    79  		}
    80  		n, err = c.ExtendedConn.Read(p)
    81  		if err != nil {
    82  			return
    83  		}
    84  		c.readPadding++
    85  		c.readRemaining = originalDataSize - n
    86  		c.paddingRemaining = paddingLen
    87  		return
    88  	}
    89  	return c.ExtendedConn.Read(p)
    90  }
    91  
    92  func (c *paddingConn) Write(p []byte) (n int, err error) {
    93  	for pLen := len(p); pLen > 0; {
    94  		var data []byte
    95  		if pLen > 65535 {
    96  			data = p[:65535]
    97  			p = p[65535:]
    98  			pLen -= 65535
    99  		} else {
   100  			data = p
   101  			pLen = 0
   102  		}
   103  		var writeN int
   104  		writeN, err = c.write(data)
   105  		n += writeN
   106  		if err != nil {
   107  			break
   108  		}
   109  	}
   110  	return n, err
   111  }
   112  
   113  func (c *paddingConn) write(p []byte) (n int, err error) {
   114  	if c.writePadding < kFirstPaddings {
   115  		paddingLen := 256 + rand.Intn(512)
   116  		buffer := buf.NewSize(4 + len(p) + paddingLen)
   117  		defer buffer.Release()
   118  		header := buffer.Extend(4)
   119  		binary.BigEndian.PutUint16(header[:2], uint16(len(p)))
   120  		binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
   121  		common.Must1(buffer.Write(p))
   122  		buffer.Extend(paddingLen)
   123  		_, err = c.ExtendedConn.Write(buffer.Bytes())
   124  		if err == nil {
   125  			n = len(p)
   126  		}
   127  		c.writePadding++
   128  		return
   129  	}
   130  	return c.ExtendedConn.Write(p)
   131  }
   132  
   133  func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error {
   134  	p := buffer.FreeBytes()
   135  	if c.readRemaining > 0 {
   136  		if len(p) > c.readRemaining {
   137  			p = p[:c.readRemaining]
   138  		}
   139  		n, err := c.ExtendedConn.Read(p)
   140  		if err != nil {
   141  			return err
   142  		}
   143  		c.readRemaining -= n
   144  		buffer.Truncate(n)
   145  		return nil
   146  	}
   147  	if c.paddingRemaining > 0 {
   148  		err := rw.SkipN(c.ExtendedConn, c.paddingRemaining)
   149  		if err != nil {
   150  			return err
   151  		}
   152  		c.paddingRemaining = 0
   153  	}
   154  	if c.readPadding < kFirstPaddings {
   155  		var paddingHdr []byte
   156  		if len(p) >= 4 {
   157  			paddingHdr = p[:4]
   158  		} else {
   159  			paddingHdr = make([]byte, 4)
   160  		}
   161  		_, err := io.ReadFull(c.ExtendedConn, paddingHdr)
   162  		if err != nil {
   163  			return err
   164  		}
   165  		originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
   166  		paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:]))
   167  
   168  		if len(p) > originalDataSize {
   169  			p = p[:originalDataSize]
   170  		}
   171  		n, err := c.ExtendedConn.Read(p)
   172  		if err != nil {
   173  			return err
   174  		}
   175  		c.readPadding++
   176  		c.readRemaining = originalDataSize - n
   177  		c.paddingRemaining = paddingLen
   178  		buffer.Truncate(n)
   179  		return nil
   180  	}
   181  	return c.ExtendedConn.ReadBuffer(buffer)
   182  }
   183  
   184  func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error {
   185  	if c.writePadding < kFirstPaddings {
   186  		bufferLen := buffer.Len()
   187  		if bufferLen > 65535 {
   188  			return common.Error(c.Write(buffer.Bytes()))
   189  		}
   190  		paddingLen := 256 + rand.Intn(512)
   191  		header := buffer.ExtendHeader(4)
   192  		binary.BigEndian.PutUint16(header[:2], uint16(bufferLen))
   193  		binary.BigEndian.PutUint16(header[2:], uint16(paddingLen))
   194  		buffer.Extend(paddingLen)
   195  		c.writePadding++
   196  	}
   197  	return c.ExtendedConn.WriteBuffer(buffer)
   198  }
   199  
   200  func (c *paddingConn) FrontHeadroom() int {
   201  	return 4 + 256 + 1024
   202  }
   203  
   204  func (c *paddingConn) Upstream() any {
   205  	return c.ExtendedConn
   206  }
   207  
   208  type vectorisedPaddingConn struct {
   209  	paddingConn
   210  	writer N.VectorisedWriter
   211  }
   212  
   213  func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error {
   214  	if c.writePadding < kFirstPaddings {
   215  		bufferLen := buf.LenMulti(buffers)
   216  		if bufferLen > 65535 {
   217  			defer buf.ReleaseMulti(buffers)
   218  			for _, buffer := range buffers {
   219  				_, err := c.Write(buffer.Bytes())
   220  				if err != nil {
   221  					return err
   222  				}
   223  			}
   224  			return nil
   225  		}
   226  		paddingLen := 256 + rand.Intn(512)
   227  		header := buf.NewSize(4)
   228  		common.Must(
   229  			binary.Write(header, binary.BigEndian, uint16(bufferLen)),
   230  			binary.Write(header, binary.BigEndian, uint16(paddingLen)),
   231  		)
   232  		c.writePadding++
   233  		padding := buf.NewSize(paddingLen)
   234  		padding.Extend(paddingLen)
   235  		buffers = append(append([]*buf.Buffer{header}, buffers...), padding)
   236  	}
   237  	return c.writer.WriteVectorised(buffers)
   238  }