google.golang.org/grpc@v1.74.2/credentials/alts/internal/conn/record.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 // Package conn contains an implementation of a secure channel created by gRPC 20 // handshakers. 21 package conn 22 23 import ( 24 "encoding/binary" 25 "fmt" 26 "math" 27 "net" 28 29 core "google.golang.org/grpc/credentials/alts/internal" 30 ) 31 32 // ALTSRecordCrypto is the interface for gRPC ALTS record protocol. 33 type ALTSRecordCrypto interface { 34 // Encrypt encrypts the plaintext, computes the tag (if any) of dst and 35 // plaintext, and appends the result to dst, returning the updated slice. 36 // dst and plaintext may fully overlap or not at all. 37 Encrypt(dst, plaintext []byte) ([]byte, error) 38 // EncryptionOverhead returns the tag size (if any) in bytes. 39 EncryptionOverhead() int 40 // Decrypt decrypts ciphertext and verifies the tag (if any). If successful, 41 // this function appends the resulting plaintext to dst, returning the 42 // updated slice. dst and ciphertext may alias exactly or not at all. To 43 // reuse ciphertext's storage for the decrypted output, use ciphertext[:0] 44 // as dst. Even if the function fails, the contents of dst, up to its 45 // capacity, may be overwritten. 46 Decrypt(dst, ciphertext []byte) ([]byte, error) 47 } 48 49 // ALTSRecordFunc is a function type for factory functions that create 50 // ALTSRecordCrypto instances. 51 type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) 52 53 const ( 54 // MsgLenFieldSize is the byte size of the frame length field of a 55 // framed message. 56 MsgLenFieldSize = 4 57 // The byte size of the message type field of a framed message. 58 msgTypeFieldSize = 4 59 // The bytes size limit for a ALTS record message. 60 altsRecordLengthLimit = 1024 * 1024 // 1 MiB 61 // The default bytes size of a ALTS record message. 62 altsRecordDefaultLength = 4 * 1024 // 4KiB 63 // Message type value included in ALTS record framing. 64 altsRecordMsgType = uint32(0x06) 65 // The initial write buffer size. 66 altsWriteBufferInitialSize = 32 * 1024 // 32KiB 67 // The maximum write buffer size. This *must* be multiple of 68 // altsRecordDefaultLength. 69 altsWriteBufferMaxSize = 512 * 1024 // 512KiB 70 // The initial buffer used to read from the network. 71 altsReadBufferInitialSize = 32 * 1024 // 32KiB 72 ) 73 74 var ( 75 protocols = make(map[string]ALTSRecordFunc) 76 ) 77 78 // RegisterProtocol register a ALTS record encryption protocol. 79 func RegisterProtocol(protocol string, f ALTSRecordFunc) error { 80 if _, ok := protocols[protocol]; ok { 81 return fmt.Errorf("protocol %v is already registered", protocol) 82 } 83 protocols[protocol] = f 84 return nil 85 } 86 87 // conn represents a secured connection. It implements the net.Conn interface. 88 type conn struct { 89 net.Conn 90 crypto ALTSRecordCrypto 91 // buf holds data that has been read from the connection and decrypted, 92 // but has not yet been returned by Read. It is a sub-slice of protected. 93 buf []byte 94 payloadLengthLimit int 95 // protected holds data read from the network but have not yet been 96 // decrypted. This data might not compose a complete frame. 97 protected []byte 98 // writeBuf is a buffer used to contain encrypted frames before being 99 // written to the network. 100 writeBuf []byte 101 // nextFrame stores the next frame (in protected buffer) info. 102 nextFrame []byte 103 // overhead is the calculated overhead of each frame. 104 overhead int 105 } 106 107 // NewConn creates a new secure channel instance given the other party role and 108 // handshaking result. 109 func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) { 110 newCrypto := protocols[recordProtocol] 111 if newCrypto == nil { 112 return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol) 113 } 114 crypto, err := newCrypto(side, key) 115 if err != nil { 116 return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err) 117 } 118 overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() 119 payloadLengthLimit := altsRecordDefaultLength - overhead 120 // We pre-allocate protected to be of size 32KB during initialization. 121 // We increase the size of the buffer by the required amount if it can't 122 // hold a complete encrypted record. 123 protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected))) 124 // Copy additional data from hanshaker service. 125 copy(protectedBuf, protected) 126 protectedBuf = protectedBuf[:len(protected)] 127 128 altsConn := &conn{ 129 Conn: c, 130 crypto: crypto, 131 payloadLengthLimit: payloadLengthLimit, 132 protected: protectedBuf, 133 writeBuf: make([]byte, altsWriteBufferInitialSize), 134 nextFrame: protectedBuf, 135 overhead: overhead, 136 } 137 return altsConn, nil 138 } 139 140 // Read reads and decrypts a frame from the underlying connection, and copies the 141 // decrypted payload into b. If the size of the payload is greater than len(b), 142 // Read retains the remaining bytes in an internal buffer, and subsequent calls 143 // to Read will read from this buffer until it is exhausted. 144 func (p *conn) Read(b []byte) (n int, err error) { 145 if len(p.buf) == 0 { 146 var framedMsg []byte 147 framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit) 148 if err != nil { 149 return n, err 150 } 151 // Check whether the next frame to be decrypted has been 152 // completely received yet. 153 if len(framedMsg) == 0 { 154 copy(p.protected, p.nextFrame) 155 p.protected = p.protected[:len(p.nextFrame)] 156 // Always copy next incomplete frame to the beginning of 157 // the protected buffer and reset nextFrame to it. 158 p.nextFrame = p.protected 159 } 160 // Check whether a complete frame has been received yet. 161 for len(framedMsg) == 0 { 162 if len(p.protected) == cap(p.protected) { 163 // We can parse the length header to know exactly how large 164 // the buffer needs to be to hold the entire frame. 165 length, didParse := parseMessageLength(p.protected) 166 if !didParse { 167 // The protected buffer is initialized with a capacity of 168 // larger than 4B. It should always hold the message length 169 // header. 170 panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize)) 171 } 172 oldProtectedBuf := p.protected 173 // The new buffer must be able to hold the message length header 174 // and the entire message. 175 requiredCapacity := int(length) + MsgLenFieldSize 176 p.protected = make([]byte, requiredCapacity) 177 // Copy the contents of the old buffer and set the length of the 178 // new buffer to the number of bytes already read. 179 copy(p.protected, oldProtectedBuf) 180 p.protected = p.protected[:len(oldProtectedBuf)] 181 } 182 n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)]) 183 if err != nil { 184 return 0, err 185 } 186 p.protected = p.protected[:len(p.protected)+n] 187 framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit) 188 if err != nil { 189 return 0, err 190 } 191 } 192 // Now we have a complete frame, decrypted it. 193 msg := framedMsg[MsgLenFieldSize:] 194 msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize]) 195 if msgType&0xff != altsRecordMsgType { 196 return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v", 197 msgType, altsRecordMsgType) 198 } 199 ciphertext := msg[msgTypeFieldSize:] 200 201 // Decrypt directly into the buffer, avoiding a copy from p.buf if 202 // possible. 203 if len(b) >= len(ciphertext) { 204 dec, err := p.crypto.Decrypt(b[:0], ciphertext) 205 if err != nil { 206 return 0, err 207 } 208 return len(dec), nil 209 } 210 // Decrypt requires that if the dst and ciphertext alias, they 211 // must alias exactly. Code here used to use msg[:0], but msg 212 // starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than 213 // ciphertext, so they alias inexactly. Using ciphertext[:0] 214 // arranges the appropriate aliasing without needing to copy 215 // ciphertext or use a separate destination buffer. For more info 216 // check: https://golang.org/pkg/crypto/cipher/#AEAD. 217 p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext) 218 if err != nil { 219 return 0, err 220 } 221 } 222 223 n = copy(b, p.buf) 224 p.buf = p.buf[n:] 225 return n, nil 226 } 227 228 // Write encrypts, frames, and writes bytes from b to the underlying connection. 229 func (p *conn) Write(b []byte) (n int, err error) { 230 n = len(b) 231 // Calculate the output buffer size with framing and encryption overhead. 232 numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit))) 233 size := len(b) + numOfFrames*p.overhead 234 // If writeBuf is too small, increase its size up to the maximum size. 235 partialBSize := len(b) 236 if size > altsWriteBufferMaxSize { 237 size = altsWriteBufferMaxSize 238 const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength 239 partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit 240 } 241 if len(p.writeBuf) < size { 242 p.writeBuf = make([]byte, size) 243 } 244 245 for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize { 246 partialBEnd := partialBStart + partialBSize 247 if partialBEnd > len(b) { 248 partialBEnd = len(b) 249 } 250 partialB := b[partialBStart:partialBEnd] 251 writeBufIndex := 0 252 for len(partialB) > 0 { 253 payloadLen := len(partialB) 254 if payloadLen > p.payloadLengthLimit { 255 payloadLen = p.payloadLengthLimit 256 } 257 buf := partialB[:payloadLen] 258 partialB = partialB[payloadLen:] 259 260 // Write buffer contains: length, type, payload, and tag 261 // if any. 262 263 // 1. Fill in type field. 264 msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:] 265 binary.LittleEndian.PutUint32(msg, altsRecordMsgType) 266 267 // 2. Encrypt the payload and create a tag if any. 268 msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf) 269 if err != nil { 270 return n, err 271 } 272 273 // 3. Fill in the size field. 274 binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg))) 275 276 // 4. Increase writeBufIndex. 277 writeBufIndex += len(buf) + p.overhead 278 } 279 nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex]) 280 if err != nil { 281 // We need to calculate the actual data size that was 282 // written. This means we need to remove header, 283 // encryption overheads, and any partially-written 284 // frame data. 285 numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength))) 286 return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err 287 } 288 } 289 return n, nil 290 }