github.com/annchain/OG@v0.0.9/p2p/raw_transport.go (about)

     1  // Copyright © 2019 Annchain Authors <EMAIL ADDRESS>
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package p2p
    15  
    16  import (
    17  	"bytes"
    18  	"crypto/ecdsa"
    19  	"crypto/rand"
    20  	"encoding/hex"
    21  	"errors"
    22  	"fmt"
    23  	"github.com/annchain/OG/arefactor/common/goroutine"
    24  	"github.com/annchain/OG/deprecated/ogcrypto"
    25  	"github.com/annchain/OG/deprecated/ogcrypto/ecies"
    26  	"github.com/annchain/OG/p2p/ioperformance"
    27  	"github.com/golang/snappy"
    28  	"io"
    29  	"io/ioutil"
    30  	"net"
    31  	"sync"
    32  	"time"
    33  )
    34  
    35  //go:generate msgp
    36  type rawTransport struct {
    37  	fd net.Conn
    38  	MsgReadWriter
    39  	rmu, wmu sync.Mutex
    40  	rw       *rawFrameRW
    41  }
    42  
    43  func newrawTransport(fd net.Conn) transport {
    44  	fd.SetDeadline(time.Now().Add(handshakeTimeout))
    45  	return &rawTransport{fd: fd}
    46  }
    47  
    48  func (t *rawTransport) ReadMsg() (Msg, error) {
    49  	t.rmu.Lock()
    50  	defer t.rmu.Unlock()
    51  	t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout))
    52  	return t.rw.ReadMsg()
    53  }
    54  
    55  func (t *rawTransport) WriteMsg(msg Msg) error {
    56  	t.wmu.Lock()
    57  	defer t.wmu.Unlock()
    58  	t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout))
    59  	return t.rw.WriteMsg(msg)
    60  }
    61  
    62  func (t *rawTransport) close(err error) {
    63  	t.wmu.Lock()
    64  	defer t.wmu.Unlock()
    65  	// Tell the remote end why we're disconnecting if possible.
    66  	if t.rw != nil {
    67  		if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
    68  			// rawTransport tries to send DiscReason to disconnected peer
    69  			// if the connection is net.Pipe (in-memory simulation)
    70  			// it hangs forever, since net.Pipe does not implement
    71  			// a write deadline. Because of this only try to send
    72  			// the disconnect reason message if there is no error.
    73  			if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil {
    74  				b, _ := r.MarshalMsg(nil)
    75  				Send(t.rw, discMsg, b)
    76  			}
    77  		}
    78  	}
    79  	t.fd.Close()
    80  }
    81  
    82  func (t *rawTransport) doProtoHandshake(our *ProtoHandshake) (their *ProtoHandshake, err error) {
    83  	// Writing our handshake happens concurrently, we prefer
    84  	// returning the handshake read error. If the remote side
    85  	// disconnects us early with a valid reason, we should return it
    86  	// as the error so it can be tracked elsewhere.
    87  	werr := make(chan error, 1)
    88  	b, _ := our.MarshalMsg(nil)
    89  	goroutine.New(func() { werr <- Send(t.rw, handshakeMsg, b) })
    90  	if their, err = readProtocolHandshake(t.rw); err != nil {
    91  		<-werr // make sure the write terminates too
    92  		return nil, err
    93  	}
    94  	if err := <-werr; err != nil {
    95  		return nil, fmt.Errorf("write error: %v", err)
    96  	}
    97  	// If the protocol version supports Snappy encoding, upgrade immediately
    98  	t.rw.snappy = their.Version >= snappyProtocolVersion
    99  
   100  	return their, nil
   101  }
   102  
   103  //msgp:tuple RawHandshakeMsg
   104  type RawHandshakeMsg struct {
   105  	Signature [sigLen]byte
   106  	Nonce     [shaLen]byte
   107  	Version   uint32
   108  }
   109  
   110  // RLPx v4 handshake response (defined in EIP-8).
   111  //msgp:tuple RawHandshakeResponseMsg
   112  type RawHandshakeResponseMsg struct {
   113  	//RemotePubkey [pubLen]byte
   114  	Nonce [shaLen]byte
   115  	//Signature    [sigLen]byte
   116  	Version uint32
   117  }
   118  
   119  // rawHandshake contains the state of the encryption handshake.
   120  type rawHandshake struct {
   121  	initiator            bool
   122  	remote               *ecies.PublicKey // remote-pubk
   123  	initNonce, respNonce []byte           // nonce
   124  }
   125  
   126  func initiatorRawencHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remote *ecdsa.PublicKey) (err error) {
   127  	h := &rawHandshake{initiator: true, remote: ecies.ImportECDSAPublic(remote)}
   128  	h.initNonce = make([]byte, shaLen)
   129  	_, err = rand.Read(h.initNonce)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	msg := &RawHandshakeMsg{
   134  		Version: 1,
   135  	}
   136  	copy(msg.Nonce[:], h.initNonce)
   137  	signature, err := ogcrypto.Sign(h.initNonce, prv)
   138  	if err != nil {
   139  		return err
   140  	}
   141  	copy(msg.Signature[:], signature)
   142  	copy(msg.Nonce[:], h.initNonce)
   143  	buf := new(bytes.Buffer)
   144  	b, err := msg.MarshalMsg(nil)
   145  	if err != nil {
   146  		log.WithError(err).Debug("marshal failed")
   147  		return err
   148  	}
   149  	buf.Write(b)
   150  	enc, err := ecies.Encrypt(rand.Reader, h.remote, buf.Bytes(), nil, nil)
   151  	if err != nil {
   152  		log.WithError(err).Debug("enc failed")
   153  		return err
   154  	}
   155  	head := make([]byte, 3)
   156  	putInt24(uint32(len(enc)), head)
   157  	if _, err = conn.Write(head); err != nil {
   158  		log.WithError(err).Debug("write failed")
   159  		return err
   160  	}
   161  	//log.WithField("write len ", len(enc)).WithField("buf size ", len(b)).WithField("msg size ", msg.Msgsize()).Debug("write")
   162  	if _, err = conn.Write(enc); err != nil {
   163  		log.WithError(err).Debug("write failed")
   164  		return err
   165  	}
   166  	log.WithField("nonce ", hex.EncodeToString(msg.Nonce[:])).WithField(
   167  		"sig ", hex.EncodeToString(msg.Signature[:])).Trace("write msg")
   168  	authRespMsg := new(RawHandshakeResponseMsg)
   169  	err = readRawHandshakeMsgResp(authRespMsg, prv, conn)
   170  	if err != nil {
   171  		log.WithError(err).Debug("read response failed")
   172  		return err
   173  	}
   174  	h.respNonce = authRespMsg.Nonce[:]
   175  	return nil
   176  }
   177  
   178  func readRawHandshakeMsgResp(msg *RawHandshakeResponseMsg, prv *ecdsa.PrivateKey, r io.Reader) error {
   179  	head := make([]byte, 3)
   180  	if _, err := io.ReadFull(r, head); err != nil {
   181  		log.WithError(err).Debug("read failed")
   182  		return err
   183  	}
   184  	size := readInt24(head)
   185  	if size == 0 {
   186  		return fmt.Errorf("size error")
   187  	}
   188  
   189  	buf := make([]byte, size)
   190  	if n, err := io.ReadFull(r, buf); err != nil {
   191  		log.WithError(err).WithField("n ", n).WithField("size", size).Debug("read failed")
   192  		return err
   193  	}
   194  	key := ecies.ImportECDSA(prv)
   195  	dec, err := key.Decrypt(buf, nil, nil)
   196  	if err != nil {
   197  		log.WithError(err).Debug("dec failed")
   198  		return err
   199  	}
   200  	_, err = msg.UnmarshalMsg(dec)
   201  	//	fmt.Println("is plain",msg)
   202  	return nil
   203  }
   204  
   205  func readRawHandshakeMsg(msg *RawHandshakeMsg, prv *ecdsa.PrivateKey, r io.Reader) error {
   206  	head := make([]byte, 3)
   207  	if _, err := io.ReadFull(r, head); err != nil {
   208  		log.WithError(err).Debug("read failed")
   209  		return err
   210  	}
   211  	size := readInt24(head)
   212  	if size == 0 {
   213  		return fmt.Errorf("size error")
   214  	}
   215  	buf := make([]byte, size)
   216  	if n, err := io.ReadFull(r, buf); err != nil {
   217  		log.WithError(err).WithField("n ", n).WithField("size", size).Debug("read failed")
   218  		return err
   219  	}
   220  
   221  	key := ecies.ImportECDSA(prv)
   222  	dec, err := key.Decrypt(buf, nil, nil)
   223  	if err != nil {
   224  		log.WithField("len ", len(buf)).WithField("hex ", hex.EncodeToString(buf)).WithError(err).Debug("dec failed")
   225  		return err
   226  	}
   227  	_, err = msg.UnmarshalMsg(dec)
   228  	//	fmt.Println("is plain",msg)
   229  	return nil
   230  }
   231  
   232  func receiverRawHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) {
   233  	authMsg := new(RawHandshakeMsg)
   234  	err := readRawHandshakeMsg(authMsg, prv, conn)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	//log.WithField("nonce ", hex.EncodeToString(authMsg.Nonce[:])).WithField(
   239  	//"sig ", hex.EncodeToString(authMsg.Signature[:])).Trace("read msg")
   240  	rPub, err := ogcrypto.Ecrecover(authMsg.Nonce[:], authMsg.Signature[:])
   241  	if err != nil {
   242  		return nil, fmt.Errorf("sig invalid %v", err)
   243  	}
   244  	h := new(rawHandshake)
   245  	rpub, err := importPublicKey(rPub[1:])
   246  	if err != nil {
   247  		log.WithError(err).Debug("handle authMsg failed")
   248  		return nil, err
   249  	}
   250  	h.initNonce = authMsg.Nonce[:]
   251  	h.remote = rpub
   252  
   253  	h.respNonce = make([]byte, shaLen)
   254  	if _, err = rand.Read(h.respNonce); err != nil {
   255  		return nil, err
   256  	}
   257  	msg := new(RawHandshakeResponseMsg)
   258  	copy(msg.Nonce[:], h.respNonce)
   259  	msg.Version = 4
   260  	buf := new(bytes.Buffer)
   261  	b, err := msg.MarshalMsg(nil)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  	buf.Write(b)
   266  	enc, err := ecies.Encrypt(rand.Reader, h.remote, buf.Bytes(), nil, nil)
   267  	if err != nil {
   268  		log.WithError(err).Debug("enc failed")
   269  		return nil, err
   270  	}
   271  	head := make([]byte, 3)
   272  	putInt24(uint32(len(enc)), head)
   273  	if _, err = conn.Write(head); err != nil {
   274  		return nil, err
   275  	}
   276  	if _, err = conn.Write(enc); err != nil {
   277  		return nil, err
   278  	}
   279  	return h.remote.ExportECDSA(), nil
   280  }
   281  
   282  // doEncHandshake runs the protocol handshake using authenticated
   283  // messages. the protocol handshake is the first authenticated message
   284  // and also verifies whether the encryption handshake 'worked' and the
   285  // remote side actually provided the right public key.
   286  func (t *rawTransport) doEncHandshake(prv *ecdsa.PrivateKey, dial *ecdsa.PublicKey) (*ecdsa.PublicKey, error) {
   287  	var (
   288  		pub *ecdsa.PublicKey
   289  		err error
   290  	)
   291  	if dial == nil {
   292  		pub, err = receiverRawHandshake(t.fd, prv)
   293  	} else {
   294  		err = initiatorRawencHandshake(t.fd, prv, dial)
   295  		pub = dial
   296  	}
   297  	if err != nil {
   298  		log.WithError(err).Debug("handshake error")
   299  		return nil, err
   300  	}
   301  	t.wmu.Lock()
   302  	t.rw = newRawFrameRW(t.fd)
   303  	t.wmu.Unlock()
   304  	return pub, nil
   305  }
   306  
   307  // rawFrameRW implements a simplified version of RLPx framing.
   308  // chunked messages are not supported and all headers are equal to
   309  // zeroHeader.
   310  //
   311  // rawFrameRW is not safe for concurrent use from multiple goroutines.
   312  type rawFrameRW struct {
   313  	conn   io.ReadWriter
   314  	snappy bool
   315  }
   316  
   317  func newRawFrameRW(conn io.ReadWriter) *rawFrameRW {
   318  	return &rawFrameRW{
   319  		conn: conn,
   320  	}
   321  }
   322  
   323  func (rw *rawFrameRW) WriteMsg(msg Msg) error {
   324  	//ptype, _ := rlp.EncodeToBytes(msg.Code)
   325  	ptype, _ := msg.Code.MarshalMsg(nil)
   326  	var payload []byte
   327  	// if snappy is enabled, compress message now
   328  	if rw.snappy {
   329  		if msg.Size > maxUint24 {
   330  			return errPlainMessageTooLarge
   331  		}
   332  		payload, _ = msg.GetPayLoad()
   333  		payload = snappy.Encode(nil, payload)
   334  
   335  		msg.Payload = bytes.NewReader(payload)
   336  		msg.Size = uint32(len(payload))
   337  	} else {
   338  		payload, _ = msg.GetPayLoad()
   339  	}
   340  
   341  	// write header
   342  	headbuf := make([]byte, 6)
   343  	fsize := uint32(len(ptype)) + msg.Size
   344  	if fsize > maxUint24 {
   345  		return errors.New("message size overflows uint24")
   346  	}
   347  	putInt24(fsize, headbuf) // TODO: check overflow
   348  	copy(headbuf[3:], zeroHeader)
   349  	// write header
   350  	if _, err := rw.conn.Write(headbuf); err != nil {
   351  		return err
   352  	}
   353  	if _, err := rw.conn.Write(ptype); err != nil {
   354  		return err
   355  	}
   356  	if _, err := rw.conn.Write(payload); err != nil {
   357  		return err
   358  	}
   359  
   360  	ioperformance.AddSendSize(len(headbuf) + len(ptype) + int(msg.Size))
   361  	return nil
   362  }
   363  
   364  func (rw *rawFrameRW) ReadMsg() (msg Msg, err error) {
   365  	// read the header
   366  	headbuf := make([]byte, 6)
   367  	if _, err := io.ReadFull(rw.conn, headbuf); err != nil {
   368  		return msg, err
   369  	}
   370  	// verify header
   371  	fsize := readInt24(headbuf)
   372  	// ignore protocol type for now
   373  	framebuf := make([]byte, fsize)
   374  	if _, err := io.ReadFull(rw.conn, framebuf); err != nil {
   375  		return msg, err
   376  	}
   377  	out, err := msg.Code.UnmarshalMsg(framebuf[:fsize])
   378  	if err != nil {
   379  		return msg, err
   380  	}
   381  	content := bytes.NewReader(out)
   382  	msg.Size = uint32(content.Len())
   383  	msg.Payload = content
   384  	// if snappy is enabled, verify and decompress message
   385  	if rw.snappy {
   386  		payload, err := ioutil.ReadAll(msg.Payload)
   387  		if err != nil {
   388  			return msg, err
   389  		}
   390  		size, err := snappy.DecodedLen(payload)
   391  		if err != nil {
   392  			return msg, err
   393  		}
   394  		if size > int(maxUint24) {
   395  			return msg, errPlainMessageTooLarge
   396  		}
   397  		payload, err = snappy.Decode(nil, payload)
   398  		if err != nil {
   399  			return msg, err
   400  		}
   401  		msg.Size, msg.Payload = uint32(size), bytes.NewReader(payload)
   402  	}
   403  
   404  	ioperformance.AddRecvSize(int(msg.Size) + len(headbuf) + 3)
   405  
   406  	return msg, nil
   407  }