github.com/EagleQL/Xray-core@v1.4.3/transport/internet/grpc/dial.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	gonet "net"
     6  	"sync"
     7  	"time"
     8  
     9  	"google.golang.org/grpc"
    10  	"google.golang.org/grpc/backoff"
    11  	"google.golang.org/grpc/connectivity"
    12  	"google.golang.org/grpc/credentials"
    13  
    14  	"github.com/xtls/xray-core/common"
    15  	"github.com/xtls/xray-core/common/net"
    16  	"github.com/xtls/xray-core/common/session"
    17  	"github.com/xtls/xray-core/transport/internet"
    18  	"github.com/xtls/xray-core/transport/internet/grpc/encoding"
    19  	"github.com/xtls/xray-core/transport/internet/tls"
    20  )
    21  
    22  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
    23  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    24  
    25  	conn, err := dialgRPC(ctx, dest, streamSettings)
    26  	if err != nil {
    27  		return nil, newError("failed to dial gRPC").Base(err)
    28  	}
    29  	return internet.Connection(conn), nil
    30  }
    31  
    32  func init() {
    33  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    34  }
    35  
    36  type dialerConf struct {
    37  	net.Destination
    38  	*internet.SocketConfig
    39  	*tls.Config
    40  }
    41  
    42  var (
    43  	globalDialerMap    map[dialerConf]*grpc.ClientConn
    44  	globalDialerAccess sync.Mutex
    45  )
    46  
    47  func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
    48  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    49  
    50  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
    51  
    52  	conn, err := getGrpcClient(ctx, dest, tlsConfig, streamSettings.SocketSettings)
    53  
    54  	if err != nil {
    55  		return nil, newError("Cannot dial gRPC").Base(err)
    56  	}
    57  	client := encoding.NewGRPCServiceClient(conn)
    58  	if grpcSettings.MultiMode {
    59  		newError("using gRPC multi mode").AtDebug().WriteToLog()
    60  		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.ServiceName)
    61  		if err != nil {
    62  			return nil, newError("Cannot dial gRPC").Base(err)
    63  		}
    64  		return encoding.NewMultiHunkConn(grpcService, nil), nil
    65  	}
    66  
    67  	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.ServiceName)
    68  	if err != nil {
    69  		return nil, newError("Cannot dial gRPC").Base(err)
    70  	}
    71  
    72  	return encoding.NewHunkConn(grpcService, nil), nil
    73  }
    74  
    75  func getGrpcClient(ctx context.Context, dest net.Destination, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) {
    76  	globalDialerAccess.Lock()
    77  	defer globalDialerAccess.Unlock()
    78  
    79  	if globalDialerMap == nil {
    80  		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
    81  	}
    82  
    83  	if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsConfig}]; found && client.GetState() != connectivity.Shutdown {
    84  		return client, nil
    85  	}
    86  
    87  	dialOption := grpc.WithInsecure()
    88  
    89  	if tlsConfig != nil {
    90  		dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig.GetTLSConfig()))
    91  	}
    92  
    93  	var grpcDestHost string
    94  	if dest.Address.Family().IsDomain() {
    95  		grpcDestHost = dest.Address.Domain()
    96  	} else {
    97  		grpcDestHost = dest.Address.IP().String()
    98  	}
    99  	conn, err := grpc.Dial(
   100  		gonet.JoinHostPort(grpcDestHost, dest.Port.String()),
   101  		dialOption,
   102  		grpc.WithConnectParams(grpc.ConnectParams{
   103  			Backoff: backoff.Config{
   104  				BaseDelay:  500 * time.Millisecond,
   105  				Multiplier: 1.5,
   106  				Jitter:     0.2,
   107  				MaxDelay:   19 * time.Second,
   108  			},
   109  			MinConnectTimeout: 5 * time.Second,
   110  		}),
   111  		grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) {
   112  			gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
   113  			gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
   114  
   115  			rawHost, rawPort, err := net.SplitHostPort(s)
   116  			select {
   117  			case <-gctx.Done():
   118  				return nil, gctx.Err()
   119  			default:
   120  			}
   121  
   122  			if err != nil {
   123  				return nil, err
   124  			}
   125  			if len(rawPort) == 0 {
   126  				rawPort = "443"
   127  			}
   128  			port, err := net.PortFromString(rawPort)
   129  			if err != nil {
   130  				return nil, err
   131  			}
   132  			address := net.ParseAddress(rawHost)
   133  			return internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
   134  		}),
   135  	)
   136  	globalDialerMap[dialerConf{dest, sockopt, tlsConfig}] = conn
   137  	return conn, err
   138  }