github.com/linapex/ethereum-dpos-chinese@v0.0.0-20190316121959-b78b3a4a1ece/swarm/pss/handshake.go (about)

     1  
     2  //<developer>
     3  //    <name>linapex 曹一峰</name>
     4  //    <email>linapex@163.com</email>
     5  //    <wx>superexc</wx>
     6  //    <qqgroup>128148617</qqgroup>
     7  //    <url>https://jsq.ink</url>
     8  //    <role>pku engineer</role>
     9  //    <date>2019-03-16 12:09:48</date>
    10  //</624342677457997824>
    11  
    12  //
    13  //
    14  //
    15  //
    16  //
    17  //
    18  //
    19  //
    20  //
    21  //
    22  //
    23  //
    24  //
    25  //
    26  //
    27  
    28  //
    29  
    30  package pss
    31  
    32  import (
    33  	"context"
    34  	"errors"
    35  	"fmt"
    36  	"sync"
    37  	"time"
    38  
    39  	"github.com/ethereum/go-ethereum/common"
    40  	"github.com/ethereum/go-ethereum/common/hexutil"
    41  	"github.com/ethereum/go-ethereum/crypto"
    42  	"github.com/ethereum/go-ethereum/p2p"
    43  	"github.com/ethereum/go-ethereum/rlp"
    44  	"github.com/ethereum/go-ethereum/rpc"
    45  	"github.com/ethereum/go-ethereum/swarm/log"
    46  )
    47  
    48  const (
    49  	IsActiveHandshake = true
    50  )
    51  
    52  var (
    53  	ctrlSingleton *HandshakeController
    54  )
    55  
    56  const (
    57  defaultSymKeyRequestTimeout = 1000 * 8  //
    58  defaultSymKeyExpiryTimeout  = 1000 * 10 //
    59  defaultSymKeySendLimit      = 256       //
    60  defaultSymKeyCapacity       = 4         //
    61  )
    62  
    63  //
    64  type handshakeMsg struct {
    65  	From    []byte
    66  	Limit   uint16
    67  	Keys    [][]byte
    68  	Request uint8
    69  	Topic   Topic
    70  }
    71  
    72  //
    73  type handshakeKey struct {
    74  	symKeyID  *string
    75  	pubKeyID  *string
    76  	limit     uint16
    77  	count     uint16
    78  	expiredAt time.Time
    79  }
    80  
    81  //
    82  //
    83  type handshake struct {
    84  	outKeys []handshakeKey
    85  	inKeys  []handshakeKey
    86  }
    87  
    88  //
    89  //
    90  //
    91  //
    92  //
    93  //
    94  //
    95  //
    96  //
    97  //
    98  type HandshakeParams struct {
    99  	SymKeyRequestTimeout time.Duration
   100  	SymKeyExpiryTimeout  time.Duration
   101  	SymKeySendLimit      uint16
   102  	SymKeyCapacity       uint8
   103  }
   104  
   105  //
   106  func NewHandshakeParams() *HandshakeParams {
   107  	return &HandshakeParams{
   108  		SymKeyRequestTimeout: defaultSymKeyRequestTimeout * time.Millisecond,
   109  		SymKeyExpiryTimeout:  defaultSymKeyExpiryTimeout * time.Millisecond,
   110  		SymKeySendLimit:      defaultSymKeySendLimit,
   111  		SymKeyCapacity:       defaultSymKeyCapacity,
   112  	}
   113  }
   114  
   115  //
   116  //
   117  type HandshakeController struct {
   118  	pss                  *Pss
   119  keyC                 map[string]chan []string //
   120  	lock                 sync.Mutex
   121  	symKeyRequestTimeout time.Duration
   122  	symKeyExpiryTimeout  time.Duration
   123  	symKeySendLimit      uint16
   124  	symKeyCapacity       uint8
   125  	symKeyIndex          map[string]*handshakeKey
   126  	handshakes           map[string]map[Topic]*handshake
   127  	deregisterFuncs      map[Topic]func()
   128  }
   129  
   130  //
   131  //
   132  //
   133  func SetHandshakeController(pss *Pss, params *HandshakeParams) error {
   134  	ctrl := &HandshakeController{
   135  		pss:                  pss,
   136  		keyC:                 make(map[string]chan []string),
   137  		symKeyRequestTimeout: params.SymKeyRequestTimeout,
   138  		symKeyExpiryTimeout:  params.SymKeyExpiryTimeout,
   139  		symKeySendLimit:      params.SymKeySendLimit,
   140  		symKeyCapacity:       params.SymKeyCapacity,
   141  		symKeyIndex:          make(map[string]*handshakeKey),
   142  		handshakes:           make(map[string]map[Topic]*handshake),
   143  		deregisterFuncs:      make(map[Topic]func()),
   144  	}
   145  	api := &HandshakeAPI{
   146  		namespace: "pss",
   147  		ctrl:      ctrl,
   148  	}
   149  	pss.addAPI(rpc.API{
   150  		Namespace: api.namespace,
   151  		Version:   "0.2",
   152  		Service:   api,
   153  		Public:    true,
   154  	})
   155  	ctrlSingleton = ctrl
   156  	return nil
   157  }
   158  
   159  //
   160  //
   161  func (ctl *HandshakeController) validKeys(pubkeyid string, topic *Topic, in bool) (validkeys []*string) {
   162  	ctl.lock.Lock()
   163  	defer ctl.lock.Unlock()
   164  	now := time.Now()
   165  	if _, ok := ctl.handshakes[pubkeyid]; !ok {
   166  		return []*string{}
   167  	} else if _, ok := ctl.handshakes[pubkeyid][*topic]; !ok {
   168  		return []*string{}
   169  	}
   170  	var keystore *[]handshakeKey
   171  	if in {
   172  		keystore = &(ctl.handshakes[pubkeyid][*topic].inKeys)
   173  	} else {
   174  		keystore = &(ctl.handshakes[pubkeyid][*topic].outKeys)
   175  	}
   176  
   177  	for _, key := range *keystore {
   178  		if key.limit <= key.count {
   179  			ctl.releaseKey(*key.symKeyID, topic)
   180  		} else if !key.expiredAt.IsZero() && key.expiredAt.Before(now) {
   181  			ctl.releaseKey(*key.symKeyID, topic)
   182  		} else {
   183  			validkeys = append(validkeys, key.symKeyID)
   184  		}
   185  	}
   186  	return
   187  }
   188  
   189  //
   190  //
   191  func (ctl *HandshakeController) updateKeys(pubkeyid string, topic *Topic, in bool, symkeyids []string, limit uint16) {
   192  	ctl.lock.Lock()
   193  	defer ctl.lock.Unlock()
   194  	if _, ok := ctl.handshakes[pubkeyid]; !ok {
   195  		ctl.handshakes[pubkeyid] = make(map[Topic]*handshake)
   196  
   197  	}
   198  	if ctl.handshakes[pubkeyid][*topic] == nil {
   199  		ctl.handshakes[pubkeyid][*topic] = &handshake{}
   200  	}
   201  	var keystore *[]handshakeKey
   202  	expire := time.Now()
   203  	if in {
   204  		keystore = &(ctl.handshakes[pubkeyid][*topic].inKeys)
   205  	} else {
   206  		keystore = &(ctl.handshakes[pubkeyid][*topic].outKeys)
   207  		expire = expire.Add(time.Millisecond * ctl.symKeyExpiryTimeout)
   208  	}
   209  	for _, storekey := range *keystore {
   210  		storekey.expiredAt = expire
   211  	}
   212  	for i := 0; i < len(symkeyids); i++ {
   213  		storekey := handshakeKey{
   214  			symKeyID: &symkeyids[i],
   215  			pubKeyID: &pubkeyid,
   216  			limit:    limit,
   217  		}
   218  		*keystore = append(*keystore, storekey)
   219  		ctl.pss.symKeyPool[*storekey.symKeyID][*topic].protected = true
   220  	}
   221  	for i := 0; i < len(*keystore); i++ {
   222  		ctl.symKeyIndex[*(*keystore)[i].symKeyID] = &((*keystore)[i])
   223  	}
   224  }
   225  
   226  //
   227  func (ctl *HandshakeController) releaseKey(symkeyid string, topic *Topic) bool {
   228  	if ctl.symKeyIndex[symkeyid] == nil {
   229  		log.Debug("no symkey", "symkeyid", symkeyid)
   230  		return false
   231  	}
   232  	ctl.symKeyIndex[symkeyid].expiredAt = time.Now()
   233  	log.Debug("handshake release", "symkeyid", symkeyid)
   234  	return true
   235  }
   236  
   237  //
   238  //
   239  //
   240  //
   241  //
   242  func (ctl *HandshakeController) cleanHandshake(pubkeyid string, topic *Topic, in bool, out bool) int {
   243  	ctl.lock.Lock()
   244  	defer ctl.lock.Unlock()
   245  	var deletecount int
   246  	var deletes []string
   247  	now := time.Now()
   248  	handshake := ctl.handshakes[pubkeyid][*topic]
   249  	log.Debug("handshake clean", "pubkey", pubkeyid, "topic", topic)
   250  	if in {
   251  		for i, key := range handshake.inKeys {
   252  			if key.expiredAt.Before(now) || (key.expiredAt.IsZero() && key.limit <= key.count) {
   253  				log.Trace("handshake in clean remove", "symkeyid", *key.symKeyID)
   254  				deletes = append(deletes, *key.symKeyID)
   255  				handshake.inKeys[deletecount] = handshake.inKeys[i]
   256  				deletecount++
   257  			}
   258  		}
   259  		handshake.inKeys = handshake.inKeys[:len(handshake.inKeys)-deletecount]
   260  	}
   261  	if out {
   262  		deletecount = 0
   263  		for i, key := range handshake.outKeys {
   264  			if key.expiredAt.Before(now) && (key.expiredAt.IsZero() && key.limit <= key.count) {
   265  				log.Trace("handshake out clean remove", "symkeyid", *key.symKeyID)
   266  				deletes = append(deletes, *key.symKeyID)
   267  				handshake.outKeys[deletecount] = handshake.outKeys[i]
   268  				deletecount++
   269  			}
   270  		}
   271  		handshake.outKeys = handshake.outKeys[:len(handshake.outKeys)-deletecount]
   272  	}
   273  	for _, keyid := range deletes {
   274  		delete(ctl.symKeyIndex, keyid)
   275  		ctl.pss.symKeyPool[keyid][*topic].protected = false
   276  	}
   277  	return len(deletes)
   278  }
   279  
   280  //
   281  func (ctl *HandshakeController) clean() {
   282  	peerpubkeys := ctl.handshakes
   283  	for pubkeyid, peertopics := range peerpubkeys {
   284  		for topic := range peertopics {
   285  			ctl.cleanHandshake(pubkeyid, &topic, true, true)
   286  		}
   287  	}
   288  }
   289  
   290  //
   291  //
   292  //
   293  //
   294  func (ctl *HandshakeController) handler(msg []byte, p *p2p.Peer, asymmetric bool, symkeyid string) error {
   295  	if !asymmetric {
   296  		if ctl.symKeyIndex[symkeyid] != nil {
   297  			if ctl.symKeyIndex[symkeyid].count >= ctl.symKeyIndex[symkeyid].limit {
   298  				return fmt.Errorf("discarding message using expired key: %s", symkeyid)
   299  			}
   300  			ctl.symKeyIndex[symkeyid].count++
   301  			log.Trace("increment symkey recv use", "symsymkeyid", symkeyid, "count", ctl.symKeyIndex[symkeyid].count, "limit", ctl.symKeyIndex[symkeyid].limit, "receiver", common.ToHex(crypto.FromECDSAPub(ctl.pss.PublicKey())))
   302  		}
   303  		return nil
   304  	}
   305  	keymsg := &handshakeMsg{}
   306  	err := rlp.DecodeBytes(msg, keymsg)
   307  	if err == nil {
   308  		err := ctl.handleKeys(symkeyid, keymsg)
   309  		if err != nil {
   310  			log.Error("handlekeys fail", "error", err)
   311  		}
   312  		return err
   313  	}
   314  	return nil
   315  }
   316  
   317  //
   318  //
   319  //
   320  //
   321  //
   322  //
   323  //
   324  //
   325  //
   326  //
   327  func (ctl *HandshakeController) handleKeys(pubkeyid string, keymsg *handshakeMsg) error {
   328  //
   329  	if len(keymsg.Keys) > 0 {
   330  		log.Debug("received handshake keys", "pubkeyid", pubkeyid, "from", keymsg.From, "count", len(keymsg.Keys))
   331  		var sendsymkeyids []string
   332  		for _, key := range keymsg.Keys {
   333  			sendsymkey := make([]byte, len(key))
   334  			copy(sendsymkey, key)
   335  			var address PssAddress
   336  			copy(address[:], keymsg.From)
   337  			sendsymkeyid, err := ctl.pss.setSymmetricKey(sendsymkey, keymsg.Topic, &address, false, false)
   338  			if err != nil {
   339  				return err
   340  			}
   341  			sendsymkeyids = append(sendsymkeyids, sendsymkeyid)
   342  		}
   343  		if len(sendsymkeyids) > 0 {
   344  			ctl.updateKeys(pubkeyid, &keymsg.Topic, false, sendsymkeyids, keymsg.Limit)
   345  
   346  			ctl.alertHandshake(pubkeyid, sendsymkeyids)
   347  		}
   348  	}
   349  
   350  //
   351  	if keymsg.Request > 0 {
   352  		_, err := ctl.sendKey(pubkeyid, &keymsg.Topic, keymsg.Request)
   353  		if err != nil {
   354  			return err
   355  		}
   356  	}
   357  
   358  	return nil
   359  }
   360  
   361  //
   362  //
   363  //
   364  //
   365  //
   366  //
   367  func (ctl *HandshakeController) sendKey(pubkeyid string, topic *Topic, keycount uint8) ([]string, error) {
   368  
   369  	var requestcount uint8
   370  	to := &PssAddress{}
   371  	if _, ok := ctl.pss.pubKeyPool[pubkeyid]; !ok {
   372  		return []string{}, errors.New("Invalid public key")
   373  	} else if psp, ok := ctl.pss.pubKeyPool[pubkeyid][*topic]; ok {
   374  		to = psp.address
   375  	}
   376  
   377  	recvkeys := make([][]byte, keycount)
   378  	recvkeyids := make([]string, keycount)
   379  	ctl.lock.Lock()
   380  	if _, ok := ctl.handshakes[pubkeyid]; !ok {
   381  		ctl.handshakes[pubkeyid] = make(map[Topic]*handshake)
   382  	}
   383  	ctl.lock.Unlock()
   384  
   385  //
   386  	outkeys := ctl.validKeys(pubkeyid, topic, false)
   387  	if len(outkeys) < int(ctl.symKeyCapacity) {
   388  //
   389  		requestcount = ctl.symKeyCapacity
   390  	}
   391  //
   392  	if requestcount == 0 && keycount == 0 {
   393  		return []string{}, nil
   394  	}
   395  
   396  //
   397  	for i := 0; i < len(recvkeyids); i++ {
   398  		var err error
   399  		recvkeyids[i], err = ctl.pss.GenerateSymmetricKey(*topic, to, true)
   400  		if err != nil {
   401  			return []string{}, fmt.Errorf("set receive symkey fail (pubkey %x topic %x): %v", pubkeyid, topic, err)
   402  		}
   403  		recvkeys[i], err = ctl.pss.GetSymmetricKey(recvkeyids[i])
   404  		if err != nil {
   405  			return []string{}, fmt.Errorf("GET Generated outgoing symkey fail (pubkey %x topic %x): %v", pubkeyid, topic, err)
   406  		}
   407  	}
   408  	ctl.updateKeys(pubkeyid, topic, true, recvkeyids, ctl.symKeySendLimit)
   409  
   410  //
   411  	recvkeymsg := &handshakeMsg{
   412  		From:    ctl.pss.BaseAddr(),
   413  		Keys:    recvkeys,
   414  		Request: requestcount,
   415  		Limit:   ctl.symKeySendLimit,
   416  		Topic:   *topic,
   417  	}
   418  	log.Debug("sending our symkeys", "pubkey", pubkeyid, "symkeys", recvkeyids, "limit", ctl.symKeySendLimit, "requestcount", requestcount, "keycount", len(recvkeys))
   419  	recvkeybytes, err := rlp.EncodeToBytes(recvkeymsg)
   420  	if err != nil {
   421  		return []string{}, fmt.Errorf("rlp keymsg encode fail: %v", err)
   422  	}
   423  //
   424  	err = ctl.pss.SendAsym(pubkeyid, *topic, recvkeybytes)
   425  	if err != nil {
   426  		return []string{}, fmt.Errorf("Send symkey failed: %v", err)
   427  	}
   428  	return recvkeyids, nil
   429  }
   430  
   431  //
   432  func (ctl *HandshakeController) alertHandshake(pubkeyid string, symkeys []string) chan []string {
   433  	if len(symkeys) > 0 {
   434  		if _, ok := ctl.keyC[pubkeyid]; ok {
   435  			ctl.keyC[pubkeyid] <- symkeys
   436  			close(ctl.keyC[pubkeyid])
   437  			delete(ctl.keyC, pubkeyid)
   438  		}
   439  		return nil
   440  	}
   441  	if _, ok := ctl.keyC[pubkeyid]; !ok {
   442  		ctl.keyC[pubkeyid] = make(chan []string)
   443  	}
   444  	return ctl.keyC[pubkeyid]
   445  }
   446  
   447  type HandshakeAPI struct {
   448  	namespace string
   449  	ctrl      *HandshakeController
   450  }
   451  
   452  //
   453  //
   454  //
   455  //
   456  //
   457  //
   458  //
   459  //
   460  //
   461  //
   462  //
   463  //
   464  //
   465  //
   466  func (api *HandshakeAPI) Handshake(pubkeyid string, topic Topic, sync bool, flush bool) (keys []string, err error) {
   467  	var hsc chan []string
   468  	var keycount uint8
   469  	if flush {
   470  		keycount = api.ctrl.symKeyCapacity
   471  	} else {
   472  		validkeys := api.ctrl.validKeys(pubkeyid, &topic, false)
   473  		keycount = api.ctrl.symKeyCapacity - uint8(len(validkeys))
   474  	}
   475  	if keycount == 0 {
   476  		return keys, errors.New("Incoming symmetric key store is already full")
   477  	}
   478  	if sync {
   479  		hsc = api.ctrl.alertHandshake(pubkeyid, []string{})
   480  	}
   481  	_, err = api.ctrl.sendKey(pubkeyid, &topic, keycount)
   482  	if err != nil {
   483  		return keys, err
   484  	}
   485  	if sync {
   486  		ctx, cancel := context.WithTimeout(context.Background(), api.ctrl.symKeyRequestTimeout)
   487  		defer cancel()
   488  		select {
   489  		case keys = <-hsc:
   490  			log.Trace("sync handshake response receive", "key", keys)
   491  		case <-ctx.Done():
   492  			return []string{}, errors.New("timeout")
   493  		}
   494  	}
   495  	return keys, nil
   496  }
   497  
   498  //
   499  func (api *HandshakeAPI) AddHandshake(topic Topic) error {
   500  	api.ctrl.deregisterFuncs[topic] = api.ctrl.pss.Register(&topic, api.ctrl.handler)
   501  	return nil
   502  }
   503  
   504  //
   505  func (api *HandshakeAPI) RemoveHandshake(topic *Topic) error {
   506  	if _, ok := api.ctrl.deregisterFuncs[*topic]; ok {
   507  		api.ctrl.deregisterFuncs[*topic]()
   508  	}
   509  	return nil
   510  }
   511  
   512  //
   513  //
   514  //
   515  //
   516  //
   517  //
   518  func (api *HandshakeAPI) GetHandshakeKeys(pubkeyid string, topic Topic, in bool, out bool) (keys []string, err error) {
   519  	if in {
   520  		for _, inkey := range api.ctrl.validKeys(pubkeyid, &topic, true) {
   521  			keys = append(keys, *inkey)
   522  		}
   523  	}
   524  	if out {
   525  		for _, outkey := range api.ctrl.validKeys(pubkeyid, &topic, false) {
   526  			keys = append(keys, *outkey)
   527  		}
   528  	}
   529  	return keys, nil
   530  }
   531  
   532  //
   533  //
   534  func (api *HandshakeAPI) GetHandshakeKeyCapacity(symkeyid string) (uint16, error) {
   535  	storekey := api.ctrl.symKeyIndex[symkeyid]
   536  	if storekey == nil {
   537  		return 0, fmt.Errorf("invalid symkey id %s", symkeyid)
   538  	}
   539  	return storekey.limit - storekey.count, nil
   540  }
   541  
   542  //
   543  //
   544  func (api *HandshakeAPI) GetHandshakePublicKey(symkeyid string) (string, error) {
   545  	storekey := api.ctrl.symKeyIndex[symkeyid]
   546  	if storekey == nil {
   547  		return "", fmt.Errorf("invalid symkey id %s", symkeyid)
   548  	}
   549  	return *storekey.pubKeyID, nil
   550  }
   551  
   552  //
   553  //
   554  //
   555  //
   556  //
   557  func (api *HandshakeAPI) ReleaseHandshakeKey(pubkeyid string, topic Topic, symkeyid string, flush bool) (removed bool, err error) {
   558  	removed = api.ctrl.releaseKey(symkeyid, &topic)
   559  	if removed && flush {
   560  		api.ctrl.cleanHandshake(pubkeyid, &topic, true, true)
   561  	}
   562  	return
   563  }
   564  
   565  //
   566  //
   567  //
   568  //
   569  func (api *HandshakeAPI) SendSym(symkeyid string, topic Topic, msg hexutil.Bytes) (err error) {
   570  	err = api.ctrl.pss.SendSym(symkeyid, topic, msg[:])
   571  	if api.ctrl.symKeyIndex[symkeyid] != nil {
   572  		if api.ctrl.symKeyIndex[symkeyid].count >= api.ctrl.symKeyIndex[symkeyid].limit {
   573  			return errors.New("attempted send with expired key")
   574  		}
   575  		api.ctrl.symKeyIndex[symkeyid].count++
   576  		log.Trace("increment symkey send use", "symkeyid", symkeyid, "count", api.ctrl.symKeyIndex[symkeyid].count, "limit", api.ctrl.symKeyIndex[symkeyid].limit, "receiver", common.ToHex(crypto.FromECDSAPub(api.ctrl.pss.PublicKey())))
   577  	}
   578  	return
   579  }
   580