github.com/xraypb/Xray-core@v1.8.1/transport/internet/grpc/dial.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	gonet "net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/xraypb/Xray-core/common"
    10  	"github.com/xraypb/Xray-core/common/net"
    11  	"github.com/xraypb/Xray-core/common/session"
    12  	"github.com/xraypb/Xray-core/transport/internet"
    13  	"github.com/xraypb/Xray-core/transport/internet/grpc/encoding"
    14  	"github.com/xraypb/Xray-core/transport/internet/reality"
    15  	"github.com/xraypb/Xray-core/transport/internet/stat"
    16  	"github.com/xraypb/Xray-core/transport/internet/tls"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/backoff"
    19  	"google.golang.org/grpc/connectivity"
    20  	"google.golang.org/grpc/credentials"
    21  	"google.golang.org/grpc/keepalive"
    22  )
    23  
    24  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
    25  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    26  
    27  	conn, err := dialgRPC(ctx, dest, streamSettings)
    28  	if err != nil {
    29  		return nil, newError("failed to dial gRPC").Base(err)
    30  	}
    31  	return stat.Connection(conn), nil
    32  }
    33  
    34  func init() {
    35  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    36  }
    37  
    38  type dialerConf struct {
    39  	net.Destination
    40  	*internet.MemoryStreamConfig
    41  }
    42  
    43  var (
    44  	globalDialerMap    map[dialerConf]*grpc.ClientConn
    45  	globalDialerAccess sync.Mutex
    46  )
    47  
    48  func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
    49  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    50  
    51  	conn, err := getGrpcClient(ctx, dest, streamSettings)
    52  	if err != nil {
    53  		return nil, newError("Cannot dial gRPC").Base(err)
    54  	}
    55  	client := encoding.NewGRPCServiceClient(conn)
    56  	if grpcSettings.MultiMode {
    57  		newError("using gRPC multi mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunMultiStreamName() + "`").AtDebug().WriteToLog()
    58  		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunMultiStreamName())
    59  		if err != nil {
    60  			return nil, newError("Cannot dial gRPC").Base(err)
    61  		}
    62  		return encoding.NewMultiHunkConn(grpcService, nil), nil
    63  	}
    64  
    65  	newError("using gRPC tun mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunStreamName() + "`").AtDebug().WriteToLog()
    66  	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunStreamName())
    67  	if err != nil {
    68  		return nil, newError("Cannot dial gRPC").Base(err)
    69  	}
    70  
    71  	return encoding.NewHunkConn(grpcService, nil), nil
    72  }
    73  
    74  func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, error) {
    75  	globalDialerAccess.Lock()
    76  	defer globalDialerAccess.Unlock()
    77  
    78  	if globalDialerMap == nil {
    79  		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
    80  	}
    81  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
    82  	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
    83  	sockopt := streamSettings.SocketSettings
    84  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    85  
    86  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown {
    87  		return client, nil
    88  	}
    89  
    90  	dialOptions := []grpc.DialOption{
    91  		grpc.WithConnectParams(grpc.ConnectParams{
    92  			Backoff: backoff.Config{
    93  				BaseDelay:  500 * time.Millisecond,
    94  				Multiplier: 1.5,
    95  				Jitter:     0.2,
    96  				MaxDelay:   19 * time.Second,
    97  			},
    98  			MinConnectTimeout: 5 * time.Second,
    99  		}),
   100  		grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) {
   101  			gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
   102  			gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx))
   103  
   104  			rawHost, rawPort, err := net.SplitHostPort(s)
   105  			select {
   106  			case <-gctx.Done():
   107  				return nil, gctx.Err()
   108  			default:
   109  			}
   110  
   111  			if err != nil {
   112  				return nil, err
   113  			}
   114  			if len(rawPort) == 0 {
   115  				rawPort = "443"
   116  			}
   117  			port, err := net.PortFromString(rawPort)
   118  			if err != nil {
   119  				return nil, err
   120  			}
   121  			address := net.ParseAddress(rawHost)
   122  			c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
   123  			if err == nil && realityConfig != nil {
   124  				return reality.UClient(c, realityConfig, ctx, dest)
   125  			}
   126  			return c, err
   127  		}),
   128  	}
   129  
   130  	if tlsConfig != nil {
   131  		var transportCredential credentials.TransportCredentials
   132  		if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
   133  			transportCredential = tls.NewGrpcUtls(tlsConfig.GetTLSConfig(), fingerprint)
   134  		} else { // Fallback to normal gRPC TLS
   135  			transportCredential = credentials.NewTLS(tlsConfig.GetTLSConfig())
   136  		}
   137  		dialOptions = append(dialOptions, grpc.WithTransportCredentials(transportCredential))
   138  	} else {
   139  		dialOptions = append(dialOptions, grpc.WithInsecure())
   140  	}
   141  
   142  	if grpcSettings.IdleTimeout > 0 || grpcSettings.HealthCheckTimeout > 0 || grpcSettings.PermitWithoutStream {
   143  		dialOptions = append(dialOptions, grpc.WithKeepaliveParams(keepalive.ClientParameters{
   144  			Time:                time.Second * time.Duration(grpcSettings.IdleTimeout),
   145  			Timeout:             time.Second * time.Duration(grpcSettings.HealthCheckTimeout),
   146  			PermitWithoutStream: grpcSettings.PermitWithoutStream,
   147  		}))
   148  	}
   149  
   150  	if grpcSettings.InitialWindowsSize > 0 {
   151  		dialOptions = append(dialOptions, grpc.WithInitialWindowSize(grpcSettings.InitialWindowsSize))
   152  	}
   153  
   154  	if grpcSettings.UserAgent != "" {
   155  		dialOptions = append(dialOptions, grpc.WithUserAgent(grpcSettings.UserAgent))
   156  	}
   157  
   158  	var grpcDestHost string
   159  	if dest.Address.Family().IsDomain() {
   160  		grpcDestHost = dest.Address.Domain()
   161  	} else {
   162  		grpcDestHost = dest.Address.IP().String()
   163  	}
   164  
   165  	conn, err := grpc.Dial(
   166  		gonet.JoinHostPort(grpcDestHost, dest.Port.String()),
   167  		dialOptions...,
   168  	)
   169  	globalDialerMap[dialerConf{dest, streamSettings}] = conn
   170  	return conn, err
   171  }