github.com/xraypb/xray-core@v1.6.6/common/crypto/auth.go (about)

     1  package crypto
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/rand"
     6  	"io"
     7  
     8  	"github.com/xraypb/xray-core/common"
     9  	"github.com/xraypb/xray-core/common/buf"
    10  	"github.com/xraypb/xray-core/common/bytespool"
    11  	"github.com/xraypb/xray-core/common/protocol"
    12  )
    13  
    14  type BytesGenerator func() []byte
    15  
    16  func GenerateEmptyBytes() BytesGenerator {
    17  	var b [1]byte
    18  	return func() []byte {
    19  		return b[:0]
    20  	}
    21  }
    22  
    23  func GenerateStaticBytes(content []byte) BytesGenerator {
    24  	return func() []byte {
    25  		return content
    26  	}
    27  }
    28  
    29  func GenerateIncreasingNonce(nonce []byte) BytesGenerator {
    30  	c := append([]byte(nil), nonce...)
    31  	return func() []byte {
    32  		for i := range c {
    33  			c[i]++
    34  			if c[i] != 0 {
    35  				break
    36  			}
    37  		}
    38  		return c
    39  	}
    40  }
    41  
    42  func GenerateAEADNonceWithSize(nonceSize int) BytesGenerator {
    43  	c := make([]byte, nonceSize)
    44  	for i := 0; i < nonceSize; i++ {
    45  		c[i] = 0xFF
    46  	}
    47  	return GenerateIncreasingNonce(c)
    48  }
    49  
    50  type Authenticator interface {
    51  	NonceSize() int
    52  	Overhead() int
    53  	Open(dst, cipherText []byte) ([]byte, error)
    54  	Seal(dst, plainText []byte) ([]byte, error)
    55  }
    56  
    57  type AEADAuthenticator struct {
    58  	cipher.AEAD
    59  	NonceGenerator          BytesGenerator
    60  	AdditionalDataGenerator BytesGenerator
    61  }
    62  
    63  func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
    64  	iv := v.NonceGenerator()
    65  	if len(iv) != v.AEAD.NonceSize() {
    66  		return nil, newError("invalid AEAD nonce size: ", len(iv))
    67  	}
    68  
    69  	var additionalData []byte
    70  	if v.AdditionalDataGenerator != nil {
    71  		additionalData = v.AdditionalDataGenerator()
    72  	}
    73  	return v.AEAD.Open(dst, iv, cipherText, additionalData)
    74  }
    75  
    76  func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
    77  	iv := v.NonceGenerator()
    78  	if len(iv) != v.AEAD.NonceSize() {
    79  		return nil, newError("invalid AEAD nonce size: ", len(iv))
    80  	}
    81  
    82  	var additionalData []byte
    83  	if v.AdditionalDataGenerator != nil {
    84  		additionalData = v.AdditionalDataGenerator()
    85  	}
    86  	return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
    87  }
    88  
    89  type AuthenticationReader struct {
    90  	auth         Authenticator
    91  	reader       *buf.BufferedReader
    92  	sizeParser   ChunkSizeDecoder
    93  	sizeBytes    []byte
    94  	transferType protocol.TransferType
    95  	padding      PaddingLengthGenerator
    96  	size         uint16
    97  	paddingLen   uint16
    98  	hasSize      bool
    99  	done         bool
   100  }
   101  
   102  func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType, paddingLen PaddingLengthGenerator) *AuthenticationReader {
   103  	r := &AuthenticationReader{
   104  		auth:         auth,
   105  		sizeParser:   sizeParser,
   106  		transferType: transferType,
   107  		padding:      paddingLen,
   108  		sizeBytes:    make([]byte, sizeParser.SizeBytes()),
   109  	}
   110  	if breader, ok := reader.(*buf.BufferedReader); ok {
   111  		r.reader = breader
   112  	} else {
   113  		r.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)}
   114  	}
   115  	return r
   116  }
   117  
   118  func (r *AuthenticationReader) readSize() (uint16, uint16, error) {
   119  	if r.hasSize {
   120  		r.hasSize = false
   121  		return r.size, r.paddingLen, nil
   122  	}
   123  	if _, err := io.ReadFull(r.reader, r.sizeBytes); err != nil {
   124  		return 0, 0, err
   125  	}
   126  	var padding uint16
   127  	if r.padding != nil {
   128  		padding = r.padding.NextPaddingLen()
   129  	}
   130  	size, err := r.sizeParser.Decode(r.sizeBytes)
   131  	return size, padding, err
   132  }
   133  
   134  var errSoft = newError("waiting for more data")
   135  
   136  func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
   137  	b := buf.New()
   138  	if _, err := b.ReadFullFrom(r.reader, size); err != nil {
   139  		b.Release()
   140  		return nil, err
   141  	}
   142  	size -= padding
   143  	rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size))
   144  	if err != nil {
   145  		b.Release()
   146  		return nil, err
   147  	}
   148  	b.Resize(0, int32(len(rb)))
   149  	return b, nil
   150  }
   151  
   152  func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) error {
   153  	if soft && r.reader.BufferedBytes() < r.sizeParser.SizeBytes() {
   154  		return errSoft
   155  	}
   156  
   157  	if r.done {
   158  		return io.EOF
   159  	}
   160  
   161  	size, padding, err := r.readSize()
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	if size == uint16(r.auth.Overhead())+padding {
   167  		r.done = true
   168  		return io.EOF
   169  	}
   170  
   171  	if soft && int32(size) > r.reader.BufferedBytes() {
   172  		r.size = size
   173  		r.paddingLen = padding
   174  		r.hasSize = true
   175  		return errSoft
   176  	}
   177  
   178  	if size <= buf.Size {
   179  		b, err := r.readBuffer(int32(size), int32(padding))
   180  		if err != nil {
   181  			return nil
   182  		}
   183  		*mb = append(*mb, b)
   184  		return nil
   185  	}
   186  
   187  	payload := bytespool.Alloc(int32(size))
   188  	defer bytespool.Free(payload)
   189  
   190  	if _, err := io.ReadFull(r.reader, payload[:size]); err != nil {
   191  		return err
   192  	}
   193  
   194  	size -= padding
   195  
   196  	rb, err := r.auth.Open(payload[:0], payload[:size])
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	*mb = buf.MergeBytes(*mb, rb)
   202  	return nil
   203  }
   204  
   205  func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   206  	const readSize = 16
   207  	mb := make(buf.MultiBuffer, 0, readSize)
   208  	if err := r.readInternal(false, &mb); err != nil {
   209  		buf.ReleaseMulti(mb)
   210  		return nil, err
   211  	}
   212  
   213  	for i := 1; i < readSize; i++ {
   214  		err := r.readInternal(true, &mb)
   215  		if err == errSoft || err == io.EOF {
   216  			break
   217  		}
   218  		if err != nil {
   219  			buf.ReleaseMulti(mb)
   220  			return nil, err
   221  		}
   222  	}
   223  
   224  	return mb, nil
   225  }
   226  
   227  type AuthenticationWriter struct {
   228  	auth         Authenticator
   229  	writer       buf.Writer
   230  	sizeParser   ChunkSizeEncoder
   231  	transferType protocol.TransferType
   232  	padding      PaddingLengthGenerator
   233  }
   234  
   235  func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType, padding PaddingLengthGenerator) *AuthenticationWriter {
   236  	w := &AuthenticationWriter{
   237  		auth:         auth,
   238  		writer:       buf.NewWriter(writer),
   239  		sizeParser:   sizeParser,
   240  		transferType: transferType,
   241  	}
   242  	if padding != nil {
   243  		w.padding = padding
   244  	}
   245  	return w
   246  }
   247  
   248  func (w *AuthenticationWriter) seal(b []byte) (*buf.Buffer, error) {
   249  	encryptedSize := int32(len(b) + w.auth.Overhead())
   250  	var paddingSize int32
   251  	if w.padding != nil {
   252  		paddingSize = int32(w.padding.NextPaddingLen())
   253  	}
   254  
   255  	sizeBytes := w.sizeParser.SizeBytes()
   256  	totalSize := sizeBytes + encryptedSize + paddingSize
   257  	if totalSize > buf.Size {
   258  		return nil, newError("size too large: ", totalSize)
   259  	}
   260  
   261  	eb := buf.New()
   262  	w.sizeParser.Encode(uint16(encryptedSize+paddingSize), eb.Extend(sizeBytes))
   263  	if _, err := w.auth.Seal(eb.Extend(encryptedSize)[:0], b); err != nil {
   264  		eb.Release()
   265  		return nil, err
   266  	}
   267  	if paddingSize > 0 {
   268  		// These paddings will send in clear text.
   269  		// To avoid leakage of PRNG internal state, a cryptographically secure PRNG should be used.
   270  		paddingBytes := eb.Extend(paddingSize)
   271  		common.Must2(rand.Read(paddingBytes))
   272  	}
   273  
   274  	return eb, nil
   275  }
   276  
   277  func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
   278  	defer buf.ReleaseMulti(mb)
   279  
   280  	var maxPadding int32
   281  	if w.padding != nil {
   282  		maxPadding = int32(w.padding.MaxPaddingLen())
   283  	}
   284  
   285  	payloadSize := buf.Size - int32(w.auth.Overhead()) - w.sizeParser.SizeBytes() - maxPadding
   286  	mb2Write := make(buf.MultiBuffer, 0, len(mb)+10)
   287  
   288  	temp := buf.New()
   289  	defer temp.Release()
   290  
   291  	rawBytes := temp.Extend(payloadSize)
   292  
   293  	for {
   294  		nb, nBytes := buf.SplitBytes(mb, rawBytes)
   295  		mb = nb
   296  
   297  		eb, err := w.seal(rawBytes[:nBytes])
   298  		if err != nil {
   299  			buf.ReleaseMulti(mb2Write)
   300  			return err
   301  		}
   302  		mb2Write = append(mb2Write, eb)
   303  		if mb.IsEmpty() {
   304  			break
   305  		}
   306  	}
   307  
   308  	return w.writer.WriteMultiBuffer(mb2Write)
   309  }
   310  
   311  func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
   312  	defer buf.ReleaseMulti(mb)
   313  
   314  	mb2Write := make(buf.MultiBuffer, 0, len(mb)+1)
   315  
   316  	for _, b := range mb {
   317  		if b.IsEmpty() {
   318  			continue
   319  		}
   320  
   321  		eb, err := w.seal(b.Bytes())
   322  		if err != nil {
   323  			continue
   324  		}
   325  
   326  		mb2Write = append(mb2Write, eb)
   327  	}
   328  
   329  	if mb2Write.IsEmpty() {
   330  		return nil
   331  	}
   332  
   333  	return w.writer.WriteMultiBuffer(mb2Write)
   334  }
   335  
   336  // WriteMultiBuffer implements buf.Writer.
   337  func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   338  	if mb.IsEmpty() {
   339  		eb, err := w.seal([]byte{})
   340  		common.Must(err)
   341  		return w.writer.WriteMultiBuffer(buf.MultiBuffer{eb})
   342  	}
   343  
   344  	if w.transferType == protocol.TransferTypeStream {
   345  		return w.writeStream(mb)
   346  	}
   347  
   348  	return w.writePacket(mb)
   349  }