github.com/sagernet/sing-shadowsocks2@v0.2.0/internal/shadowio/reader.go (about)

     1  package shadowio
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"io"
     7  
     8  	"github.com/sagernet/sing/common/buf"
     9  	N "github.com/sagernet/sing/common/network"
    10  )
    11  
    12  const PacketLengthBufferSize = 2
    13  
    14  const (
    15  	// Overhead
    16  	// crypto/cipher.gcmTagSize
    17  	// golang.org/x/crypto/chacha20poly1305.Overhead
    18  	Overhead = 16
    19  )
    20  
    21  var (
    22  	_ N.ExtendedReader = (*Reader)(nil)
    23  	_ N.ReadWaiter     = (*Reader)(nil)
    24  )
    25  
    26  type Reader struct {
    27  	reader          io.Reader
    28  	cipher          cipher.AEAD
    29  	nonce           []byte
    30  	cache           *buf.Buffer
    31  	readWaitOptions N.ReadWaitOptions
    32  }
    33  
    34  func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader {
    35  	return &Reader{
    36  		reader: upstream,
    37  		cipher: cipher,
    38  		nonce:  make([]byte, cipher.NonceSize()),
    39  	}
    40  }
    41  
    42  func (r *Reader) ReadFixedBuffer(pLen int) (*buf.Buffer, error) {
    43  	buffer := buf.NewSize(pLen + Overhead)
    44  	_, err := buffer.ReadFullFrom(r.reader, buffer.FreeLen())
    45  	if err != nil {
    46  		buffer.Release()
    47  		return nil, err
    48  	}
    49  	err = r.Decrypt(buffer.Index(0), buffer.Bytes())
    50  	if err != nil {
    51  		buffer.Release()
    52  		return nil, err
    53  	}
    54  	buffer.Truncate(pLen)
    55  	r.cache = buffer
    56  	return buffer, nil
    57  }
    58  
    59  func (r *Reader) Decrypt(destination []byte, source []byte) error {
    60  	_, err := r.cipher.Open(destination[:0], r.nonce, source, nil)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	increaseNonce(r.nonce)
    65  	return nil
    66  }
    67  
    68  func (r *Reader) Read(p []byte) (n int, err error) {
    69  	for {
    70  		if r.cache != nil {
    71  			if r.cache.IsEmpty() {
    72  				r.cache.Release()
    73  				r.cache = nil
    74  			} else {
    75  				n = copy(p, r.cache.Bytes())
    76  				if n > 0 {
    77  					r.cache.Advance(n)
    78  					return
    79  				}
    80  			}
    81  		}
    82  		r.cache, err = r.readBuffer()
    83  		if err != nil {
    84  			return
    85  		}
    86  	}
    87  }
    88  
    89  func (r *Reader) ReadBuffer(buffer *buf.Buffer) error {
    90  	var err error
    91  	for {
    92  		if r.cache != nil {
    93  			if r.cache.IsEmpty() {
    94  				r.cache.Release()
    95  				r.cache = nil
    96  			} else {
    97  				n := copy(buffer.FreeBytes(), r.cache.Bytes())
    98  				if n > 0 {
    99  					buffer.Truncate(n)
   100  					r.cache.Advance(n)
   101  					return nil
   102  				}
   103  			}
   104  		}
   105  		r.cache, err = r.readBuffer()
   106  		if err != nil {
   107  			return err
   108  		}
   109  	}
   110  }
   111  
   112  func (r *Reader) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
   113  	r.readWaitOptions = options
   114  	return options.NeedHeadroom()
   115  }
   116  
   117  func (r *Reader) WaitReadBuffer() (buffer *buf.Buffer, err error) {
   118  	if r.readWaitOptions.NeedHeadroom() {
   119  		for {
   120  			if r.cache != nil {
   121  				if r.cache.IsEmpty() {
   122  					r.cache.Release()
   123  					r.cache = nil
   124  				} else {
   125  					buffer = r.readWaitOptions.NewBuffer()
   126  					var n int
   127  					n, err = buffer.Write(r.cache.Bytes())
   128  					if err != nil {
   129  						buffer.Release()
   130  						return
   131  					}
   132  					buffer.Truncate(n)
   133  					r.cache.Advance(n)
   134  					r.readWaitOptions.PostReturn(buffer)
   135  					return
   136  				}
   137  			}
   138  			r.cache, err = r.readBuffer()
   139  			if err != nil {
   140  				return
   141  			}
   142  		}
   143  	} else {
   144  		cache := r.cache
   145  		if cache != nil {
   146  			r.cache = nil
   147  			return cache, nil
   148  		}
   149  		return r.readBuffer()
   150  	}
   151  }
   152  
   153  func (r *Reader) readBuffer() (*buf.Buffer, error) {
   154  	buffer := buf.NewSize(PacketLengthBufferSize + Overhead)
   155  	_, err := buffer.ReadFullFrom(r.reader, buffer.FreeLen())
   156  	if err != nil {
   157  		buffer.Release()
   158  		return nil, err
   159  	}
   160  	_, err = r.cipher.Open(buffer.Index(0), r.nonce, buffer.Bytes(), nil)
   161  	if err != nil {
   162  		buffer.Release()
   163  		return nil, err
   164  	}
   165  	increaseNonce(r.nonce)
   166  	length := int(binary.BigEndian.Uint16(buffer.To(PacketLengthBufferSize)))
   167  	buffer.Release()
   168  	buffer = buf.NewSize(length + Overhead)
   169  	_, err = buffer.ReadFullFrom(r.reader, buffer.FreeLen())
   170  	if err != nil {
   171  		buffer.Release()
   172  		return nil, err
   173  	}
   174  	_, err = r.cipher.Open(buffer.Index(0), r.nonce, buffer.Bytes(), nil)
   175  	if err != nil {
   176  		buffer.Release()
   177  		return nil, err
   178  	}
   179  	increaseNonce(r.nonce)
   180  	buffer.Truncate(length)
   181  	return buffer, nil
   182  }