github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/packetio.go (about) 1 // Copyright 2020 The Go-MyALLEGROSQL-Driver Authors. All rights reserved. 2 // 3 // This Source Code Form is subject to the terms of the Mozilla Public 4 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 // You can obtain one at http://mozilla.org/MPL/2.0/. 6 7 // The MIT License (MIT) 8 // 9 // Copyright (c) 2020 wandoulabs 10 // Copyright (c) 2020 siddontang 11 // 12 // Permission is hereby granted, free of charge, to any person obtaining a copy of 13 // this software and associated documentation files (the "Software"), to deal in 14 // the Software without restriction, including without limitation the rights to 15 // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 16 // the Software, and to permit persons to whom the Software is furnished to do so, 17 // subject to the following conditions: 18 // 19 // The above copyright notice and this permission notice shall be included in all 20 // copies or substantial portions of the Software. 21 22 // Copyright 2020 WHTCORPS INC, Inc. 23 // 24 // Licensed under the Apache License, Version 2.0 (the "License"); 25 // you may not use this file except in compliance with the License. 26 // You may obtain a copy of the License at 27 // 28 // http://www.apache.org/licenses/LICENSE-2.0 29 // 30 // Unless required by applicable law or agreed to in writing, software 31 // distributed under the License is distributed on an "AS IS" BASIS, 32 // See the License for the specific language governing permissions and 33 // limitations under the License. 34 35 package server 36 37 import ( 38 "bufio" 39 "io" 40 "time" 41 42 "github.com/whtcorpsinc/errors" 43 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 44 "github.com/whtcorpsinc/BerolinaSQL/terror" 45 "github.com/whtcorpsinc/milevadb/metrics" 46 ) 47 48 const defaultWriterSize = 16 * 1024 49 50 var ( 51 readPacketBytes = metrics.PacketIOHistogram.WithLabelValues("read") 52 writePacketBytes = metrics.PacketIOHistogram.WithLabelValues("write") 53 ) 54 55 // packetIO is a helper to read and write data in packet format. 56 type packetIO struct { 57 bufReadConn *bufferedReadConn 58 bufWriter *bufio.Writer 59 sequence uint8 60 readTimeout time.Duration 61 } 62 63 func newPacketIO(bufReadConn *bufferedReadConn) *packetIO { 64 p := &packetIO{sequence: 0} 65 p.setBufferedReadConn(bufReadConn) 66 return p 67 } 68 69 func (p *packetIO) setBufferedReadConn(bufReadConn *bufferedReadConn) { 70 p.bufReadConn = bufReadConn 71 p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize) 72 } 73 74 func (p *packetIO) setReadTimeout(timeout time.Duration) { 75 p.readTimeout = timeout 76 } 77 78 func (p *packetIO) readOnePacket() ([]byte, error) { 79 var header [4]byte 80 if p.readTimeout > 0 { 81 if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil { 82 return nil, err 83 } 84 } 85 if _, err := io.ReadFull(p.bufReadConn, header[:]); err != nil { 86 return nil, errors.Trace(err) 87 } 88 89 sequence := header[3] 90 if sequence != p.sequence { 91 return nil, errInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) 92 } 93 94 p.sequence++ 95 96 length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) 97 98 data := make([]byte, length) 99 if p.readTimeout > 0 { 100 if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil { 101 return nil, err 102 } 103 } 104 if _, err := io.ReadFull(p.bufReadConn, data); err != nil { 105 return nil, errors.Trace(err) 106 } 107 return data, nil 108 } 109 110 func (p *packetIO) readPacket() ([]byte, error) { 111 data, err := p.readOnePacket() 112 if err != nil { 113 return nil, errors.Trace(err) 114 } 115 116 if len(data) < allegrosql.MaxPayloadLen { 117 readPacketBytes.Observe(float64(len(data))) 118 return data, nil 119 } 120 121 // handle multi-packet 122 for { 123 buf, err := p.readOnePacket() 124 if err != nil { 125 return nil, errors.Trace(err) 126 } 127 128 data = append(data, buf...) 129 130 if len(buf) < allegrosql.MaxPayloadLen { 131 break 132 } 133 } 134 135 readPacketBytes.Observe(float64(len(data))) 136 return data, nil 137 } 138 139 // writePacket writes data that already have header 140 func (p *packetIO) writePacket(data []byte) error { 141 length := len(data) - 4 142 writePacketBytes.Observe(float64(len(data))) 143 144 for length >= allegrosql.MaxPayloadLen { 145 data[0] = 0xff 146 data[1] = 0xff 147 data[2] = 0xff 148 149 data[3] = p.sequence 150 151 if n, err := p.bufWriter.Write(data[:4+allegrosql.MaxPayloadLen]); err != nil { 152 return errors.Trace(allegrosql.ErrBadConn) 153 } else if n != (4 + allegrosql.MaxPayloadLen) { 154 return errors.Trace(allegrosql.ErrBadConn) 155 } else { 156 p.sequence++ 157 length -= allegrosql.MaxPayloadLen 158 data = data[allegrosql.MaxPayloadLen:] 159 } 160 } 161 162 data[0] = byte(length) 163 data[1] = byte(length >> 8) 164 data[2] = byte(length >> 16) 165 data[3] = p.sequence 166 167 if n, err := p.bufWriter.Write(data); err != nil { 168 terror.Log(errors.Trace(err)) 169 return errors.Trace(allegrosql.ErrBadConn) 170 } else if n != len(data) { 171 return errors.Trace(allegrosql.ErrBadConn) 172 } else { 173 p.sequence++ 174 return nil 175 } 176 } 177 178 func (p *packetIO) flush() error { 179 err := p.bufWriter.Flush() 180 if err != nil { 181 return errors.Trace(err) 182 } 183 return err 184 }