github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/stream.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"encoding/binary"
     8  	"errors"
     9  	"io"
    10  
    11  	"github.com/database64128/shadowsocks-go/zerocopy"
    12  )
    13  
    14  const MaxPayloadSize = 0xFFFF
    15  
    16  // ShadowStreamHeadroom is the headroom required by an encrypted Shadowsocks stream.
    17  //
    18  // Front is the size of an encrypted length chunk.
    19  // Rear is the size of an AEAD tag.
    20  var ShadowStreamHeadroom = zerocopy.Headroom{
    21  	Front: 2 + 16,
    22  	Rear:  16,
    23  }
    24  
    25  // ShadowStreamReaderInfo contains information about a [ShadowStreamReader].
    26  var ShadowStreamReaderInfo = zerocopy.ReaderInfo{
    27  	Headroom:                    ShadowStreamHeadroom,
    28  	MinPayloadBufferSizePerRead: MaxPayloadSize,
    29  }
    30  
    31  // ShadowStreamWriterInfo contains information about a [ShadowStreamWriter].
    32  var ShadowStreamWriterInfo = zerocopy.WriterInfo{
    33  	Headroom:               ShadowStreamHeadroom,
    34  	MaxPayloadSizePerWrite: MaxPayloadSize,
    35  }
    36  
    37  var (
    38  	ErrZeroLengthChunk = errors.New("length in length chunk is zero")
    39  	ErrFirstRead       = errors.New("failed to read fixed-length header in one read call")
    40  	ErrRepeatedSalt    = errors.New("detected replay: repeated salt")
    41  )
    42  
    43  var ErrUnsafeStreamPrefixMismatch = errors.New("unsafe stream prefix mismatch")
    44  
    45  // ShadowStreamServerReadWriter implements Shadowsocks stream server.
    46  type ShadowStreamServerReadWriter struct {
    47  	*ShadowStreamReader
    48  	*ShadowStreamWriter
    49  	rawRW                      zerocopy.DirectReadWriteCloser
    50  	cipherConfig               UserCipherConfig
    51  	requestSalt                []byte
    52  	unsafeResponseStreamPrefix []byte
    53  }
    54  
    55  // WriteZeroCopy implements the Writer WriteZeroCopy method.
    56  func (rw *ShadowStreamServerReadWriter) WriteZeroCopy(b []byte, payloadStart, payloadLen int) (int, error) {
    57  	if rw.ShadowStreamWriter == nil { // first write
    58  		urspLen := len(rw.unsafeResponseStreamPrefix)
    59  		saltLen := len(rw.cipherConfig.PSK)
    60  		responseHeaderStart := urspLen + saltLen
    61  		responseHeaderEnd := responseHeaderStart + TCPRequestFixedLengthHeaderLength + saltLen
    62  		payloadBufStart := responseHeaderEnd + 16
    63  		bufferLen := payloadBufStart + payloadLen + 16
    64  		hb := make([]byte, bufferLen)
    65  		ursp := hb[:urspLen]
    66  		salt := hb[urspLen:responseHeaderStart]
    67  		responseHeader := hb[responseHeaderStart:responseHeaderEnd]
    68  
    69  		// Write unsafe response stream prefix.
    70  		copy(ursp, rw.unsafeResponseStreamPrefix)
    71  
    72  		// Random salt.
    73  		_, err := rand.Read(salt)
    74  		if err != nil {
    75  			return 0, err
    76  		}
    77  
    78  		// Write response header.
    79  		WriteTCPResponseHeader(responseHeader, rw.requestSalt, uint16(payloadLen))
    80  
    81  		// Create AEAD cipher.
    82  		shadowStreamCipher, err := rw.cipherConfig.ShadowStreamCipher(salt)
    83  		if err != nil {
    84  			return 0, err
    85  		}
    86  
    87  		// Create writer.
    88  		rw.ShadowStreamWriter = &ShadowStreamWriter{
    89  			writer: rw.rawRW,
    90  			ssc:    shadowStreamCipher,
    91  		}
    92  
    93  		// Seal response header.
    94  		shadowStreamCipher.EncryptInPlace(responseHeader)
    95  
    96  		// Seal payload.
    97  		dst := hb[payloadBufStart:]
    98  		plaintext := b[payloadStart : payloadStart+payloadLen]
    99  		shadowStreamCipher.EncryptTo(dst, plaintext)
   100  
   101  		// Write out.
   102  		_, err = rw.rawRW.Write(hb)
   103  		if err != nil {
   104  			return 0, err
   105  		}
   106  
   107  		return payloadLen, nil
   108  	}
   109  
   110  	return rw.ShadowStreamWriter.WriteZeroCopy(b, payloadStart, payloadLen)
   111  }
   112  
   113  // CloseRead implements the ReadWriter CloseRead method.
   114  func (rw *ShadowStreamServerReadWriter) CloseRead() error {
   115  	return rw.rawRW.CloseRead()
   116  }
   117  
   118  // CloseWrite implements the ReadWriter CloseWrite method.
   119  func (rw *ShadowStreamServerReadWriter) CloseWrite() error {
   120  	return rw.rawRW.CloseWrite()
   121  }
   122  
   123  // Close implements the ReadWriter Close method.
   124  func (rw *ShadowStreamServerReadWriter) Close() error {
   125  	return rw.rawRW.Close()
   126  }
   127  
   128  // ShadowStreamClientReadWriter implements Shadowsocks stream client.
   129  type ShadowStreamClientReadWriter struct {
   130  	*ShadowStreamReader
   131  	*ShadowStreamWriter
   132  	rawRW                      zerocopy.DirectReadWriteCloser
   133  	readOnceOrFull             func(io.Reader, []byte) (int, error)
   134  	cipherConfig               *ClientCipherConfig
   135  	requestSalt                []byte
   136  	unsafeResponseStreamPrefix []byte
   137  }
   138  
   139  // ReadZeroCopy implements the Reader ReadZeroCopy method.
   140  func (rw *ShadowStreamClientReadWriter) ReadZeroCopy(b []byte, payloadBufStart, payloadBufLen int) (int, error) {
   141  	if rw.ShadowStreamReader == nil { // first read
   142  		urspLen := len(rw.unsafeResponseStreamPrefix)
   143  		saltLen := len(rw.cipherConfig.PSK)
   144  		fixedLengthHeaderStart := urspLen + saltLen
   145  		bufferLen := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength + saltLen + 16
   146  		hb := make([]byte, bufferLen)
   147  
   148  		// Read response header.
   149  		_, err := rw.readOnceOrFull(rw.rawRW, hb)
   150  		if err != nil {
   151  			return 0, err
   152  		}
   153  
   154  		// Check unsafe response stream prefix.
   155  		ursp := hb[:urspLen]
   156  		if !bytes.Equal(ursp, rw.unsafeResponseStreamPrefix) {
   157  			return 0, &HeaderError[[]byte]{ErrUnsafeStreamPrefixMismatch, rw.unsafeResponseStreamPrefix, ursp}
   158  		}
   159  
   160  		// Derive key and create cipher.
   161  		salt := hb[urspLen:fixedLengthHeaderStart]
   162  		ciphertext := hb[fixedLengthHeaderStart:]
   163  		shadowStreamCipher, err := rw.cipherConfig.ShadowStreamCipher(salt)
   164  		if err != nil {
   165  			return 0, err
   166  		}
   167  
   168  		// Create reader.
   169  		rw.ShadowStreamReader = &ShadowStreamReader{
   170  			reader: rw.rawRW,
   171  			ssc:    shadowStreamCipher,
   172  		}
   173  
   174  		// AEAD open.
   175  		plaintext, err := shadowStreamCipher.DecryptInPlace(ciphertext)
   176  		if err != nil {
   177  			return 0, err
   178  		}
   179  
   180  		// Parse response header.
   181  		n, err := ParseTCPResponseHeader(plaintext, rw.requestSalt)
   182  		if err != nil {
   183  			return 0, err
   184  		}
   185  
   186  		payloadBuf := b[payloadBufStart : payloadBufStart+n+16]
   187  
   188  		// Read payload chunk.
   189  		_, err = io.ReadFull(rw.rawRW, payloadBuf)
   190  		if err != nil {
   191  			return 0, err
   192  		}
   193  
   194  		// AEAD open.
   195  		_, err = shadowStreamCipher.DecryptInPlace(payloadBuf)
   196  		if err != nil {
   197  			return 0, err
   198  		}
   199  
   200  		return n, nil
   201  	}
   202  
   203  	return rw.ShadowStreamReader.ReadZeroCopy(b, payloadBufStart, payloadBufLen)
   204  }
   205  
   206  // CloseRead implements the ReadWriter CloseRead method.
   207  func (rw *ShadowStreamClientReadWriter) CloseRead() error {
   208  	return rw.rawRW.CloseRead()
   209  }
   210  
   211  // CloseWrite implements the ReadWriter CloseWrite method.
   212  func (rw *ShadowStreamClientReadWriter) CloseWrite() error {
   213  	return rw.rawRW.CloseWrite()
   214  }
   215  
   216  // Close implements the ReadWriter Close method.
   217  func (rw *ShadowStreamClientReadWriter) Close() error {
   218  	return rw.rawRW.Close()
   219  }
   220  
   221  // ShadowStreamWriter wraps an io.WriteCloser and feeds an encrypted Shadowsocks stream to it.
   222  //
   223  // Wire format:
   224  //
   225  //	+------------------------+---------------------------+
   226  //	| encrypted length chunk |  encrypted payload chunk  |
   227  //	+------------------------+---------------------------+
   228  //	|  2B length + 16B tag   | variable length + 16B tag |
   229  //	+------------------------+---------------------------+
   230  type ShadowStreamWriter struct {
   231  	writer io.WriteCloser
   232  	ssc    *ShadowStreamCipher
   233  }
   234  
   235  // WriterInfo implements the Writer WriterInfo method.
   236  func (w *ShadowStreamWriter) WriterInfo() zerocopy.WriterInfo {
   237  	return ShadowStreamWriterInfo
   238  }
   239  
   240  // WriteZeroCopy implements the Writer WriteZeroCopy method.
   241  func (w *ShadowStreamWriter) WriteZeroCopy(b []byte, payloadStart, payloadLen int) (payloadWritten int, err error) {
   242  	overhead := w.ssc.Overhead()
   243  	lengthStart := payloadStart - overhead - 2
   244  	lengthBuf := b[lengthStart : lengthStart+2]
   245  	payloadBuf := b[payloadStart : payloadStart+payloadLen]
   246  	payloadTagEnd := payloadStart + payloadLen + overhead
   247  	chunksBuf := b[lengthStart:payloadTagEnd]
   248  
   249  	// Write length.
   250  	binary.BigEndian.PutUint16(lengthBuf, uint16(payloadLen))
   251  
   252  	// Seal length chunk.
   253  	w.ssc.EncryptInPlace(lengthBuf)
   254  
   255  	// Seal payload chunk.
   256  	w.ssc.EncryptInPlace(payloadBuf)
   257  
   258  	// Write to wrapped writer.
   259  	_, err = w.writer.Write(chunksBuf)
   260  	if err != nil {
   261  		return
   262  	}
   263  	payloadWritten = payloadLen
   264  	return
   265  }
   266  
   267  // ShadowStreamReader wraps an io.ReadCloser and reads from it as an encrypted Shadowsocks stream.
   268  type ShadowStreamReader struct {
   269  	reader io.ReadCloser
   270  	ssc    *ShadowStreamCipher
   271  }
   272  
   273  // ReaderInfo implements the Reader ReaderInfo method.
   274  func (r *ShadowStreamReader) ReaderInfo() zerocopy.ReaderInfo {
   275  	return ShadowStreamReaderInfo
   276  }
   277  
   278  // ReadZeroCopy implements the Reader ReadZeroCopy method.
   279  func (r *ShadowStreamReader) ReadZeroCopy(b []byte, payloadBufStart, payloadBufLen int) (payloadLen int, err error) {
   280  	overhead := r.ssc.Overhead()
   281  	sealedLengthChunkStart := payloadBufStart - overhead - 2
   282  	sealedLengthChunkBuf := b[sealedLengthChunkStart:payloadBufStart]
   283  
   284  	// Read sealed length chunk.
   285  	_, err = io.ReadFull(r.reader, sealedLengthChunkBuf)
   286  	if err != nil {
   287  		return
   288  	}
   289  
   290  	// Open sealed length chunk.
   291  	_, err = r.ssc.DecryptInPlace(sealedLengthChunkBuf)
   292  	if err != nil {
   293  		return
   294  	}
   295  
   296  	// Validate length.
   297  	payloadLen = int(binary.BigEndian.Uint16(sealedLengthChunkBuf))
   298  	if payloadLen == 0 {
   299  		err = ErrZeroLengthChunk
   300  		return
   301  	}
   302  
   303  	// Read sealed payload chunk.
   304  	sealedPayloadChunkBuf := b[payloadBufStart : payloadBufStart+payloadLen+overhead]
   305  	_, err = io.ReadFull(r.reader, sealedPayloadChunkBuf)
   306  	if err != nil {
   307  		payloadLen = 0
   308  		return
   309  	}
   310  
   311  	// Open sealed payload chunk.
   312  	_, err = r.ssc.DecryptInPlace(sealedPayloadChunkBuf)
   313  	if err != nil {
   314  		payloadLen = 0
   315  	}
   316  
   317  	return
   318  }
   319  
   320  // ShadowStreamCipher wraps an AEAD cipher and provides methods that transparently increments
   321  // the nonce after each AEAD operation.
   322  type ShadowStreamCipher struct {
   323  	aead  cipher.AEAD
   324  	nonce []byte
   325  }
   326  
   327  // NewShadowStreamCipher wraps the given AEAD cipher into a new ShadowStreamCipher.
   328  func NewShadowStreamCipher(aead cipher.AEAD) *ShadowStreamCipher {
   329  	return &ShadowStreamCipher{
   330  		aead:  aead,
   331  		nonce: make([]byte, aead.NonceSize()),
   332  	}
   333  }
   334  
   335  // Overhead returns the tag size of the AEAD cipher.
   336  func (c *ShadowStreamCipher) Overhead() int {
   337  	return c.aead.Overhead()
   338  }
   339  
   340  // EncryptInPlace encrypts and authenticates plaintext in-place.
   341  func (c *ShadowStreamCipher) EncryptInPlace(plaintext []byte) (ciphertext []byte) {
   342  	ciphertext = c.aead.Seal(plaintext[:0], c.nonce, plaintext, nil)
   343  	increment(c.nonce)
   344  	return
   345  }
   346  
   347  // EncryptTo encrypts and authenticates the plaintext and saves the ciphertext to dst.
   348  func (c *ShadowStreamCipher) EncryptTo(dst, plaintext []byte) (ciphertext []byte) {
   349  	ciphertext = c.aead.Seal(dst[:0], c.nonce, plaintext, nil)
   350  	increment(c.nonce)
   351  	return
   352  }
   353  
   354  // DecryptInplace decrypts and authenticates ciphertext in-place.
   355  func (c *ShadowStreamCipher) DecryptInPlace(ciphertext []byte) (plaintext []byte, err error) {
   356  	plaintext, err = c.aead.Open(ciphertext[:0], c.nonce, ciphertext, nil)
   357  	if err == nil {
   358  		increment(c.nonce)
   359  	}
   360  	return
   361  }
   362  
   363  // DecryptTo decrypts and authenticates the ciphertext and saves the plaintext to dst.
   364  func (c *ShadowStreamCipher) DecryptTo(dst, ciphertext []byte) (plaintext []byte, err error) {
   365  	plaintext, err = c.aead.Open(dst[:0], c.nonce, ciphertext, nil)
   366  	if err == nil {
   367  		increment(c.nonce)
   368  	}
   369  	return
   370  }
   371  
   372  // increment increments a little-endian unsigned integer b.
   373  func increment(b []byte) {
   374  	for i := range b {
   375  		b[i]++
   376  		if b[i] != 0 {
   377  			return
   378  		}
   379  	}
   380  }
   381  
   382  // readOnceExpectFull reads exactly once from r into b and
   383  // returns an error if the read fails to fill up b.
   384  func readOnceExpectFull(r io.Reader, b []byte) (int, error) {
   385  	n, err := r.Read(b)
   386  	if err != nil {
   387  		return n, err
   388  	}
   389  	if n < len(b) {
   390  		return n, &HeaderError[int]{ErrFirstRead, len(b), n}
   391  	}
   392  	return n, nil
   393  }
   394  
   395  // readOnceOrFullFunc returns a function that either reads exactly once from r into b
   396  // or reads until b is full, depending on the value of allowSegmentedFixedLengthHeader.
   397  func readOnceOrFullFunc(allowSegmentedFixedLengthHeader bool) func(io.Reader, []byte) (int, error) {
   398  	if allowSegmentedFixedLengthHeader {
   399  		return io.ReadFull
   400  	}
   401  	return readOnceExpectFull
   402  }