github.com/ava-labs/avalanchego@v1.11.11/network/peer/upgrader.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package peer
     5  
     6  import (
     7  	"crypto/tls"
     8  	"errors"
     9  	"net"
    10  
    11  	"github.com/prometheus/client_golang/prometheus"
    12  
    13  	"github.com/ava-labs/avalanchego/ids"
    14  	"github.com/ava-labs/avalanchego/staking"
    15  )
    16  
    17  var (
    18  	errNoCert = errors.New("tls handshake finished with no peer certificate")
    19  
    20  	_ Upgrader = (*tlsServerUpgrader)(nil)
    21  	_ Upgrader = (*tlsClientUpgrader)(nil)
    22  )
    23  
    24  type Upgrader interface {
    25  	// Must be thread safe
    26  	Upgrade(net.Conn) (ids.NodeID, net.Conn, *staking.Certificate, error)
    27  }
    28  
    29  type tlsServerUpgrader struct {
    30  	config       *tls.Config
    31  	invalidCerts prometheus.Counter
    32  }
    33  
    34  func NewTLSServerUpgrader(config *tls.Config, invalidCerts prometheus.Counter) Upgrader {
    35  	return &tlsServerUpgrader{
    36  		config:       config,
    37  		invalidCerts: invalidCerts,
    38  	}
    39  }
    40  
    41  func (t *tlsServerUpgrader) Upgrade(conn net.Conn) (ids.NodeID, net.Conn, *staking.Certificate, error) {
    42  	return connToIDAndCert(tls.Server(conn, t.config), t.invalidCerts)
    43  }
    44  
    45  type tlsClientUpgrader struct {
    46  	config       *tls.Config
    47  	invalidCerts prometheus.Counter
    48  }
    49  
    50  func NewTLSClientUpgrader(config *tls.Config, invalidCerts prometheus.Counter) Upgrader {
    51  	return &tlsClientUpgrader{
    52  		config:       config,
    53  		invalidCerts: invalidCerts,
    54  	}
    55  }
    56  
    57  func (t *tlsClientUpgrader) Upgrade(conn net.Conn) (ids.NodeID, net.Conn, *staking.Certificate, error) {
    58  	return connToIDAndCert(tls.Client(conn, t.config), t.invalidCerts)
    59  }
    60  
    61  func connToIDAndCert(conn *tls.Conn, invalidCerts prometheus.Counter) (ids.NodeID, net.Conn, *staking.Certificate, error) {
    62  	if err := conn.Handshake(); err != nil {
    63  		return ids.EmptyNodeID, nil, nil, err
    64  	}
    65  
    66  	state := conn.ConnectionState()
    67  	if len(state.PeerCertificates) == 0 {
    68  		return ids.EmptyNodeID, nil, nil, errNoCert
    69  	}
    70  
    71  	tlsCert := state.PeerCertificates[0]
    72  	peerCert, err := staking.ParseCertificate(tlsCert.Raw)
    73  	if err != nil {
    74  		invalidCerts.Inc()
    75  		return ids.EmptyNodeID, nil, nil, err
    76  	}
    77  
    78  	nodeID := ids.NodeIDFromCert(peerCert)
    79  	return nodeID, conn, peerCert, nil
    80  }