github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/cshirt2/transport.go (about)

     1  package cshirt2
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/cipher"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"time"
    12  
    13  	"github.com/geph-official/geph2/libs/erand"
    14  	pool "github.com/libp2p/go-buffer-pool"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  )
    17  
    18  type transport struct {
    19  	readCrypt  cipher.AEAD
    20  	readNonce  uint64
    21  	writeCrypt cipher.AEAD
    22  	writeNonce uint64
    23  	wireBuf    *bufio.Reader
    24  	wire       net.Conn
    25  	readbuf    bytes.Buffer
    26  
    27  	buf [128]byte
    28  }
    29  
    30  const maxWriteSize = 16384
    31  
    32  func (tp *transport) getiWriteNonce() uint64 {
    33  	n := tp.writeNonce
    34  	tp.writeNonce++
    35  	return n
    36  }
    37  
    38  func (tp *transport) getiReadNonce() uint64 {
    39  	n := tp.readNonce
    40  	tp.readNonce++
    41  	return n
    42  }
    43  
    44  func addPadding(appendTo, data []byte) []byte {
    45  	padAmount := erand.Int(200)
    46  	appendTo = append(appendTo, byte(padAmount))
    47  	appendTo = append(appendTo, make([]byte, padAmount)...)
    48  	appendTo = append(appendTo, data...)
    49  	return appendTo
    50  	//return data
    51  }
    52  
    53  func remPadding(data []byte) []byte {
    54  	if len(data) > 0 && len(data) > int(data[0]) {
    55  		return data[data[0]+1:]
    56  	}
    57  	return data
    58  }
    59  
    60  func (tp *transport) writeSegment(unpadded []byte) (err error) {
    61  	toWrite := pool.Get(len(unpadded) + 256)[:0]
    62  	defer pool.Put(toWrite)
    63  	toWrite = addPadding(toWrite, unpadded)
    64  	lengthNonce := tp.buf[:12]
    65  	binary.LittleEndian.PutUint64(lengthNonce, tp.getiWriteNonce())
    66  	bodyNonce := tp.buf[12:24]
    67  	binary.LittleEndian.PutUint64(bodyNonce, tp.getiWriteNonce())
    68  	length := tp.buf[24:26]
    69  	binary.LittleEndian.PutUint16(length, uint16(len(toWrite)+tp.readCrypt.Overhead()))
    70  	buffer := pool.Get(maxWriteSize + 128)[:0]
    71  	defer pool.Put(buffer)
    72  	buffer = tp.writeCrypt.Seal(buffer, lengthNonce, length, nil)
    73  	buffer = tp.writeCrypt.Seal(buffer, bodyNonce, toWrite, nil)
    74  	_, err = tp.wire.Write(buffer)
    75  	return
    76  }
    77  
    78  func (tp *transport) Write(p []byte) (n int, err error) {
    79  	ptr := p
    80  	for len(ptr) > maxWriteSize {
    81  		err = tp.writeSegment(ptr[:maxWriteSize])
    82  		if err != nil {
    83  			return
    84  		}
    85  		ptr = ptr[maxWriteSize:]
    86  	}
    87  	err = tp.writeSegment(ptr)
    88  	if err != nil {
    89  		return
    90  	}
    91  	n = len(p)
    92  	return
    93  }
    94  
    95  func (tp *transport) Read(p []byte) (n int, err error) {
    96  	for tp.readbuf.Len() == 0 {
    97  		cryptLength := tp.buf[64:][:2+tp.readCrypt.Overhead()]
    98  		_, err = io.ReadFull(tp.wireBuf, cryptLength)
    99  		if err != nil {
   100  			err = fmt.Errorf("can't read encrypted length: %w", err)
   101  			return
   102  		}
   103  		nonce := tp.buf[32:][:12]
   104  		binary.LittleEndian.PutUint64(nonce, tp.getiReadNonce())
   105  		// decrypt length
   106  		var length []byte
   107  		length, err = tp.readCrypt.Open(p[:0], nonce, cryptLength, nil)
   108  		if err != nil {
   109  			err = fmt.Errorf("can't decrypt length: %w", err)
   110  			return
   111  		}
   112  		// read body
   113  		ctext := pool.Get(int(binary.LittleEndian.Uint16(length)))
   114  		defer pool.Put(ctext)
   115  		_, err = io.ReadFull(tp.wireBuf, ctext)
   116  		if err != nil {
   117  			err = fmt.Errorf("can't read body: %w", err)
   118  			return
   119  		}
   120  		binary.LittleEndian.PutUint64(nonce, tp.getiReadNonce())
   121  		var ptext []byte
   122  		ptext, err = tp.readCrypt.Open(ctext[:0], nonce, ctext, nil)
   123  		if err != nil {
   124  			err = fmt.Errorf("can't decrypt body: %w", err)
   125  			return
   126  		}
   127  		tp.readbuf.Write(remPadding(ptext))
   128  	}
   129  	return tp.readbuf.Read(p)
   130  }
   131  
   132  func (tp *transport) Close() error {
   133  	return tp.wire.Close()
   134  }
   135  
   136  func (tp *transport) LocalAddr() net.Addr {
   137  	return tp.wire.LocalAddr()
   138  }
   139  
   140  func (tp *transport) RemoteAddr() net.Addr {
   141  	return tp.wire.RemoteAddr()
   142  }
   143  
   144  func (tp *transport) SetDeadline(t time.Time) error {
   145  	return tp.wire.SetDeadline(t)
   146  }
   147  
   148  func (tp *transport) SetWriteDeadline(t time.Time) error {
   149  	return tp.wire.SetWriteDeadline(t)
   150  }
   151  
   152  func (tp *transport) SetReadDeadline(t time.Time) error {
   153  	return tp.wire.SetReadDeadline(t)
   154  }
   155  
   156  func newTransport(wire net.Conn, ss []byte, isServer bool) *transport {
   157  	tp := new(transport)
   158  	readKey := mac256(ss, []byte("c2s"))
   159  	writeKey := mac256(ss, []byte("c2c"))
   160  	if !isServer {
   161  		readKey, writeKey = writeKey, readKey
   162  	}
   163  	var err error
   164  	tp.readCrypt, err = chacha20poly1305.New(readKey)
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	tp.writeCrypt, err = chacha20poly1305.New(writeKey)
   169  	if err != nil {
   170  		panic(err)
   171  	}
   172  	tp.wire = wire
   173  	tp.wireBuf = bufio.NewReader(wire)
   174  	return tp
   175  }