github.com/slackhq/nebula@v1.9.0/connection_state.go (about)

     1  package nebula
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/json"
     6  	"sync"
     7  	"sync/atomic"
     8  
     9  	"github.com/flynn/noise"
    10  	"github.com/sirupsen/logrus"
    11  	"github.com/slackhq/nebula/cert"
    12  	"github.com/slackhq/nebula/noiseutil"
    13  )
    14  
    15  const ReplayWindow = 1024
    16  
    17  type ConnectionState struct {
    18  	eKey           *NebulaCipherState
    19  	dKey           *NebulaCipherState
    20  	H              *noise.HandshakeState
    21  	myCert         *cert.NebulaCertificate
    22  	peerCert       *cert.NebulaCertificate
    23  	initiator      bool
    24  	messageCounter atomic.Uint64
    25  	window         *Bits
    26  	writeLock      sync.Mutex
    27  }
    28  
    29  func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
    30  	var dhFunc noise.DHFunc
    31  	switch certState.Certificate.Details.Curve {
    32  	case cert.Curve_CURVE25519:
    33  		dhFunc = noise.DH25519
    34  	case cert.Curve_P256:
    35  		dhFunc = noiseutil.DHP256
    36  	default:
    37  		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
    38  		return nil
    39  	}
    40  
    41  	var cs noise.CipherSuite
    42  	if cipher == "chachapoly" {
    43  		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
    44  	} else {
    45  		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
    46  	}
    47  
    48  	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
    49  
    50  	b := NewBits(ReplayWindow)
    51  	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
    52  	b.Update(l, 0)
    53  
    54  	hs, err := noise.NewHandshakeState(noise.Config{
    55  		CipherSuite:           cs,
    56  		Random:                rand.Reader,
    57  		Pattern:               pattern,
    58  		Initiator:             initiator,
    59  		StaticKeypair:         static,
    60  		PresharedKey:          psk,
    61  		PresharedKeyPlacement: pskStage,
    62  	})
    63  	if err != nil {
    64  		return nil
    65  	}
    66  
    67  	// The queue and ready params prevent a counter race that would happen when
    68  	// sending stored packets and simultaneously accepting new traffic.
    69  	ci := &ConnectionState{
    70  		H:         hs,
    71  		initiator: initiator,
    72  		window:    b,
    73  		myCert:    certState.Certificate,
    74  	}
    75  
    76  	return ci
    77  }
    78  
    79  func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
    80  	return json.Marshal(m{
    81  		"certificate":     cs.peerCert,
    82  		"initiator":       cs.initiator,
    83  		"message_counter": cs.messageCounter.Load(),
    84  	})
    85  }