github.com/xraypb/xray-core@v1.6.6/transport/internet/grpc/dial.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	gonet "net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/xraypb/xray-core/common"
    10  	"github.com/xraypb/xray-core/common/net"
    11  	"github.com/xraypb/xray-core/common/session"
    12  	"github.com/xraypb/xray-core/transport/internet"
    13  	"github.com/xraypb/xray-core/transport/internet/grpc/encoding"
    14  	"github.com/xraypb/xray-core/transport/internet/stat"
    15  	"github.com/xraypb/xray-core/transport/internet/tls"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/backoff"
    18  	"google.golang.org/grpc/connectivity"
    19  	"google.golang.org/grpc/credentials"
    20  	"google.golang.org/grpc/keepalive"
    21  )
    22  
    23  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
    24  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    25  
    26  	conn, err := dialgRPC(ctx, dest, streamSettings)
    27  	if err != nil {
    28  		return nil, newError("failed to dial gRPC").Base(err)
    29  	}
    30  	return stat.Connection(conn), nil
    31  }
    32  
    33  func init() {
    34  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    35  }
    36  
    37  type dialerConf struct {
    38  	net.Destination
    39  	*internet.MemoryStreamConfig
    40  }
    41  
    42  var (
    43  	globalDialerMap    map[dialerConf]*grpc.ClientConn
    44  	globalDialerAccess sync.Mutex
    45  )
    46  
    47  func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
    48  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    49  
    50  	conn, err := getGrpcClient(ctx, dest, streamSettings)
    51  	if err != nil {
    52  		return nil, newError("Cannot dial gRPC").Base(err)
    53  	}
    54  	client := encoding.NewGRPCServiceClient(conn)
    55  	if grpcSettings.MultiMode {
    56  		newError("using gRPC multi mode").AtDebug().WriteToLog()
    57  		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getNormalizedName())
    58  		if err != nil {
    59  			return nil, newError("Cannot dial gRPC").Base(err)
    60  		}
    61  		return encoding.NewMultiHunkConn(grpcService, nil), nil
    62  	}
    63  
    64  	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getNormalizedName())
    65  	if err != nil {
    66  		return nil, newError("Cannot dial gRPC").Base(err)
    67  	}
    68  
    69  	return encoding.NewHunkConn(grpcService, nil), nil
    70  }
    71  
    72  func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, error) {
    73  	globalDialerAccess.Lock()
    74  	defer globalDialerAccess.Unlock()
    75  
    76  	if globalDialerMap == nil {
    77  		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
    78  	}
    79  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
    80  	sockopt := streamSettings.SocketSettings
    81  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    82  
    83  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown {
    84  		return client, nil
    85  	}
    86  
    87  	dialOptions := []grpc.DialOption{
    88  		grpc.WithConnectParams(grpc.ConnectParams{
    89  			Backoff: backoff.Config{
    90  				BaseDelay:  500 * time.Millisecond,
    91  				Multiplier: 1.5,
    92  				Jitter:     0.2,
    93  				MaxDelay:   19 * time.Second,
    94  			},
    95  			MinConnectTimeout: 5 * time.Second,
    96  		}),
    97  		grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) {
    98  			gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
    99  			gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
   100  
   101  			rawHost, rawPort, err := net.SplitHostPort(s)
   102  			select {
   103  			case <-gctx.Done():
   104  				return nil, gctx.Err()
   105  			default:
   106  			}
   107  
   108  			if err != nil {
   109  				return nil, err
   110  			}
   111  			if len(rawPort) == 0 {
   112  				rawPort = "443"
   113  			}
   114  			port, err := net.PortFromString(rawPort)
   115  			if err != nil {
   116  				return nil, err
   117  			}
   118  			address := net.ParseAddress(rawHost)
   119  			return internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
   120  		}),
   121  	}
   122  
   123  	if tlsConfig != nil {
   124  		var transportCredential credentials.TransportCredentials
   125  		if fingerprint, exists := tls.Fingerprints[tlsConfig.Fingerprint]; exists {
   126  			transportCredential = tls.NewGrpcUtls(tlsConfig.GetTLSConfig(), fingerprint)
   127  		} else { // Fallback to normal gRPC TLS
   128  			transportCredential = credentials.NewTLS(tlsConfig.GetTLSConfig())
   129  		}
   130  		dialOptions = append(dialOptions, grpc.WithTransportCredentials(transportCredential))
   131  	} else {
   132  		dialOptions = append(dialOptions, grpc.WithInsecure())
   133  	}
   134  
   135  	if grpcSettings.IdleTimeout > 0 || grpcSettings.HealthCheckTimeout > 0 || grpcSettings.PermitWithoutStream {
   136  		dialOptions = append(dialOptions, grpc.WithKeepaliveParams(keepalive.ClientParameters{
   137  			Time:                time.Second * time.Duration(grpcSettings.IdleTimeout),
   138  			Timeout:             time.Second * time.Duration(grpcSettings.HealthCheckTimeout),
   139  			PermitWithoutStream: grpcSettings.PermitWithoutStream,
   140  		}))
   141  	}
   142  
   143  	if grpcSettings.InitialWindowsSize > 0 {
   144  		dialOptions = append(dialOptions, grpc.WithInitialWindowSize(grpcSettings.InitialWindowsSize))
   145  	}
   146  
   147  	var grpcDestHost string
   148  	if dest.Address.Family().IsDomain() {
   149  		grpcDestHost = dest.Address.Domain()
   150  	} else {
   151  		grpcDestHost = dest.Address.IP().String()
   152  	}
   153  
   154  	conn, err := grpc.Dial(
   155  		gonet.JoinHostPort(grpcDestHost, dest.Port.String()),
   156  		dialOptions...,
   157  	)
   158  	globalDialerMap[dialerConf{dest, streamSettings}] = conn
   159  	return conn, err
   160  }