github.com/sagernet/sing-shadowsocks2@v0.2.0/internal/shadowio/writer.go (about)

     1  package shadowio
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/sagernet/sing/common"
    10  	"github.com/sagernet/sing/common/buf"
    11  	"github.com/sagernet/sing/common/bufio"
    12  	N "github.com/sagernet/sing/common/network"
    13  )
    14  
    15  type Writer struct {
    16  	WriterInterface
    17  	writer        N.ExtendedWriter
    18  	cipher        cipher.AEAD
    19  	maxPacketSize int
    20  	nonce         []byte
    21  	access        sync.Mutex
    22  }
    23  
    24  func NewWriter(writer io.Writer, cipher cipher.AEAD, nonce []byte, maxPacketSize int) *Writer {
    25  	if len(nonce) == 0 {
    26  		nonce = make([]byte, cipher.NonceSize())
    27  	}
    28  	return &Writer{
    29  		writer:        bufio.NewExtendedWriter(writer),
    30  		cipher:        cipher,
    31  		nonce:         nonce,
    32  		maxPacketSize: maxPacketSize,
    33  	}
    34  }
    35  
    36  func (w *Writer) Encrypt(destination []byte, source []byte) {
    37  	w.cipher.Seal(destination, w.nonce, source, nil)
    38  	increaseNonce(w.nonce)
    39  }
    40  
    41  func (w *Writer) Write(p []byte) (n int, err error) {
    42  	if len(p) == 0 {
    43  		return
    44  	}
    45  	w.access.Lock()
    46  	defer w.access.Unlock()
    47  	for pLen := len(p); pLen > 0; {
    48  		var data []byte
    49  		if pLen > w.maxPacketSize {
    50  			data = p[:w.maxPacketSize]
    51  			p = p[w.maxPacketSize:]
    52  			pLen -= w.maxPacketSize
    53  		} else {
    54  			data = p
    55  			pLen = 0
    56  		}
    57  		bufferSize := PacketLengthBufferSize + 2*Overhead + len(data)
    58  		buffer := buf.NewSize(bufferSize)
    59  		common.Must(binary.Write(buffer, binary.BigEndian, uint16(len(data))))
    60  		w.cipher.Seal(buffer.Index(0), w.nonce, buffer.To(PacketLengthBufferSize), nil)
    61  		increaseNonce(w.nonce)
    62  		buffer.Extend(Overhead)
    63  		w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, data, nil)
    64  		buffer.Extend(len(data) + Overhead)
    65  		increaseNonce(w.nonce)
    66  		_, err = w.writer.Write(buffer.Bytes())
    67  		buffer.Release()
    68  		if err != nil {
    69  			return
    70  		}
    71  		n += len(data)
    72  	}
    73  	return
    74  }
    75  
    76  func (w *Writer) WriteBuffer(buffer *buf.Buffer) error {
    77  	if buffer.Len() > w.maxPacketSize {
    78  		defer buffer.Release()
    79  		return common.Error(w.Write(buffer.Bytes()))
    80  	}
    81  	pLen := buffer.Len()
    82  	headerOffset := PacketLengthBufferSize + Overhead
    83  	header := buffer.ExtendHeader(headerOffset)
    84  	binary.BigEndian.PutUint16(header, uint16(pLen))
    85  	w.cipher.Seal(header[:0], w.nonce, header[:PacketLengthBufferSize], nil)
    86  	increaseNonce(w.nonce)
    87  	w.cipher.Seal(buffer.Index(headerOffset), w.nonce, buffer.From(headerOffset), nil)
    88  	increaseNonce(w.nonce)
    89  	buffer.Extend(Overhead)
    90  	return w.writer.WriteBuffer(buffer)
    91  }
    92  
    93  func (w *Writer) TakeNonce() []byte {
    94  	return w.nonce
    95  }
    96  
    97  func (w *Writer) Upstream() any {
    98  	return w.writer
    99  }
   100  
   101  type WriterInterface struct{}
   102  
   103  func (w *WriterInterface) FrontHeadroom() int {
   104  	return PacketLengthBufferSize + Overhead
   105  }
   106  
   107  func (w *WriterInterface) RearHeadroom() int {
   108  	return Overhead
   109  }