gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/alts/internal/conn/record.go (about)

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package conn contains an implementation of a secure channel created by gRPC
    20  // handshakers.
    21  package conn
    22  
    23  import (
    24  	"encoding/binary"
    25  	"fmt"
    26  	"math"
    27  	"net"
    28  
    29  	core "gitee.com/ks-custle/core-gm/grpc/credentials/alts/internal"
    30  )
    31  
    32  // ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
    33  type ALTSRecordCrypto interface {
    34  	// Encrypt encrypts the plaintext and computes the tag (if any) of dst
    35  	// and plaintext. dst and plaintext may fully overlap or not at all.
    36  	Encrypt(dst, plaintext []byte) ([]byte, error)
    37  	// EncryptionOverhead returns the tag size (if any) in bytes.
    38  	EncryptionOverhead() int
    39  	// Decrypt decrypts ciphertext and verify the tag (if any). dst and
    40  	// ciphertext may alias exactly or not at all. To reuse ciphertext's
    41  	// storage for the decrypted output, use ciphertext[:0] as dst.
    42  	Decrypt(dst, ciphertext []byte) ([]byte, error)
    43  }
    44  
    45  // ALTSRecordFunc is a function type for factory functions that create
    46  // ALTSRecordCrypto instances.
    47  type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error)
    48  
    49  const (
    50  	// MsgLenFieldSize is the byte size of the frame length field of a
    51  	// framed message.
    52  	MsgLenFieldSize = 4
    53  	// The byte size of the message type field of a framed message.
    54  	msgTypeFieldSize = 4
    55  	// The bytes size limit for a ALTS record message.
    56  	altsRecordLengthLimit = 1024 * 1024 // 1 MiB
    57  	// The default bytes size of a ALTS record message.
    58  	altsRecordDefaultLength = 4 * 1024 // 4KiB
    59  	// Message type value included in ALTS record framing.
    60  	altsRecordMsgType = uint32(0x06)
    61  	// The initial write buffer size.
    62  	altsWriteBufferInitialSize = 32 * 1024 // 32KiB
    63  	// The maximum write buffer size. This *must* be multiple of
    64  	// altsRecordDefaultLength.
    65  	altsWriteBufferMaxSize = 512 * 1024 // 512KiB
    66  )
    67  
    68  var (
    69  	protocols = make(map[string]ALTSRecordFunc)
    70  )
    71  
    72  // RegisterProtocol register a ALTS record encryption protocol.
    73  func RegisterProtocol(protocol string, f ALTSRecordFunc) error {
    74  	if _, ok := protocols[protocol]; ok {
    75  		return fmt.Errorf("protocol %v is already registered", protocol)
    76  	}
    77  	protocols[protocol] = f
    78  	return nil
    79  }
    80  
    81  // conn represents a secured connection. It implements the net.Conn interface.
    82  type conn struct {
    83  	net.Conn
    84  	crypto ALTSRecordCrypto
    85  	// buf holds data that has been read from the connection and decrypted,
    86  	// but has not yet been returned by Read.
    87  	buf                []byte
    88  	payloadLengthLimit int
    89  	// protected holds data read from the network but have not yet been
    90  	// decrypted. This data might not compose a complete frame.
    91  	protected []byte
    92  	// writeBuf is a buffer used to contain encrypted frames before being
    93  	// written to the network.
    94  	writeBuf []byte
    95  	// nextFrame stores the next frame (in protected buffer) info.
    96  	nextFrame []byte
    97  	// overhead is the calculated overhead of each frame.
    98  	overhead int
    99  }
   100  
   101  // NewConn creates a new secure channel instance given the other party role and
   102  // handshaking result.
   103  func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) {
   104  	newCrypto := protocols[recordProtocol]
   105  	if newCrypto == nil {
   106  		return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
   107  	}
   108  	crypto, err := newCrypto(side, key)
   109  	if err != nil {
   110  		return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
   111  	}
   112  	overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
   113  	payloadLengthLimit := altsRecordDefaultLength - overhead
   114  	var protectedBuf []byte
   115  	if protected == nil {
   116  		// We pre-allocate protected to be of size
   117  		// 2*altsRecordDefaultLength-1 during initialization. We only
   118  		// read from the network into protected when protected does not
   119  		// contain a complete frame, which is at most
   120  		// altsRecordDefaultLength-1 (bytes). And we read at most
   121  		// altsRecordDefaultLength (bytes) data into protected at one
   122  		// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
   123  		// to buffer data read from the network.
   124  		protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
   125  	} else {
   126  		protectedBuf = make([]byte, len(protected))
   127  		copy(protectedBuf, protected)
   128  	}
   129  
   130  	altsConn := &conn{
   131  		Conn:               c,
   132  		crypto:             crypto,
   133  		payloadLengthLimit: payloadLengthLimit,
   134  		protected:          protectedBuf,
   135  		writeBuf:           make([]byte, altsWriteBufferInitialSize),
   136  		nextFrame:          protectedBuf,
   137  		overhead:           overhead,
   138  	}
   139  	return altsConn, nil
   140  }
   141  
   142  // Read reads and decrypts a frame from the underlying connection, and copies the
   143  // decrypted payload into b. If the size of the payload is greater than len(b),
   144  // Read retains the remaining bytes in an internal buffer, and subsequent calls
   145  // to Read will read from this buffer until it is exhausted.
   146  func (p *conn) Read(b []byte) (n int, err error) {
   147  	if len(p.buf) == 0 {
   148  		var framedMsg []byte
   149  		framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
   150  		if err != nil {
   151  			return n, err
   152  		}
   153  		// Check whether the next frame to be decrypted has been
   154  		// completely received yet.
   155  		if len(framedMsg) == 0 {
   156  			copy(p.protected, p.nextFrame)
   157  			p.protected = p.protected[:len(p.nextFrame)]
   158  			// Always copy next incomplete frame to the beginning of
   159  			// the protected buffer and reset nextFrame to it.
   160  			p.nextFrame = p.protected
   161  		}
   162  		// Check whether a complete frame has been received yet.
   163  		for len(framedMsg) == 0 {
   164  			if len(p.protected) == cap(p.protected) {
   165  				tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
   166  				copy(tmp, p.protected)
   167  				p.protected = tmp
   168  			}
   169  			n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
   170  			if err != nil {
   171  				return 0, err
   172  			}
   173  			p.protected = p.protected[:len(p.protected)+n]
   174  			framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit)
   175  			if err != nil {
   176  				return 0, err
   177  			}
   178  		}
   179  		// Now we have a complete frame, decrypted it.
   180  		msg := framedMsg[MsgLenFieldSize:]
   181  		msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize])
   182  		if msgType&0xff != altsRecordMsgType {
   183  			return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v",
   184  				msgType, altsRecordMsgType)
   185  		}
   186  		ciphertext := msg[msgTypeFieldSize:]
   187  
   188  		// Decrypt requires that if the dst and ciphertext alias, they
   189  		// must alias exactly. Code here used to use msg[:0], but msg
   190  		// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
   191  		// ciphertext, so they alias inexactly. Using ciphertext[:0]
   192  		// arranges the appropriate aliasing without needing to copy
   193  		// ciphertext or use a separate destination buffer. For more info
   194  		// check: https://golang.org/pkg/crypto/cipher/#AEAD.
   195  		p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext)
   196  		if err != nil {
   197  			return 0, err
   198  		}
   199  	}
   200  
   201  	n = copy(b, p.buf)
   202  	p.buf = p.buf[n:]
   203  	return n, nil
   204  }
   205  
   206  // Write encrypts, frames, and writes bytes from b to the underlying connection.
   207  func (p *conn) Write(b []byte) (n int, err error) {
   208  	n = len(b)
   209  	// Calculate the output buffer size with framing and encryption overhead.
   210  	numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit)))
   211  	size := len(b) + numOfFrames*p.overhead
   212  	// If writeBuf is too small, increase its size up to the maximum size.
   213  	partialBSize := len(b)
   214  	if size > altsWriteBufferMaxSize {
   215  		size = altsWriteBufferMaxSize
   216  		const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength
   217  		partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit
   218  	}
   219  	if len(p.writeBuf) < size {
   220  		p.writeBuf = make([]byte, size)
   221  	}
   222  
   223  	for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize {
   224  		partialBEnd := partialBStart + partialBSize
   225  		if partialBEnd > len(b) {
   226  			partialBEnd = len(b)
   227  		}
   228  		partialB := b[partialBStart:partialBEnd]
   229  		writeBufIndex := 0
   230  		for len(partialB) > 0 {
   231  			payloadLen := len(partialB)
   232  			if payloadLen > p.payloadLengthLimit {
   233  				payloadLen = p.payloadLengthLimit
   234  			}
   235  			buf := partialB[:payloadLen]
   236  			partialB = partialB[payloadLen:]
   237  
   238  			// Write buffer contains: length, type, payload, and tag
   239  			// if any.
   240  
   241  			// 1. Fill in type field.
   242  			msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:]
   243  			binary.LittleEndian.PutUint32(msg, altsRecordMsgType)
   244  
   245  			// 2. Encrypt the payload and create a tag if any.
   246  			msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf)
   247  			if err != nil {
   248  				return n, err
   249  			}
   250  
   251  			// 3. Fill in the size field.
   252  			binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg)))
   253  
   254  			// 4. Increase writeBufIndex.
   255  			writeBufIndex += len(buf) + p.overhead
   256  		}
   257  		nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex])
   258  		if err != nil {
   259  			// We need to calculate the actual data size that was
   260  			// written. This means we need to remove header,
   261  			// encryption overheads, and any partially-written
   262  			// frame data.
   263  			numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength)))
   264  			return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err
   265  		}
   266  	}
   267  	return n, nil
   268  }
   269  
   270  func min(a, b int) int {
   271  	if a < b {
   272  		return a
   273  	}
   274  	return b
   275  }