github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/tinyss/socket.go (about)

     1  package tinyss
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/cipher"
     7  	"crypto/hmac"
     8  	"crypto/sha256"
     9  	"encoding/binary"
    10  	"io"
    11  	"net"
    12  	"time"
    13  
    14  	"github.com/geph-official/geph2/libs/c25519"
    15  	pool "github.com/libp2p/go-buffer-pool"
    16  	"golang.org/x/crypto/chacha20poly1305"
    17  	"golang.org/x/crypto/curve25519"
    18  )
    19  
    20  // Socket represents a TinySS connection; it implements net.Conn but with more methods.
    21  type Socket struct {
    22  	rxctr   uint64
    23  	rxerr   error
    24  	rxcrypt cipher.AEAD
    25  	rxbuf   bytes.Buffer
    26  
    27  	txctr   uint64
    28  	txcrypt cipher.AEAD
    29  
    30  	plain         net.Conn
    31  	plainBuffered *bufio.Reader
    32  	sharedsec     []byte
    33  
    34  	nextprot byte
    35  }
    36  
    37  func hm(m, k []byte) []byte {
    38  	h := hmac.New(sha256.New, k)
    39  	h.Write(m)
    40  	return h.Sum(nil)
    41  }
    42  
    43  func aead(key []byte) cipher.AEAD {
    44  	k, e := chacha20poly1305.New(key)
    45  	if e != nil {
    46  		panic(e)
    47  	}
    48  	return k
    49  }
    50  
    51  func newSocket(plain net.Conn, repk, lesk [32]byte) (sok *Socket) {
    52  	// calc
    53  	var lepk [32]byte
    54  	curve25519.ScalarBaseMult(&lepk, &lesk)
    55  	// calculate shared secrets
    56  	var sharedsec [32]byte
    57  	curve25519.ScalarMult(&sharedsec, &lesk, &repk)
    58  	s1 := hm(sharedsec[:], []byte("tinyss-s1"))
    59  	s2 := hm(sharedsec[:], []byte("tinyss-s2"))
    60  	// derive keys
    61  	var rxkey []byte
    62  	var txkey []byte
    63  	if bytes.Compare(lepk[:], repk[:]) < 0 {
    64  		rxkey = s1
    65  		txkey = s2
    66  	} else {
    67  		txkey = s1
    68  		rxkey = s2
    69  	}
    70  	// create socket
    71  	sok = &Socket{
    72  		rxcrypt:       aead(rxkey),
    73  		txcrypt:       aead(txkey),
    74  		plain:         plain,
    75  		sharedsec:     sharedsec[:],
    76  		plainBuffered: bufio.NewReader(plain),
    77  	}
    78  
    79  	return
    80  }
    81  
    82  var decctr1 uint64
    83  
    84  // NextProt returns the "next protocol" signal given by the remote.
    85  func (sk *Socket) NextProt() byte {
    86  	return sk.nextprot
    87  }
    88  
    89  // Read reads into the given byte slice.
    90  func (sk *Socket) Read(p []byte) (n int, err error) {
    91  	// if any in buffer, read from buffer
    92  	if sk.rxbuf.Len() > 0 {
    93  		return sk.rxbuf.Read(p)
    94  	}
    95  	// if error exists, return it
    96  	err = sk.rxerr
    97  	if err != nil {
    98  		return
    99  	}
   100  	// otherwise wait for record
   101  	lenbts := pool.GlobalPool.Get(2)
   102  	defer pool.GlobalPool.Put(lenbts)
   103  	_, err = io.ReadFull(sk.plainBuffered, lenbts)
   104  	if err != nil {
   105  		sk.rxerr = err
   106  		return
   107  	}
   108  	ciph := pool.GlobalPool.Get(int(binary.BigEndian.Uint16(lenbts)))
   109  	defer pool.GlobalPool.Put(ciph)
   110  	_, err = io.ReadFull(sk.plainBuffered, ciph)
   111  	if err != nil {
   112  		sk.rxerr = err
   113  		return
   114  	}
   115  	// decrypt the ciphertext
   116  	nonce := pool.GlobalPool.Get(sk.rxcrypt.NonceSize())
   117  	for i := range nonce {
   118  		nonce[i] = 0
   119  	}
   120  	defer pool.GlobalPool.Put(nonce)
   121  	binary.BigEndian.PutUint64(nonce, sk.rxctr)
   122  	sk.rxctr++
   123  	data, err := sk.rxcrypt.Open(ciph[:0], nonce, ciph, nil)
   124  	if err != nil {
   125  		sk.rxerr = err
   126  		return
   127  	}
   128  	// copy the data into the buffer
   129  	n = copy(p, data)
   130  	if n < len(data) {
   131  		sk.rxbuf.Write(data[n:])
   132  	}
   133  	return
   134  }
   135  
   136  // Write writes out the given byte slice. No guarantees are made regarding the number of low-level segments sent over the wire.
   137  func (sk *Socket) Write(p []byte) (n int, err error) {
   138  	if len(p) > 32768 {
   139  		// recurse
   140  		var n1 int
   141  		var n2 int
   142  		n1, err = sk.Write(p[:32768])
   143  		if err != nil {
   144  			return
   145  		}
   146  		n2, err = sk.Write(p[32768:])
   147  		if err != nil {
   148  			return
   149  		}
   150  		n = n1 + n2
   151  		return
   152  	}
   153  	// main work here
   154  	backing := pool.GlobalPool.Get(sk.txcrypt.Overhead() + 2 + len(p))
   155  	nonce := pool.GlobalPool.Get(sk.txcrypt.NonceSize())
   156  	defer pool.GlobalPool.Put(nonce)
   157  	for i := range nonce {
   158  		nonce[i] = 0
   159  	}
   160  	binary.BigEndian.PutUint64(nonce, sk.txctr)
   161  	sk.txctr++
   162  	ciph := sk.txcrypt.Seal(backing[2:][:0], nonce, p, nil)
   163  	binary.BigEndian.PutUint16(backing[:2], uint16(len(ciph)))
   164  	_, err = sk.plain.Write(backing)
   165  	n = len(p)
   166  	return
   167  }
   168  
   169  // Close closes the socket.
   170  func (sk *Socket) Close() error {
   171  	return sk.plain.Close()
   172  }
   173  
   174  // LocalAddr returns the local address.
   175  func (sk *Socket) LocalAddr() net.Addr {
   176  	return sk.plain.LocalAddr()
   177  }
   178  
   179  // RemoteAddr returns the remote address.
   180  func (sk *Socket) RemoteAddr() net.Addr {
   181  	return sk.plain.RemoteAddr()
   182  }
   183  
   184  // SetDeadline sets the deadline.
   185  func (sk *Socket) SetDeadline(t time.Time) error {
   186  	return sk.plain.SetDeadline(t)
   187  }
   188  
   189  // SetReadDeadline sets the read deadline.
   190  func (sk *Socket) SetReadDeadline(t time.Time) error {
   191  	return sk.plain.SetReadDeadline(t)
   192  }
   193  
   194  // SetWriteDeadline sets the write deadline.
   195  func (sk *Socket) SetWriteDeadline(t time.Time) error {
   196  	return sk.plain.SetWriteDeadline(t)
   197  }
   198  
   199  // SharedSec returns the shared secret. Use this to authenticate the connection (through signing etc).
   200  func (sk *Socket) SharedSec() []byte {
   201  	return sk.sharedsec
   202  }
   203  
   204  // Handshake upgrades a plaintext socket to a MiniSS socket, given our secret key.
   205  func Handshake(plain net.Conn, nextProtocol byte) (sok *Socket, err error) {
   206  	// generate ephemeral key
   207  	myesk := c25519.GenSK()
   208  	// in another thread, send over hello
   209  	wet := make(chan bool)
   210  	go func() {
   211  		var msgb bytes.Buffer
   212  		// if nextProtocol isn't zero, we send a different protocol header
   213  		if nextProtocol == 0 {
   214  			msgb.Write([]byte("TinySS-1"))
   215  		} else {
   216  			msgb.Write([]byte("TinySS-2"))
   217  		}
   218  		var pub [32]byte
   219  		curve25519.ScalarBaseMult(&pub, &myesk)
   220  		msgb.Write(pub[:])
   221  		io.Copy(plain, &msgb)
   222  		close(wet)
   223  	}()
   224  	// read hello
   225  	bts := make([]byte, 32+8)
   226  	_, err = io.ReadFull(plain, bts)
   227  	if err != nil {
   228  		return
   229  	}
   230  	// check version
   231  	if string(bts[:7]) != "TinySS-" {
   232  		err = io.ErrClosedPipe
   233  		return
   234  	}
   235  	<-wet
   236  	// read rest of hello
   237  	var repk [32]byte
   238  	copy(repk[:], bts[8:][:32])
   239  	ns := newSocket(plain, repk, myesk)
   240  	wait := make(chan bool)
   241  	if nextProtocol != 0 {
   242  		go func() {
   243  			binary.Write(ns, binary.BigEndian, nextProtocol)
   244  			close(wait)
   245  		}()
   246  	}
   247  	switch string(bts[:8]) {
   248  	case "TinySS-1":
   249  	case "TinySS-2":
   250  		// then we wait for their next protocol
   251  		var theirNextProt byte
   252  		err = binary.Read(ns, binary.BigEndian, &theirNextProt)
   253  		if err != nil {
   254  			return
   255  		}
   256  		ns.nextprot = theirNextProt
   257  	}
   258  	sok = ns
   259  	if nextProtocol != 0 {
   260  		<-wait
   261  	}
   262  	return
   263  }