github.com/metacubex/sing-shadowsocks@v0.2.6/shadowaead/aead.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/sagernet/sing/common/buf"
    10  )
    11  
    12  // https://shadowsocks.org/en/wiki/AEAD-Ciphers.html
    13  const (
    14  	MaxPacketSize          = 16*1024 - 1
    15  	PacketLengthBufferSize = 2
    16  )
    17  
    18  const (
    19  	// Overhead
    20  	// crypto/cipher.gcmTagSize
    21  	// golang.org/x/crypto/chacha20poly1305.Overhead
    22  	// github.com/sina-ghaderi/poly1305.TagSize
    23  	// github.com/ericlagergren/siv.TagSize
    24  	// github.com/ericlagergren/aegis.TagSize128L
    25  	// github.com/ericlagergren/aegis.TagSize256
    26  	// github.com/Yawning/aez.aeadOverhead
    27  	// github.com/oasisprotocol/deoxysii.TagSize
    28  	Overhead = 16
    29  )
    30  
    31  type Reader struct {
    32  	upstream io.Reader
    33  	cipher   cipher.AEAD
    34  	buffer   []byte
    35  	nonce    []byte
    36  	index    int
    37  	cached   int
    38  }
    39  
    40  func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader {
    41  	return &Reader{
    42  		upstream: upstream,
    43  		cipher:   cipher,
    44  		buffer:   make([]byte, maxPacketSize+Overhead),
    45  		nonce:    make([]byte, cipher.NonceSize()),
    46  	}
    47  }
    48  
    49  func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader {
    50  	return &Reader{
    51  		upstream: upstream,
    52  		cipher:   cipher,
    53  		buffer:   buffer,
    54  		nonce:    nonce,
    55  	}
    56  }
    57  
    58  func (r *Reader) Upstream() any {
    59  	return r.upstream
    60  }
    61  
    62  func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
    63  	if r.cached > 0 {
    64  		writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached])
    65  		if writeErr != nil {
    66  			return int64(writeN), writeErr
    67  		}
    68  		n += int64(writeN)
    69  	}
    70  	for {
    71  		start := PacketLengthBufferSize + Overhead
    72  		_, err = io.ReadFull(r.upstream, r.buffer[:start])
    73  		if err != nil {
    74  			return
    75  		}
    76  		_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
    77  		if err != nil {
    78  			return
    79  		}
    80  		increaseNonce(r.nonce)
    81  		length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
    82  		end := length + Overhead
    83  		_, err = io.ReadFull(r.upstream, r.buffer[:end])
    84  		if err != nil {
    85  			return
    86  		}
    87  		_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
    88  		if err != nil {
    89  			return
    90  		}
    91  		increaseNonce(r.nonce)
    92  		writeN, writeErr := writer.Write(r.buffer[:length])
    93  		if writeErr != nil {
    94  			return int64(writeN), writeErr
    95  		}
    96  		n += int64(writeN)
    97  	}
    98  }
    99  
   100  func (r *Reader) readInternal() (err error) {
   101  	start := PacketLengthBufferSize + Overhead
   102  	_, err = io.ReadFull(r.upstream, r.buffer[:start])
   103  	if err != nil {
   104  		return err
   105  	}
   106  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
   107  	if err != nil {
   108  		return err
   109  	}
   110  	increaseNonce(r.nonce)
   111  	length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
   112  	end := length + Overhead
   113  	_, err = io.ReadFull(r.upstream, r.buffer[:end])
   114  	if err != nil {
   115  		return err
   116  	}
   117  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   118  	if err != nil {
   119  		return err
   120  	}
   121  	increaseNonce(r.nonce)
   122  	r.cached = length
   123  	r.index = 0
   124  	return nil
   125  }
   126  
   127  func (r *Reader) ReadByte() (byte, error) {
   128  	if r.cached == 0 {
   129  		err := r.readInternal()
   130  		if err != nil {
   131  			return 0, err
   132  		}
   133  	}
   134  	index := r.index
   135  	r.index++
   136  	r.cached--
   137  	return r.buffer[index], nil
   138  }
   139  
   140  func (r *Reader) Read(b []byte) (n int, err error) {
   141  	if r.cached > 0 {
   142  		n = copy(b, r.buffer[r.index:r.index+r.cached])
   143  		r.cached -= n
   144  		r.index += n
   145  		return
   146  	}
   147  	start := PacketLengthBufferSize + Overhead
   148  	_, err = io.ReadFull(r.upstream, r.buffer[:start])
   149  	if err != nil {
   150  		return 0, err
   151  	}
   152  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
   153  	if err != nil {
   154  		return 0, err
   155  	}
   156  	increaseNonce(r.nonce)
   157  	length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
   158  	end := length + Overhead
   159  
   160  	if len(b) >= end {
   161  		data := b[:end]
   162  		_, err = io.ReadFull(r.upstream, data)
   163  		if err != nil {
   164  			return 0, err
   165  		}
   166  		_, err = r.cipher.Open(b[:0], r.nonce, data, nil)
   167  		if err != nil {
   168  			return 0, err
   169  		}
   170  		increaseNonce(r.nonce)
   171  		return length, nil
   172  	} else {
   173  		_, err = io.ReadFull(r.upstream, r.buffer[:end])
   174  		if err != nil {
   175  			return 0, err
   176  		}
   177  		_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   178  		if err != nil {
   179  			return 0, err
   180  		}
   181  		increaseNonce(r.nonce)
   182  		n = copy(b, r.buffer[:length])
   183  		r.cached = length - n
   184  		r.index = n
   185  		return
   186  	}
   187  }
   188  
   189  func (r *Reader) Discard(n int) error {
   190  	for {
   191  		if r.cached >= n {
   192  			r.cached -= n
   193  			r.index += n
   194  			return nil
   195  		} else if r.cached > 0 {
   196  			n -= r.cached
   197  			r.cached = 0
   198  			r.index = 0
   199  		}
   200  		err := r.readInternal()
   201  		if err != nil {
   202  			return err
   203  		}
   204  	}
   205  }
   206  
   207  func (r *Reader) Buffer() *buf.Buffer {
   208  	buffer := buf.With(r.buffer)
   209  	buffer.Resize(r.index, r.cached)
   210  	return buffer
   211  }
   212  
   213  func (r *Reader) Cached() int {
   214  	return r.cached
   215  }
   216  
   217  func (r *Reader) CachedSlice() []byte {
   218  	return r.buffer[r.index : r.index+r.cached]
   219  }
   220  
   221  func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error {
   222  	_, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	increaseNonce(r.nonce)
   227  	length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
   228  	end := length + Overhead
   229  	_, err = io.ReadFull(r.upstream, r.buffer[:end])
   230  	if err != nil {
   231  		return err
   232  	}
   233  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   234  	if err != nil {
   235  		return err
   236  	}
   237  	increaseNonce(r.nonce)
   238  	r.cached = length
   239  	r.index = 0
   240  	return nil
   241  }
   242  
   243  func (r *Reader) ReadWithLength(length uint16) error {
   244  	end := int(length) + Overhead
   245  	_, err := io.ReadFull(r.upstream, r.buffer[:end])
   246  	if err != nil {
   247  		return err
   248  	}
   249  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   250  	if err != nil {
   251  		return err
   252  	}
   253  	increaseNonce(r.nonce)
   254  	r.cached = int(length)
   255  	r.index = 0
   256  	return nil
   257  }
   258  
   259  func (r *Reader) ReadExternalChunk(chunk []byte) error {
   260  	bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	increaseNonce(r.nonce)
   265  	r.cached = len(bb)
   266  	r.index = 0
   267  	return nil
   268  }
   269  
   270  func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error {
   271  	bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil)
   272  	if err != nil {
   273  		return err
   274  	}
   275  	increaseNonce(r.nonce)
   276  	buffer.Extend(len(bb))
   277  	return nil
   278  }
   279  
   280  type Writer struct {
   281  	upstream      io.Writer
   282  	cipher        cipher.AEAD
   283  	maxPacketSize int
   284  	buffer        []byte
   285  	nonce         []byte
   286  	access        sync.Mutex
   287  }
   288  
   289  func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer {
   290  	return &Writer{
   291  		upstream:      upstream,
   292  		cipher:        cipher,
   293  		buffer:        make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
   294  		nonce:         make([]byte, cipher.NonceSize()),
   295  		maxPacketSize: maxPacketSize,
   296  	}
   297  }
   298  
   299  func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer {
   300  	return &Writer{
   301  		upstream:      upstream,
   302  		cipher:        cipher,
   303  		maxPacketSize: maxPacketSize,
   304  		buffer:        buffer,
   305  		nonce:         nonce,
   306  	}
   307  }
   308  
   309  func (w *Writer) Upstream() any {
   310  	return w.upstream
   311  }
   312  
   313  func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
   314  	for {
   315  		offset := Overhead + PacketLengthBufferSize
   316  		readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize])
   317  		if readErr != nil {
   318  			return 0, readErr
   319  		}
   320  		binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN))
   321  		w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
   322  		increaseNonce(w.nonce)
   323  		packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil)
   324  		increaseNonce(w.nonce)
   325  		_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
   326  		if err != nil {
   327  			return
   328  		}
   329  		n += int64(readN)
   330  	}
   331  }
   332  
   333  func (w *Writer) Write(p []byte) (n int, err error) {
   334  	if len(p) == 0 {
   335  		return
   336  	}
   337  
   338  	for pLen := len(p); pLen > 0; {
   339  		var data []byte
   340  		if pLen > w.maxPacketSize {
   341  			data = p[:w.maxPacketSize]
   342  			p = p[w.maxPacketSize:]
   343  			pLen -= w.maxPacketSize
   344  		} else {
   345  			data = p
   346  			pLen = 0
   347  		}
   348  		w.access.Lock()
   349  		binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data)))
   350  		w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
   351  		increaseNonce(w.nonce)
   352  		offset := Overhead + PacketLengthBufferSize
   353  		packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil)
   354  		increaseNonce(w.nonce)
   355  		w.access.Unlock()
   356  		_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
   357  		if err != nil {
   358  			return
   359  		}
   360  		n += len(data)
   361  	}
   362  
   363  	return
   364  }
   365  
   366  func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error {
   367  	defer buf.ReleaseMulti(buffers)
   368  	var index int
   369  	var err error
   370  	for _, buffer := range buffers {
   371  		pLen := buffer.Len()
   372  		if pLen > w.maxPacketSize {
   373  			_, err = w.Write(buffer.Bytes())
   374  			if err != nil {
   375  				return err
   376  			}
   377  		} else {
   378  			if cap(w.buffer) < index+PacketLengthBufferSize+pLen+2*Overhead {
   379  				_, err = w.upstream.Write(w.buffer[:index])
   380  				index = 0
   381  				if err != nil {
   382  					return err
   383  				}
   384  			}
   385  			w.access.Lock()
   386  			binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen))
   387  			w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil)
   388  			increaseNonce(w.nonce)
   389  			offset := index + Overhead + PacketLengthBufferSize
   390  			w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil)
   391  			increaseNonce(w.nonce)
   392  			w.access.Unlock()
   393  			index = offset + pLen + Overhead
   394  		}
   395  	}
   396  	if index > 0 {
   397  		_, err = w.upstream.Write(w.buffer[:index])
   398  	}
   399  	return err
   400  }
   401  
   402  func (w *Writer) Buffer() *buf.Buffer {
   403  	return buf.With(w.buffer)
   404  }
   405  
   406  func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) {
   407  	bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil)
   408  	buffer.Extend(len(bb))
   409  	increaseNonce(w.nonce)
   410  }
   411  
   412  func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
   413  	return &BufferedWriter{
   414  		upstream: w,
   415  		reversed: reversed,
   416  		data:     w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead],
   417  	}
   418  }
   419  
   420  type BufferedWriter struct {
   421  	upstream *Writer
   422  	data     []byte
   423  	reversed int
   424  	index    int
   425  }
   426  
   427  func (w *BufferedWriter) Write(p []byte) (n int, err error) {
   428  	for {
   429  		cachedN := copy(w.data[w.reversed+w.index:], p[n:])
   430  		w.index += cachedN
   431  		if cachedN == len(p[n:]) {
   432  			n += cachedN
   433  			return
   434  		}
   435  		err = w.Flush()
   436  		if err != nil {
   437  			return
   438  		}
   439  		n += cachedN
   440  	}
   441  }
   442  
   443  func (w *BufferedWriter) Flush() error {
   444  	if w.index == 0 {
   445  		if w.reversed > 0 {
   446  			_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
   447  			w.reversed = 0
   448  			return err
   449  		}
   450  		return nil
   451  	}
   452  	buffer := w.upstream.buffer[w.reversed:]
   453  	binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
   454  	w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
   455  	increaseNonce(w.upstream.nonce)
   456  	offset := Overhead + PacketLengthBufferSize
   457  	packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
   458  	increaseNonce(w.upstream.nonce)
   459  	_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
   460  	w.reversed = 0
   461  	w.index = 0
   462  	return err
   463  }
   464  
   465  func increaseNonce(nonce []byte) {
   466  	for i := range nonce {
   467  		nonce[i]++
   468  		if nonce[i] != 0 {
   469  			return
   470  		}
   471  	}
   472  }