github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead/aead.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/sagernet/sing/common/buf"
    10  )
    11  
    12  // https://shadowsocks.org/en/wiki/AEAD-Ciphers.html
    13  const (
    14  	MaxPacketSize          = 16*1024 - 1
    15  	PacketLengthBufferSize = 2
    16  )
    17  
    18  const (
    19  	// Overhead
    20  	// crypto/cipher.gcmTagSize
    21  	// golang.org/x/crypto/chacha20poly1305.Overhead
    22  	Overhead = 16
    23  )
    24  
    25  type Reader struct {
    26  	upstream io.Reader
    27  	cipher   cipher.AEAD
    28  	buffer   []byte
    29  	nonce    []byte
    30  	index    int
    31  	cached   int
    32  }
    33  
    34  func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader {
    35  	return &Reader{
    36  		upstream: upstream,
    37  		cipher:   cipher,
    38  		buffer:   make([]byte, maxPacketSize+Overhead),
    39  		nonce:    make([]byte, cipher.NonceSize()),
    40  	}
    41  }
    42  
    43  func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader {
    44  	return &Reader{
    45  		upstream: upstream,
    46  		cipher:   cipher,
    47  		buffer:   buffer,
    48  		nonce:    nonce,
    49  	}
    50  }
    51  
    52  func (r *Reader) Upstream() any {
    53  	return r.upstream
    54  }
    55  
    56  func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) {
    57  	if r.cached > 0 {
    58  		writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached])
    59  		if writeErr != nil {
    60  			return int64(writeN), writeErr
    61  		}
    62  		n += int64(writeN)
    63  	}
    64  	for {
    65  		start := PacketLengthBufferSize + Overhead
    66  		_, err = io.ReadFull(r.upstream, r.buffer[:start])
    67  		if err != nil {
    68  			return
    69  		}
    70  		_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
    71  		if err != nil {
    72  			return
    73  		}
    74  		increaseNonce(r.nonce)
    75  		length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
    76  		end := length + Overhead
    77  		_, err = io.ReadFull(r.upstream, r.buffer[:end])
    78  		if err != nil {
    79  			return
    80  		}
    81  		_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
    82  		if err != nil {
    83  			return
    84  		}
    85  		increaseNonce(r.nonce)
    86  		writeN, writeErr := writer.Write(r.buffer[:length])
    87  		if writeErr != nil {
    88  			return int64(writeN), writeErr
    89  		}
    90  		n += int64(writeN)
    91  	}
    92  }
    93  
    94  func (r *Reader) readInternal() (err error) {
    95  	start := PacketLengthBufferSize + Overhead
    96  	_, err = io.ReadFull(r.upstream, r.buffer[:start])
    97  	if err != nil {
    98  		return err
    99  	}
   100  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
   101  	if err != nil {
   102  		return err
   103  	}
   104  	increaseNonce(r.nonce)
   105  	length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
   106  	end := length + Overhead
   107  	_, err = io.ReadFull(r.upstream, r.buffer[:end])
   108  	if err != nil {
   109  		return err
   110  	}
   111  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	increaseNonce(r.nonce)
   116  	r.cached = length
   117  	r.index = 0
   118  	return nil
   119  }
   120  
   121  func (r *Reader) ReadByte() (byte, error) {
   122  	if r.cached == 0 {
   123  		err := r.readInternal()
   124  		if err != nil {
   125  			return 0, err
   126  		}
   127  	}
   128  	index := r.index
   129  	r.index++
   130  	r.cached--
   131  	return r.buffer[index], nil
   132  }
   133  
   134  func (r *Reader) Read(b []byte) (n int, err error) {
   135  	if r.cached > 0 {
   136  		n = copy(b, r.buffer[r.index:r.index+r.cached])
   137  		r.cached -= n
   138  		r.index += n
   139  		return
   140  	}
   141  	start := PacketLengthBufferSize + Overhead
   142  	_, err = io.ReadFull(r.upstream, r.buffer[:start])
   143  	if err != nil {
   144  		return 0, err
   145  	}
   146  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil)
   147  	if err != nil {
   148  		return 0, err
   149  	}
   150  	increaseNonce(r.nonce)
   151  	length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
   152  	end := length + Overhead
   153  
   154  	if len(b) >= end {
   155  		data := b[:end]
   156  		_, err = io.ReadFull(r.upstream, data)
   157  		if err != nil {
   158  			return 0, err
   159  		}
   160  		_, err = r.cipher.Open(b[:0], r.nonce, data, nil)
   161  		if err != nil {
   162  			return 0, err
   163  		}
   164  		increaseNonce(r.nonce)
   165  		return length, nil
   166  	} else {
   167  		_, err = io.ReadFull(r.upstream, r.buffer[:end])
   168  		if err != nil {
   169  			return 0, err
   170  		}
   171  		_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   172  		if err != nil {
   173  			return 0, err
   174  		}
   175  		increaseNonce(r.nonce)
   176  		n = copy(b, r.buffer[:length])
   177  		r.cached = length - n
   178  		r.index = n
   179  		return
   180  	}
   181  }
   182  
   183  func (r *Reader) Discard(n int) error {
   184  	for {
   185  		if r.cached >= n {
   186  			r.cached -= n
   187  			r.index += n
   188  			return nil
   189  		} else if r.cached > 0 {
   190  			n -= r.cached
   191  			r.cached = 0
   192  			r.index = 0
   193  		}
   194  		err := r.readInternal()
   195  		if err != nil {
   196  			return err
   197  		}
   198  	}
   199  }
   200  
   201  func (r *Reader) Buffer() *buf.Buffer {
   202  	buffer := buf.With(r.buffer)
   203  	buffer.Resize(r.index, r.cached)
   204  	return buffer
   205  }
   206  
   207  func (r *Reader) Cached() int {
   208  	return r.cached
   209  }
   210  
   211  func (r *Reader) CachedSlice() []byte {
   212  	return r.buffer[r.index : r.index+r.cached]
   213  }
   214  
   215  func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error {
   216  	_, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil)
   217  	if err != nil {
   218  		return err
   219  	}
   220  	increaseNonce(r.nonce)
   221  	length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize]))
   222  	end := length + Overhead
   223  	_, err = io.ReadFull(r.upstream, r.buffer[:end])
   224  	if err != nil {
   225  		return err
   226  	}
   227  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   228  	if err != nil {
   229  		return err
   230  	}
   231  	increaseNonce(r.nonce)
   232  	r.cached = length
   233  	r.index = 0
   234  	return nil
   235  }
   236  
   237  func (r *Reader) ReadWithLength(length uint16) error {
   238  	end := int(length) + Overhead
   239  	_, err := io.ReadFull(r.upstream, r.buffer[:end])
   240  	if err != nil {
   241  		return err
   242  	}
   243  	_, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil)
   244  	if err != nil {
   245  		return err
   246  	}
   247  	increaseNonce(r.nonce)
   248  	r.cached = int(length)
   249  	r.index = 0
   250  	return nil
   251  }
   252  
   253  func (r *Reader) ReadExternalChunk(chunk []byte) error {
   254  	bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil)
   255  	if err != nil {
   256  		return err
   257  	}
   258  	increaseNonce(r.nonce)
   259  	r.cached = len(bb)
   260  	r.index = 0
   261  	return nil
   262  }
   263  
   264  func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error {
   265  	bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	increaseNonce(r.nonce)
   270  	buffer.Extend(len(bb))
   271  	return nil
   272  }
   273  
   274  type Writer struct {
   275  	upstream      io.Writer
   276  	cipher        cipher.AEAD
   277  	maxPacketSize int
   278  	buffer        []byte
   279  	nonce         []byte
   280  	access        sync.Mutex
   281  }
   282  
   283  func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer {
   284  	return &Writer{
   285  		upstream:      upstream,
   286  		cipher:        cipher,
   287  		buffer:        make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2),
   288  		nonce:         make([]byte, cipher.NonceSize()),
   289  		maxPacketSize: maxPacketSize,
   290  	}
   291  }
   292  
   293  func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer {
   294  	return &Writer{
   295  		upstream:      upstream,
   296  		cipher:        cipher,
   297  		maxPacketSize: maxPacketSize,
   298  		buffer:        buffer,
   299  		nonce:         nonce,
   300  	}
   301  }
   302  
   303  func (w *Writer) Upstream() any {
   304  	return w.upstream
   305  }
   306  
   307  func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
   308  	for {
   309  		offset := Overhead + PacketLengthBufferSize
   310  		readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize])
   311  		if readErr != nil {
   312  			return 0, readErr
   313  		}
   314  		binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN))
   315  		w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
   316  		increaseNonce(w.nonce)
   317  		packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil)
   318  		increaseNonce(w.nonce)
   319  		_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
   320  		if err != nil {
   321  			return
   322  		}
   323  		n += int64(readN)
   324  	}
   325  }
   326  
   327  func (w *Writer) Write(p []byte) (n int, err error) {
   328  	if len(p) == 0 {
   329  		return
   330  	}
   331  
   332  	for pLen := len(p); pLen > 0; {
   333  		var data []byte
   334  		if pLen > w.maxPacketSize {
   335  			data = p[:w.maxPacketSize]
   336  			p = p[w.maxPacketSize:]
   337  			pLen -= w.maxPacketSize
   338  		} else {
   339  			data = p
   340  			pLen = 0
   341  		}
   342  		w.access.Lock()
   343  		binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data)))
   344  		w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil)
   345  		increaseNonce(w.nonce)
   346  		offset := Overhead + PacketLengthBufferSize
   347  		packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil)
   348  		increaseNonce(w.nonce)
   349  		w.access.Unlock()
   350  		_, err = w.upstream.Write(w.buffer[:offset+len(packet)])
   351  		if err != nil {
   352  			return
   353  		}
   354  		n += len(data)
   355  	}
   356  
   357  	return
   358  }
   359  
   360  func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error {
   361  	defer buf.ReleaseMulti(buffers)
   362  	var index int
   363  	var err error
   364  	for _, buffer := range buffers {
   365  		pLen := buffer.Len()
   366  		if pLen > w.maxPacketSize {
   367  			_, err = w.Write(buffer.Bytes())
   368  			if err != nil {
   369  				return err
   370  			}
   371  		} else {
   372  			if cap(w.buffer) < index+PacketLengthBufferSize+pLen+2*Overhead {
   373  				_, err = w.upstream.Write(w.buffer[:index])
   374  				index = 0
   375  				if err != nil {
   376  					return err
   377  				}
   378  			}
   379  			w.access.Lock()
   380  			binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen))
   381  			w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil)
   382  			increaseNonce(w.nonce)
   383  			offset := index + Overhead + PacketLengthBufferSize
   384  			w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil)
   385  			increaseNonce(w.nonce)
   386  			w.access.Unlock()
   387  			index = offset + pLen + Overhead
   388  		}
   389  	}
   390  	if index > 0 {
   391  		_, err = w.upstream.Write(w.buffer[:index])
   392  	}
   393  	return err
   394  }
   395  
   396  func (w *Writer) Buffer() *buf.Buffer {
   397  	return buf.With(w.buffer)
   398  }
   399  
   400  func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) {
   401  	bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil)
   402  	buffer.Extend(len(bb))
   403  	increaseNonce(w.nonce)
   404  }
   405  
   406  func (w *Writer) BufferedWriter(reversed int) *BufferedWriter {
   407  	return &BufferedWriter{
   408  		upstream: w,
   409  		reversed: reversed,
   410  		data:     w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead],
   411  	}
   412  }
   413  
   414  type BufferedWriter struct {
   415  	upstream *Writer
   416  	data     []byte
   417  	reversed int
   418  	index    int
   419  }
   420  
   421  func (w *BufferedWriter) Write(p []byte) (n int, err error) {
   422  	for {
   423  		cachedN := copy(w.data[w.reversed+w.index:], p[n:])
   424  		w.index += cachedN
   425  		if cachedN == len(p[n:]) {
   426  			n += cachedN
   427  			return
   428  		}
   429  		err = w.Flush()
   430  		if err != nil {
   431  			return
   432  		}
   433  		n += cachedN
   434  	}
   435  }
   436  
   437  func (w *BufferedWriter) Flush() error {
   438  	if w.index == 0 {
   439  		if w.reversed > 0 {
   440  			_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed])
   441  			w.reversed = 0
   442  			return err
   443  		}
   444  		return nil
   445  	}
   446  	buffer := w.upstream.buffer[w.reversed:]
   447  	binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index))
   448  	w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil)
   449  	increaseNonce(w.upstream.nonce)
   450  	offset := Overhead + PacketLengthBufferSize
   451  	packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil)
   452  	increaseNonce(w.upstream.nonce)
   453  	_, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)])
   454  	w.reversed = 0
   455  	w.index = 0
   456  	return err
   457  }
   458  
   459  func increaseNonce(nonce []byte) {
   460  	for i := range nonce {
   461  		nonce[i]++
   462  		if nonce[i] != 0 {
   463  			return
   464  		}
   465  	}
   466  }