github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/internal/shadowio/reader.go (about)

     1  package shadowio
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"io"
     7  
     8  	"github.com/sagernet/sing/common/buf"
     9  )
    10  
    11  const PacketLengthBufferSize = 2
    12  
    13  const (
    14  	// Overhead
    15  	// crypto/cipher.gcmTagSize
    16  	// golang.org/x/crypto/chacha20poly1305.Overhead
    17  	Overhead = 16
    18  )
    19  
    20  type Reader struct {
    21  	reader io.Reader
    22  	cipher cipher.AEAD
    23  	nonce  []byte
    24  	cache  *buf.Buffer
    25  }
    26  
    27  func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
    28  	return &Reader{
    29  		reader: upstream,
    30  		cipher: cipher,
    31  		nonce:  make([]byte, cipher.NonceSize()),
    32  	}
    33  }
    34  
    35  func (r *Reader) ReadFixedBuffer(pLen int) (*buf.Buffer, error) {
    36  	buffer := buf.NewSize(pLen + Overhead)
    37  	_, err := buffer.ReadFullFrom(r.reader, buffer.FreeLen())
    38  	if err != nil {
    39  		buffer.Release()
    40  		return nil, err
    41  	}
    42  	err = r.Decrypt(buffer.Index(0), buffer.Bytes())
    43  	if err != nil {
    44  		buffer.Release()
    45  		return nil, err
    46  	}
    47  	buffer.Truncate(pLen)
    48  	r.cache = buffer
    49  	return buffer, nil
    50  }
    51  
    52  func (r *Reader) Decrypt(destination []byte, source []byte) error {
    53  	_, err := r.cipher.Open(destination[:0], r.nonce, source, nil)
    54  	if err != nil {
    55  		return err
    56  	}
    57  	increaseNonce(r.nonce)
    58  	return nil
    59  }
    60  
    61  func (r *Reader) Read(p []byte) (n int, err error) {
    62  	for {
    63  		if r.cache != nil {
    64  			if r.cache.IsEmpty() {
    65  				r.cache.Release()
    66  				r.cache = nil
    67  			} else {
    68  				n = copy(p, r.cache.Bytes())
    69  				if n > 0 {
    70  					r.cache.Advance(n)
    71  					return
    72  				}
    73  			}
    74  		}
    75  		r.cache, err = r.readBuffer()
    76  		if err != nil {
    77  			return
    78  		}
    79  	}
    80  }
    81  
    82  func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
    83  	var err error
    84  	for {
    85  		if r.cache != nil {
    86  			if r.cache.IsEmpty() {
    87  				r.cache.Release()
    88  				r.cache = nil
    89  			} else {
    90  				n := copy(buffer.FreeBytes(), r.cache.Bytes())
    91  				if n > 0 {
    92  					buffer.Truncate(n)
    93  					r.cache.Advance(n)
    94  					return nil
    95  				}
    96  			}
    97  		}
    98  		r.cache, err = r.readBuffer()
    99  		if err != nil {
   100  			return err
   101  		}
   102  	}
   103  }
   104  
   105  func (r *Reader) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) {
   106  	cache := r.cache
   107  	if cache != nil {
   108  		r.cache = nil
   109  		return cache, nil
   110  	}
   111  	return r.readBuffer()
   112  }
   113  
   114  func (r *Reader) readBuffer() (*buf.Buffer, error) {
   115  	buffer := buf.NewSize(PacketLengthBufferSize + Overhead)
   116  	_, err := buffer.ReadFullFrom(r.reader, buffer.FreeLen())
   117  	if err != nil {
   118  		buffer.Release()
   119  		return nil, err
   120  	}
   121  	_, err = r.cipher.Open(buffer.Index(0), r.nonce, buffer.Bytes(), nil)
   122  	if err != nil {
   123  		buffer.Release()
   124  		return nil, err
   125  	}
   126  	increaseNonce(r.nonce)
   127  	length := int(binary.BigEndian.Uint16(buffer.To(PacketLengthBufferSize)))
   128  	buffer.Release()
   129  	buffer = buf.NewSize(length + Overhead)
   130  	_, err = buffer.ReadFullFrom(r.reader, buffer.FreeLen())
   131  	if err != nil {
   132  		buffer.Release()
   133  		return nil, err
   134  	}
   135  	_, err = r.cipher.Open(buffer.Index(0), r.nonce, buffer.Bytes(), nil)
   136  	if err != nil {
   137  		buffer.Release()
   138  		return nil, err
   139  	}
   140  	increaseNonce(r.nonce)
   141  	buffer.Truncate(length)
   142  	return buffer, nil
   143  }