gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/alts/internal/conn/record.go (about) 1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 // Package conn contains an implementation of a secure channel created by gRPC 20 // handshakers. 21 package conn 22 23 import ( 24 "encoding/binary" 25 "fmt" 26 "math" 27 "net" 28 29 core "gitee.com/ks-custle/core-gm/grpc/credentials/alts/internal" 30 ) 31 32 // ALTSRecordCrypto is the interface for gRPC ALTS record protocol. 33 type ALTSRecordCrypto interface { 34 // Encrypt encrypts the plaintext and computes the tag (if any) of dst 35 // and plaintext. dst and plaintext may fully overlap or not at all. 36 Encrypt(dst, plaintext []byte) ([]byte, error) 37 // EncryptionOverhead returns the tag size (if any) in bytes. 38 EncryptionOverhead() int 39 // Decrypt decrypts ciphertext and verify the tag (if any). dst and 40 // ciphertext may alias exactly or not at all. To reuse ciphertext's 41 // storage for the decrypted output, use ciphertext[:0] as dst. 42 Decrypt(dst, ciphertext []byte) ([]byte, error) 43 } 44 45 // ALTSRecordFunc is a function type for factory functions that create 46 // ALTSRecordCrypto instances. 47 type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) 48 49 const ( 50 // MsgLenFieldSize is the byte size of the frame length field of a 51 // framed message. 52 MsgLenFieldSize = 4 53 // The byte size of the message type field of a framed message. 54 msgTypeFieldSize = 4 55 // The bytes size limit for a ALTS record message. 56 altsRecordLengthLimit = 1024 * 1024 // 1 MiB 57 // The default bytes size of a ALTS record message. 58 altsRecordDefaultLength = 4 * 1024 // 4KiB 59 // Message type value included in ALTS record framing. 60 altsRecordMsgType = uint32(0x06) 61 // The initial write buffer size. 62 altsWriteBufferInitialSize = 32 * 1024 // 32KiB 63 // The maximum write buffer size. This *must* be multiple of 64 // altsRecordDefaultLength. 65 altsWriteBufferMaxSize = 512 * 1024 // 512KiB 66 ) 67 68 var ( 69 protocols = make(map[string]ALTSRecordFunc) 70 ) 71 72 // RegisterProtocol register a ALTS record encryption protocol. 73 func RegisterProtocol(protocol string, f ALTSRecordFunc) error { 74 if _, ok := protocols[protocol]; ok { 75 return fmt.Errorf("protocol %v is already registered", protocol) 76 } 77 protocols[protocol] = f 78 return nil 79 } 80 81 // conn represents a secured connection. It implements the net.Conn interface. 82 type conn struct { 83 net.Conn 84 crypto ALTSRecordCrypto 85 // buf holds data that has been read from the connection and decrypted, 86 // but has not yet been returned by Read. 87 buf []byte 88 payloadLengthLimit int 89 // protected holds data read from the network but have not yet been 90 // decrypted. This data might not compose a complete frame. 91 protected []byte 92 // writeBuf is a buffer used to contain encrypted frames before being 93 // written to the network. 94 writeBuf []byte 95 // nextFrame stores the next frame (in protected buffer) info. 96 nextFrame []byte 97 // overhead is the calculated overhead of each frame. 98 overhead int 99 } 100 101 // NewConn creates a new secure channel instance given the other party role and 102 // handshaking result. 103 func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) { 104 newCrypto := protocols[recordProtocol] 105 if newCrypto == nil { 106 return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol) 107 } 108 crypto, err := newCrypto(side, key) 109 if err != nil { 110 return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err) 111 } 112 overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() 113 payloadLengthLimit := altsRecordDefaultLength - overhead 114 var protectedBuf []byte 115 if protected == nil { 116 // We pre-allocate protected to be of size 117 // 2*altsRecordDefaultLength-1 during initialization. We only 118 // read from the network into protected when protected does not 119 // contain a complete frame, which is at most 120 // altsRecordDefaultLength-1 (bytes). And we read at most 121 // altsRecordDefaultLength (bytes) data into protected at one 122 // time. Therefore, 2*altsRecordDefaultLength-1 is large enough 123 // to buffer data read from the network. 124 protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1) 125 } else { 126 protectedBuf = make([]byte, len(protected)) 127 copy(protectedBuf, protected) 128 } 129 130 altsConn := &conn{ 131 Conn: c, 132 crypto: crypto, 133 payloadLengthLimit: payloadLengthLimit, 134 protected: protectedBuf, 135 writeBuf: make([]byte, altsWriteBufferInitialSize), 136 nextFrame: protectedBuf, 137 overhead: overhead, 138 } 139 return altsConn, nil 140 } 141 142 // Read reads and decrypts a frame from the underlying connection, and copies the 143 // decrypted payload into b. If the size of the payload is greater than len(b), 144 // Read retains the remaining bytes in an internal buffer, and subsequent calls 145 // to Read will read from this buffer until it is exhausted. 146 func (p *conn) Read(b []byte) (n int, err error) { 147 if len(p.buf) == 0 { 148 var framedMsg []byte 149 framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit) 150 if err != nil { 151 return n, err 152 } 153 // Check whether the next frame to be decrypted has been 154 // completely received yet. 155 if len(framedMsg) == 0 { 156 copy(p.protected, p.nextFrame) 157 p.protected = p.protected[:len(p.nextFrame)] 158 // Always copy next incomplete frame to the beginning of 159 // the protected buffer and reset nextFrame to it. 160 p.nextFrame = p.protected 161 } 162 // Check whether a complete frame has been received yet. 163 for len(framedMsg) == 0 { 164 if len(p.protected) == cap(p.protected) { 165 tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength) 166 copy(tmp, p.protected) 167 p.protected = tmp 168 } 169 n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)]) 170 if err != nil { 171 return 0, err 172 } 173 p.protected = p.protected[:len(p.protected)+n] 174 framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit) 175 if err != nil { 176 return 0, err 177 } 178 } 179 // Now we have a complete frame, decrypted it. 180 msg := framedMsg[MsgLenFieldSize:] 181 msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize]) 182 if msgType&0xff != altsRecordMsgType { 183 return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v", 184 msgType, altsRecordMsgType) 185 } 186 ciphertext := msg[msgTypeFieldSize:] 187 188 // Decrypt requires that if the dst and ciphertext alias, they 189 // must alias exactly. Code here used to use msg[:0], but msg 190 // starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than 191 // ciphertext, so they alias inexactly. Using ciphertext[:0] 192 // arranges the appropriate aliasing without needing to copy 193 // ciphertext or use a separate destination buffer. For more info 194 // check: https://golang.org/pkg/crypto/cipher/#AEAD. 195 p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext) 196 if err != nil { 197 return 0, err 198 } 199 } 200 201 n = copy(b, p.buf) 202 p.buf = p.buf[n:] 203 return n, nil 204 } 205 206 // Write encrypts, frames, and writes bytes from b to the underlying connection. 207 func (p *conn) Write(b []byte) (n int, err error) { 208 n = len(b) 209 // Calculate the output buffer size with framing and encryption overhead. 210 numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit))) 211 size := len(b) + numOfFrames*p.overhead 212 // If writeBuf is too small, increase its size up to the maximum size. 213 partialBSize := len(b) 214 if size > altsWriteBufferMaxSize { 215 size = altsWriteBufferMaxSize 216 const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength 217 partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit 218 } 219 if len(p.writeBuf) < size { 220 p.writeBuf = make([]byte, size) 221 } 222 223 for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize { 224 partialBEnd := partialBStart + partialBSize 225 if partialBEnd > len(b) { 226 partialBEnd = len(b) 227 } 228 partialB := b[partialBStart:partialBEnd] 229 writeBufIndex := 0 230 for len(partialB) > 0 { 231 payloadLen := len(partialB) 232 if payloadLen > p.payloadLengthLimit { 233 payloadLen = p.payloadLengthLimit 234 } 235 buf := partialB[:payloadLen] 236 partialB = partialB[payloadLen:] 237 238 // Write buffer contains: length, type, payload, and tag 239 // if any. 240 241 // 1. Fill in type field. 242 msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:] 243 binary.LittleEndian.PutUint32(msg, altsRecordMsgType) 244 245 // 2. Encrypt the payload and create a tag if any. 246 msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf) 247 if err != nil { 248 return n, err 249 } 250 251 // 3. Fill in the size field. 252 binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg))) 253 254 // 4. Increase writeBufIndex. 255 writeBufIndex += len(buf) + p.overhead 256 } 257 nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex]) 258 if err != nil { 259 // We need to calculate the actual data size that was 260 // written. This means we need to remove header, 261 // encryption overheads, and any partially-written 262 // frame data. 263 numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength))) 264 return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err 265 } 266 } 267 return n, nil 268 } 269 270 func min(a, b int) int { 271 if a < b { 272 return a 273 } 274 return b 275 }