github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2raygrpc/tls_credentials.go (about)

     1  package v2raygrpc
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"os"
     7  
     8  	"github.com/sagernet/sing-box/common/tls"
     9  	internal_credentials "github.com/sagernet/sing-box/transport/v2raygrpc/credentials"
    10  
    11  	"google.golang.org/grpc/credentials"
    12  )
    13  
    14  type TLSTransportCredentials struct {
    15  	config tls.Config
    16  }
    17  
    18  func NewTLSTransportCredentials(config tls.Config) credentials.TransportCredentials {
    19  	return &TLSTransportCredentials{config}
    20  }
    21  
    22  func (c *TLSTransportCredentials) Info() credentials.ProtocolInfo {
    23  	return credentials.ProtocolInfo{
    24  		SecurityProtocol: "tls",
    25  		SecurityVersion:  "1.2",
    26  		ServerName:       c.config.ServerName(),
    27  	}
    28  }
    29  
    30  func (c *TLSTransportCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    31  	cfg := c.config.Clone()
    32  	if cfg.ServerName() == "" {
    33  		serverName, _, err := net.SplitHostPort(authority)
    34  		if err != nil {
    35  			serverName = authority
    36  		}
    37  		cfg.SetServerName(serverName)
    38  	}
    39  	conn, err := tls.ClientHandshake(ctx, rawConn, cfg)
    40  	if err != nil {
    41  		return nil, nil, err
    42  	}
    43  	tlsInfo := credentials.TLSInfo{
    44  		State: conn.ConnectionState(),
    45  		CommonAuthInfo: credentials.CommonAuthInfo{
    46  			SecurityLevel: credentials.PrivacyAndIntegrity,
    47  		},
    48  	}
    49  	id := internal_credentials.SPIFFEIDFromState(conn.ConnectionState())
    50  	if id != nil {
    51  		tlsInfo.SPIFFEID = id
    52  	}
    53  	return internal_credentials.WrapSyscallConn(rawConn, conn), tlsInfo, nil
    54  }
    55  
    56  func (c *TLSTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
    57  	serverConfig, isServer := c.config.(tls.ServerConfig)
    58  	if !isServer {
    59  		return nil, nil, os.ErrInvalid
    60  	}
    61  	conn, err := tls.ServerHandshake(context.Background(), rawConn, serverConfig)
    62  	if err != nil {
    63  		rawConn.Close()
    64  		return nil, nil, err
    65  	}
    66  	tlsInfo := credentials.TLSInfo{
    67  		State: conn.ConnectionState(),
    68  		CommonAuthInfo: credentials.CommonAuthInfo{
    69  			SecurityLevel: credentials.PrivacyAndIntegrity,
    70  		},
    71  	}
    72  	id := internal_credentials.SPIFFEIDFromState(conn.ConnectionState())
    73  	if id != nil {
    74  		tlsInfo.SPIFFEID = id
    75  	}
    76  	return internal_credentials.WrapSyscallConn(rawConn, conn), tlsInfo, nil
    77  }
    78  
    79  func (c *TLSTransportCredentials) Clone() credentials.TransportCredentials {
    80  	return NewTLSTransportCredentials(c.config)
    81  }
    82  
    83  func (c *TLSTransportCredentials) OverrideServerName(serverNameOverride string) error {
    84  	c.config.SetServerName(serverNameOverride)
    85  	return nil
    86  }