google.golang.org/grpc@v1.74.2/credentials/alts/internal/conn/record.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package conn contains an implementation of a secure channel created by gRPC
    20  // handshakers.
    21  package conn
    22  
    23  import (
    24  	"encoding/binary"
    25  	"fmt"
    26  	"math"
    27  	"net"
    28  
    29  	core "google.golang.org/grpc/credentials/alts/internal"
    30  )
    31  
    32  // ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
    33  type ALTSRecordCrypto interface {
    34  	// Encrypt encrypts the plaintext, computes the tag (if any) of dst and
    35  	// plaintext, and appends the result to dst, returning the updated slice.
    36  	// dst and plaintext may fully overlap or not at all.
    37  	Encrypt(dst, plaintext []byte) ([]byte, error)
    38  	// EncryptionOverhead returns the tag size (if any) in bytes.
    39  	EncryptionOverhead() int
    40  	// Decrypt decrypts ciphertext and verifies the tag (if any). If successful,
    41  	// this function appends the resulting plaintext to dst, returning the
    42  	// updated slice. dst and ciphertext may alias exactly or not at all. To
    43  	// reuse ciphertext's storage for the decrypted output, use ciphertext[:0]
    44  	// as dst. Even if the function fails, the contents of dst, up to its
    45  	// capacity, may be overwritten.
    46  	Decrypt(dst, ciphertext []byte) ([]byte, error)
    47  }
    48  
    49  // ALTSRecordFunc is a function type for factory functions that create
    50  // ALTSRecordCrypto instances.
    51  type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error)
    52  
    53  const (
    54  	// MsgLenFieldSize is the byte size of the frame length field of a
    55  	// framed message.
    56  	MsgLenFieldSize = 4
    57  	// The byte size of the message type field of a framed message.
    58  	msgTypeFieldSize = 4
    59  	// The bytes size limit for a ALTS record message.
    60  	altsRecordLengthLimit = 1024 * 1024 // 1 MiB
    61  	// The default bytes size of a ALTS record message.
    62  	altsRecordDefaultLength = 4 * 1024 // 4KiB
    63  	// Message type value included in ALTS record framing.
    64  	altsRecordMsgType = uint32(0x06)
    65  	// The initial write buffer size.
    66  	altsWriteBufferInitialSize = 32 * 1024 // 32KiB
    67  	// The maximum write buffer size. This *must* be multiple of
    68  	// altsRecordDefaultLength.
    69  	altsWriteBufferMaxSize = 512 * 1024 // 512KiB
    70  	// The initial buffer used to read from the network.
    71  	altsReadBufferInitialSize = 32 * 1024 // 32KiB
    72  )
    73  
    74  var (
    75  	protocols = make(map[string]ALTSRecordFunc)
    76  )
    77  
    78  // RegisterProtocol register a ALTS record encryption protocol.
    79  func RegisterProtocol(protocol string, f ALTSRecordFunc) error {
    80  	if _, ok := protocols[protocol]; ok {
    81  		return fmt.Errorf("protocol %v is already registered", protocol)
    82  	}
    83  	protocols[protocol] = f
    84  	return nil
    85  }
    86  
    87  // conn represents a secured connection. It implements the net.Conn interface.
    88  type conn struct {
    89  	net.Conn
    90  	crypto ALTSRecordCrypto
    91  	// buf holds data that has been read from the connection and decrypted,
    92  	// but has not yet been returned by Read. It is a sub-slice of protected.
    93  	buf                []byte
    94  	payloadLengthLimit int
    95  	// protected holds data read from the network but have not yet been
    96  	// decrypted. This data might not compose a complete frame.
    97  	protected []byte
    98  	// writeBuf is a buffer used to contain encrypted frames before being
    99  	// written to the network.
   100  	writeBuf []byte
   101  	// nextFrame stores the next frame (in protected buffer) info.
   102  	nextFrame []byte
   103  	// overhead is the calculated overhead of each frame.
   104  	overhead int
   105  }
   106  
   107  // NewConn creates a new secure channel instance given the other party role and
   108  // handshaking result.
   109  func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) {
   110  	newCrypto := protocols[recordProtocol]
   111  	if newCrypto == nil {
   112  		return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
   113  	}
   114  	crypto, err := newCrypto(side, key)
   115  	if err != nil {
   116  		return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
   117  	}
   118  	overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
   119  	payloadLengthLimit := altsRecordDefaultLength - overhead
   120  	// We pre-allocate protected to be of size 32KB during initialization.
   121  	// We increase the size of the buffer by the required amount if it can't
   122  	// hold a complete encrypted record.
   123  	protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
   124  	// Copy additional data from hanshaker service.
   125  	copy(protectedBuf, protected)
   126  	protectedBuf = protectedBuf[:len(protected)]
   127  
   128  	altsConn := &conn{
   129  		Conn:               c,
   130  		crypto:             crypto,
   131  		payloadLengthLimit: payloadLengthLimit,
   132  		protected:          protectedBuf,
   133  		writeBuf:           make([]byte, altsWriteBufferInitialSize),
   134  		nextFrame:          protectedBuf,
   135  		overhead:           overhead,
   136  	}
   137  	return altsConn, nil
   138  }
   139  
   140  // Read reads and decrypts a frame from the underlying connection, and copies the
   141  // decrypted payload into b. If the size of the payload is greater than len(b),
   142  // Read retains the remaining bytes in an internal buffer, and subsequent calls
   143  // to Read will read from this buffer until it is exhausted.
   144  func (p *conn) Read(b []byte) (n int, err error) {
   145  	if len(p.buf) == 0 {
   146  		var framedMsg []byte
   147  		framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
   148  		if err != nil {
   149  			return n, err
   150  		}
   151  		// Check whether the next frame to be decrypted has been
   152  		// completely received yet.
   153  		if len(framedMsg) == 0 {
   154  			copy(p.protected, p.nextFrame)
   155  			p.protected = p.protected[:len(p.nextFrame)]
   156  			// Always copy next incomplete frame to the beginning of
   157  			// the protected buffer and reset nextFrame to it.
   158  			p.nextFrame = p.protected
   159  		}
   160  		// Check whether a complete frame has been received yet.
   161  		for len(framedMsg) == 0 {
   162  			if len(p.protected) == cap(p.protected) {
   163  				// We can parse the length header to know exactly how large
   164  				// the buffer needs to be to hold the entire frame.
   165  				length, didParse := parseMessageLength(p.protected)
   166  				if !didParse {
   167  					// The protected buffer is initialized with a capacity of
   168  					// larger than 4B. It should always hold the message length
   169  					// header.
   170  					panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
   171  				}
   172  				oldProtectedBuf := p.protected
   173  				// The new buffer must be able to hold the message length header
   174  				// and the entire message.
   175  				requiredCapacity := int(length) + MsgLenFieldSize
   176  				p.protected = make([]byte, requiredCapacity)
   177  				// Copy the contents of the old buffer and set the length of the
   178  				// new buffer to the number of bytes already read.
   179  				copy(p.protected, oldProtectedBuf)
   180  				p.protected = p.protected[:len(oldProtectedBuf)]
   181  			}
   182  			n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
   183  			if err != nil {
   184  				return 0, err
   185  			}
   186  			p.protected = p.protected[:len(p.protected)+n]
   187  			framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit)
   188  			if err != nil {
   189  				return 0, err
   190  			}
   191  		}
   192  		// Now we have a complete frame, decrypted it.
   193  		msg := framedMsg[MsgLenFieldSize:]
   194  		msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize])
   195  		if msgType&0xff != altsRecordMsgType {
   196  			return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v",
   197  				msgType, altsRecordMsgType)
   198  		}
   199  		ciphertext := msg[msgTypeFieldSize:]
   200  
   201  		// Decrypt directly into the buffer, avoiding a copy from p.buf if
   202  		// possible.
   203  		if len(b) >= len(ciphertext) {
   204  			dec, err := p.crypto.Decrypt(b[:0], ciphertext)
   205  			if err != nil {
   206  				return 0, err
   207  			}
   208  			return len(dec), nil
   209  		}
   210  		// Decrypt requires that if the dst and ciphertext alias, they
   211  		// must alias exactly. Code here used to use msg[:0], but msg
   212  		// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
   213  		// ciphertext, so they alias inexactly. Using ciphertext[:0]
   214  		// arranges the appropriate aliasing without needing to copy
   215  		// ciphertext or use a separate destination buffer. For more info
   216  		// check: https://golang.org/pkg/crypto/cipher/#AEAD.
   217  		p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext)
   218  		if err != nil {
   219  			return 0, err
   220  		}
   221  	}
   222  
   223  	n = copy(b, p.buf)
   224  	p.buf = p.buf[n:]
   225  	return n, nil
   226  }
   227  
   228  // Write encrypts, frames, and writes bytes from b to the underlying connection.
   229  func (p *conn) Write(b []byte) (n int, err error) {
   230  	n = len(b)
   231  	// Calculate the output buffer size with framing and encryption overhead.
   232  	numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit)))
   233  	size := len(b) + numOfFrames*p.overhead
   234  	// If writeBuf is too small, increase its size up to the maximum size.
   235  	partialBSize := len(b)
   236  	if size > altsWriteBufferMaxSize {
   237  		size = altsWriteBufferMaxSize
   238  		const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength
   239  		partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit
   240  	}
   241  	if len(p.writeBuf) < size {
   242  		p.writeBuf = make([]byte, size)
   243  	}
   244  
   245  	for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize {
   246  		partialBEnd := partialBStart + partialBSize
   247  		if partialBEnd > len(b) {
   248  			partialBEnd = len(b)
   249  		}
   250  		partialB := b[partialBStart:partialBEnd]
   251  		writeBufIndex := 0
   252  		for len(partialB) > 0 {
   253  			payloadLen := len(partialB)
   254  			if payloadLen > p.payloadLengthLimit {
   255  				payloadLen = p.payloadLengthLimit
   256  			}
   257  			buf := partialB[:payloadLen]
   258  			partialB = partialB[payloadLen:]
   259  
   260  			// Write buffer contains: length, type, payload, and tag
   261  			// if any.
   262  
   263  			// 1. Fill in type field.
   264  			msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:]
   265  			binary.LittleEndian.PutUint32(msg, altsRecordMsgType)
   266  
   267  			// 2. Encrypt the payload and create a tag if any.
   268  			msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf)
   269  			if err != nil {
   270  				return n, err
   271  			}
   272  
   273  			// 3. Fill in the size field.
   274  			binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg)))
   275  
   276  			// 4. Increase writeBufIndex.
   277  			writeBufIndex += len(buf) + p.overhead
   278  		}
   279  		nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex])
   280  		if err != nil {
   281  			// We need to calculate the actual data size that was
   282  			// written. This means we need to remove header,
   283  			// encryption overheads, and any partially-written
   284  			// frame data.
   285  			numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength)))
   286  			return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err
   287  		}
   288  	}
   289  	return n, nil
   290  }