github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/packetio.go (about)

     1  // Copyright 2020 The Go-MyALLEGROSQL-Driver Authors. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     5  // You can obtain one at http://mozilla.org/MPL/2.0/.
     6  
     7  // The MIT License (MIT)
     8  //
     9  // Copyright (c) 2020 wandoulabs
    10  // Copyright (c) 2020 siddontang
    11  //
    12  // Permission is hereby granted, free of charge, to any person obtaining a copy of
    13  // this software and associated documentation files (the "Software"), to deal in
    14  // the Software without restriction, including without limitation the rights to
    15  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
    16  // the Software, and to permit persons to whom the Software is furnished to do so,
    17  // subject to the following conditions:
    18  //
    19  // The above copyright notice and this permission notice shall be included in all
    20  // copies or substantial portions of the Software.
    21  
    22  // Copyright 2020 WHTCORPS INC, Inc.
    23  //
    24  // Licensed under the Apache License, Version 2.0 (the "License");
    25  // you may not use this file except in compliance with the License.
    26  // You may obtain a copy of the License at
    27  //
    28  //     http://www.apache.org/licenses/LICENSE-2.0
    29  //
    30  // Unless required by applicable law or agreed to in writing, software
    31  // distributed under the License is distributed on an "AS IS" BASIS,
    32  // See the License for the specific language governing permissions and
    33  // limitations under the License.
    34  
    35  package server
    36  
    37  import (
    38  	"bufio"
    39  	"io"
    40  	"time"
    41  
    42  	"github.com/whtcorpsinc/errors"
    43  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    44  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    45  	"github.com/whtcorpsinc/milevadb/metrics"
    46  )
    47  
    48  const defaultWriterSize = 16 * 1024
    49  
    50  var (
    51  	readPacketBytes  = metrics.PacketIOHistogram.WithLabelValues("read")
    52  	writePacketBytes = metrics.PacketIOHistogram.WithLabelValues("write")
    53  )
    54  
    55  // packetIO is a helper to read and write data in packet format.
    56  type packetIO struct {
    57  	bufReadConn *bufferedReadConn
    58  	bufWriter   *bufio.Writer
    59  	sequence    uint8
    60  	readTimeout time.Duration
    61  }
    62  
    63  func newPacketIO(bufReadConn *bufferedReadConn) *packetIO {
    64  	p := &packetIO{sequence: 0}
    65  	p.setBufferedReadConn(bufReadConn)
    66  	return p
    67  }
    68  
    69  func (p *packetIO) setBufferedReadConn(bufReadConn *bufferedReadConn) {
    70  	p.bufReadConn = bufReadConn
    71  	p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize)
    72  }
    73  
    74  func (p *packetIO) setReadTimeout(timeout time.Duration) {
    75  	p.readTimeout = timeout
    76  }
    77  
    78  func (p *packetIO) readOnePacket() ([]byte, error) {
    79  	var header [4]byte
    80  	if p.readTimeout > 0 {
    81  		if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
    82  			return nil, err
    83  		}
    84  	}
    85  	if _, err := io.ReadFull(p.bufReadConn, header[:]); err != nil {
    86  		return nil, errors.Trace(err)
    87  	}
    88  
    89  	sequence := header[3]
    90  	if sequence != p.sequence {
    91  		return nil, errInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence)
    92  	}
    93  
    94  	p.sequence++
    95  
    96  	length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
    97  
    98  	data := make([]byte, length)
    99  	if p.readTimeout > 0 {
   100  		if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
   101  			return nil, err
   102  		}
   103  	}
   104  	if _, err := io.ReadFull(p.bufReadConn, data); err != nil {
   105  		return nil, errors.Trace(err)
   106  	}
   107  	return data, nil
   108  }
   109  
   110  func (p *packetIO) readPacket() ([]byte, error) {
   111  	data, err := p.readOnePacket()
   112  	if err != nil {
   113  		return nil, errors.Trace(err)
   114  	}
   115  
   116  	if len(data) < allegrosql.MaxPayloadLen {
   117  		readPacketBytes.Observe(float64(len(data)))
   118  		return data, nil
   119  	}
   120  
   121  	// handle multi-packet
   122  	for {
   123  		buf, err := p.readOnePacket()
   124  		if err != nil {
   125  			return nil, errors.Trace(err)
   126  		}
   127  
   128  		data = append(data, buf...)
   129  
   130  		if len(buf) < allegrosql.MaxPayloadLen {
   131  			break
   132  		}
   133  	}
   134  
   135  	readPacketBytes.Observe(float64(len(data)))
   136  	return data, nil
   137  }
   138  
   139  // writePacket writes data that already have header
   140  func (p *packetIO) writePacket(data []byte) error {
   141  	length := len(data) - 4
   142  	writePacketBytes.Observe(float64(len(data)))
   143  
   144  	for length >= allegrosql.MaxPayloadLen {
   145  		data[0] = 0xff
   146  		data[1] = 0xff
   147  		data[2] = 0xff
   148  
   149  		data[3] = p.sequence
   150  
   151  		if n, err := p.bufWriter.Write(data[:4+allegrosql.MaxPayloadLen]); err != nil {
   152  			return errors.Trace(allegrosql.ErrBadConn)
   153  		} else if n != (4 + allegrosql.MaxPayloadLen) {
   154  			return errors.Trace(allegrosql.ErrBadConn)
   155  		} else {
   156  			p.sequence++
   157  			length -= allegrosql.MaxPayloadLen
   158  			data = data[allegrosql.MaxPayloadLen:]
   159  		}
   160  	}
   161  
   162  	data[0] = byte(length)
   163  	data[1] = byte(length >> 8)
   164  	data[2] = byte(length >> 16)
   165  	data[3] = p.sequence
   166  
   167  	if n, err := p.bufWriter.Write(data); err != nil {
   168  		terror.Log(errors.Trace(err))
   169  		return errors.Trace(allegrosql.ErrBadConn)
   170  	} else if n != len(data) {
   171  		return errors.Trace(allegrosql.ErrBadConn)
   172  	} else {
   173  		p.sequence++
   174  		return nil
   175  	}
   176  }
   177  
   178  func (p *packetIO) flush() error {
   179  	err := p.bufWriter.Flush()
   180  	if err != nil {
   181  		return errors.Trace(err)
   182  	}
   183  	return err
   184  }