github.com/aergoio/aergo@v1.3.1/p2p/v030/v030io.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package v030
     7  
     8  import (
     9  	"bufio"
    10  	"encoding/binary"
    11  	"fmt"
    12  	"io"
    13  
    14  	"github.com/aergoio/aergo/p2p/p2pcommon"
    15  )
    16  
    17  const msgHeaderLength int = 48
    18  
    19  type V030ReadWriter struct {
    20  	r        *bufio.Reader
    21  	readBuf  [msgHeaderLength]byte
    22  	w        *bufio.Writer
    23  	writeBuf [msgHeaderLength]byte
    24  	c        io.Closer
    25  
    26  	ls []p2pcommon.MsgIOListener
    27  }
    28  
    29  func NewV030MsgPipe(s io.ReadWriteCloser) *V030ReadWriter {
    30  	return NewV030ReadWriter(s, s, s)
    31  }
    32  func NewV030ReadWriter(r io.Reader, w io.Writer, c io.Closer) *V030ReadWriter {
    33  	br, ok := r.(*bufio.Reader)
    34  	if !ok {
    35  		br = bufio.NewReader(r)
    36  	}
    37  	bw, ok := w.(*bufio.Writer)
    38  	if !ok {
    39  		bw = bufio.NewWriter(w)
    40  	}
    41  	return &V030ReadWriter{
    42  		r: br,
    43  		w: bw,
    44  		c: c,
    45  	}
    46  }
    47  
    48  func (rw *V030ReadWriter) Close() error {
    49  	return rw.c.Close()
    50  }
    51  
    52  func (rw *V030ReadWriter) AddIOListener(l p2pcommon.MsgIOListener) {
    53  	rw.ls = append(rw.ls, l)
    54  }
    55  
    56  // ReadMsg() must be used in single thread
    57  func (rw *V030ReadWriter) ReadMsg() (p2pcommon.Message, error) {
    58  	readN := 0
    59  	// fill data
    60  	read, err := rw.readToLen(rw.readBuf[:], msgHeaderLength)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	readN += read
    65  	if read != msgHeaderLength {
    66  		return nil, fmt.Errorf("invalid msgHeader")
    67  	}
    68  
    69  	msg, bodyLen := parseHeader(rw.readBuf)
    70  	if bodyLen > p2pcommon.MaxPayloadLength {
    71  		return nil, fmt.Errorf("too big payload")
    72  	}
    73  	payload := make([]byte, bodyLen)
    74  	read, err = rw.readToLen(payload, int(bodyLen))
    75  	if err != nil {
    76  		return nil, fmt.Errorf("failed to read paylod of msg %s %s : %s", msg.Subprotocol().String(), msg.ID(), err.Error())
    77  	}
    78  	readN += read
    79  	if read != int(bodyLen) {
    80  		return nil, fmt.Errorf("failed to read paylod of msg %s %s : payload length mismatch", msg.Subprotocol().String(), msg.ID())
    81  	}
    82  
    83  	msg.SetPayload(payload)
    84  	for _, l := range rw.ls {
    85  		l.OnRead(msg.Subprotocol(), readN)
    86  	}
    87  	return msg, nil
    88  }
    89  
    90  func (rw *V030ReadWriter) readToLen(bf []byte, max int) (int, error) {
    91  	remain := max
    92  	offset := 0
    93  	for remain > 0 {
    94  		read, err := rw.r.Read(bf[offset:])
    95  		if err != nil {
    96  			return offset, err
    97  		}
    98  		remain -= read
    99  		offset += read
   100  	}
   101  	return offset, nil
   102  }
   103  
   104  // WriteMsg() must be used in single thread
   105  func (rw *V030ReadWriter) WriteMsg(msg p2pcommon.Message) error {
   106  	writeN := 0
   107  	if msg.Length() != uint32(len(msg.Payload())) {
   108  		return fmt.Errorf("Invalid payload size")
   109  	}
   110  	if msg.Length() > p2pcommon.MaxPayloadLength {
   111  		return fmt.Errorf("too big payload")
   112  	}
   113  
   114  	rw.marshalHeader(msg)
   115  	written, err := rw.w.Write(rw.writeBuf[:])
   116  	if err != nil {
   117  		return err
   118  	}
   119  	writeN += written
   120  	if written != msgHeaderLength {
   121  		return fmt.Errorf("header is not written")
   122  	}
   123  	written, err = rw.w.Write(msg.Payload())
   124  	if err != nil {
   125  		return err
   126  	}
   127  	writeN += written
   128  	if written != int(msg.Length()) {
   129  		return fmt.Errorf("wrong write")
   130  	}
   131  	for _, l := range rw.ls {
   132  		l.OnWrite(msg.Subprotocol(), writeN)
   133  	}
   134  	return rw.w.Flush()
   135  }
   136  
   137  func parseHeader(buf [msgHeaderLength]byte) (*p2pcommon.MessageValue, uint32) {
   138  	subProtocol := p2pcommon.SubProtocol(binary.BigEndian.Uint32(buf[0:4]))
   139  	length := binary.BigEndian.Uint32(buf[4:8])
   140  	timestamp := int64(binary.BigEndian.Uint64(buf[8:16]))
   141  	msgID := p2pcommon.MustParseBytes(buf[16:32])
   142  	orgID := p2pcommon.MustParseBytes(buf[32:48])
   143  	return p2pcommon.NewLiteMessageValue(subProtocol, msgID, orgID, timestamp), length
   144  }
   145  
   146  func (rw *V030ReadWriter) marshalHeader(m p2pcommon.Message) {
   147  	binary.BigEndian.PutUint32(rw.writeBuf[0:4], m.Subprotocol().Uint32())
   148  	binary.BigEndian.PutUint32(rw.writeBuf[4:8], m.Length())
   149  	binary.BigEndian.PutUint64(rw.writeBuf[8:16], uint64(m.Timestamp()))
   150  
   151  	msgID := m.ID()
   152  	copy(rw.writeBuf[16:32], msgID[:])
   153  	msgID = m.OriginalID()
   154  	copy(rw.writeBuf[32:48], msgID[:])
   155  }