github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/grpc/dial.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	gonet "net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/net"
    11  	"github.com/xtls/xray-core/common/session"
    12  	"github.com/xtls/xray-core/transport/internet"
    13  	"github.com/xtls/xray-core/transport/internet/grpc/encoding"
    14  	"github.com/xtls/xray-core/transport/internet/reality"
    15  	"github.com/xtls/xray-core/transport/internet/stat"
    16  	"github.com/xtls/xray-core/transport/internet/tls"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/backoff"
    19  	"google.golang.org/grpc/connectivity"
    20  	"google.golang.org/grpc/credentials/insecure"
    21  	"google.golang.org/grpc/keepalive"
    22  )
    23  
    24  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
    25  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    26  
    27  	conn, err := dialgRPC(ctx, dest, streamSettings)
    28  	if err != nil {
    29  		return nil, newError("failed to dial gRPC").Base(err)
    30  	}
    31  	return stat.Connection(conn), nil
    32  }
    33  
    34  func init() {
    35  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    36  }
    37  
    38  type dialerConf struct {
    39  	net.Destination
    40  	*internet.MemoryStreamConfig
    41  }
    42  
    43  var (
    44  	globalDialerMap    map[dialerConf]*grpc.ClientConn
    45  	globalDialerAccess sync.Mutex
    46  )
    47  
    48  func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
    49  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    50  
    51  	conn, err := getGrpcClient(ctx, dest, streamSettings)
    52  	if err != nil {
    53  		return nil, newError("Cannot dial gRPC").Base(err)
    54  	}
    55  	client := encoding.NewGRPCServiceClient(conn)
    56  	if grpcSettings.MultiMode {
    57  		newError("using gRPC multi mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunMultiStreamName() + "`").AtDebug().WriteToLog()
    58  		grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunMultiStreamName())
    59  		if err != nil {
    60  			return nil, newError("Cannot dial gRPC").Base(err)
    61  		}
    62  		return encoding.NewMultiHunkConn(grpcService, nil), nil
    63  	}
    64  
    65  	newError("using gRPC tun mode service name: `" + grpcSettings.getServiceName() + "` stream name: `" + grpcSettings.getTunStreamName() + "`").AtDebug().WriteToLog()
    66  	grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.getServiceName(), grpcSettings.getTunStreamName())
    67  	if err != nil {
    68  		return nil, newError("Cannot dial gRPC").Base(err)
    69  	}
    70  
    71  	return encoding.NewHunkConn(grpcService, nil), nil
    72  }
    73  
    74  func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, error) {
    75  	globalDialerAccess.Lock()
    76  	defer globalDialerAccess.Unlock()
    77  
    78  	if globalDialerMap == nil {
    79  		globalDialerMap = make(map[dialerConf]*grpc.ClientConn)
    80  	}
    81  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
    82  	realityConfig := reality.ConfigFromStreamSettings(streamSettings)
    83  	sockopt := streamSettings.SocketSettings
    84  	grpcSettings := streamSettings.ProtocolSettings.(*Config)
    85  
    86  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found && client.GetState() != connectivity.Shutdown {
    87  		return client, nil
    88  	}
    89  
    90  	dialOptions := []grpc.DialOption{
    91  		grpc.WithConnectParams(grpc.ConnectParams{
    92  			Backoff: backoff.Config{
    93  				BaseDelay:  500 * time.Millisecond,
    94  				Multiplier: 1.5,
    95  				Jitter:     0.2,
    96  				MaxDelay:   19 * time.Second,
    97  			},
    98  			MinConnectTimeout: 5 * time.Second,
    99  		}),
   100  		grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) {
   101  			select {
   102  			case <-gctx.Done():
   103  				return nil, gctx.Err()
   104  			default:
   105  			}
   106  
   107  			rawHost, rawPort, err := net.SplitHostPort(s)
   108  			if err != nil {
   109  				return nil, err
   110  			}
   111  			if len(rawPort) == 0 {
   112  				rawPort = "443"
   113  			}
   114  			port, err := net.PortFromString(rawPort)
   115  			if err != nil {
   116  				return nil, err
   117  			}
   118  			address := net.ParseAddress(rawHost)
   119  
   120  			gctx = session.ContextWithID(gctx, session.IDFromContext(ctx))
   121  			gctx = session.ContextWithOutbounds(gctx, session.OutboundsFromContext(ctx))
   122  			gctx = session.ContextWithTimeoutOnly(gctx, true)
   123  
   124  			c, err := internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt)
   125  			if err == nil {
   126  				if tlsConfig != nil {
   127  					config := tlsConfig.GetTLSConfig()
   128  					if config.ServerName == "" && address.Family().IsDomain() {
   129  						config.ServerName = address.Domain()
   130  					}
   131  					if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
   132  						return tls.UClient(c, config, fingerprint), nil
   133  					} else { // Fallback to normal gRPC TLS
   134  						return tls.Client(c, config), nil
   135  					}
   136  				}
   137  				if realityConfig != nil {
   138  					return reality.UClient(c, realityConfig, gctx, dest)
   139  				}
   140  			}
   141  			return c, err
   142  		}),
   143  	}
   144  
   145  	dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
   146  
   147  	authority := ""
   148  	if grpcSettings.Authority != "" {
   149  		authority = grpcSettings.Authority
   150  	} else if tlsConfig != nil && tlsConfig.ServerName != "" {
   151  		authority = tlsConfig.ServerName
   152  	} else if realityConfig == nil && dest.Address.Family().IsDomain() {
   153  		authority = dest.Address.Domain()
   154  	}
   155  	dialOptions = append(dialOptions, grpc.WithAuthority(authority))
   156  
   157  	if grpcSettings.IdleTimeout > 0 || grpcSettings.HealthCheckTimeout > 0 || grpcSettings.PermitWithoutStream {
   158  		dialOptions = append(dialOptions, grpc.WithKeepaliveParams(keepalive.ClientParameters{
   159  			Time:                time.Second * time.Duration(grpcSettings.IdleTimeout),
   160  			Timeout:             time.Second * time.Duration(grpcSettings.HealthCheckTimeout),
   161  			PermitWithoutStream: grpcSettings.PermitWithoutStream,
   162  		}))
   163  	}
   164  
   165  	if grpcSettings.InitialWindowsSize > 0 {
   166  		dialOptions = append(dialOptions, grpc.WithInitialWindowSize(grpcSettings.InitialWindowsSize))
   167  	}
   168  
   169  	if grpcSettings.UserAgent != "" {
   170  		dialOptions = append(dialOptions, grpc.WithUserAgent(grpcSettings.UserAgent))
   171  	}
   172  
   173  	var grpcDestHost string
   174  	if dest.Address.Family().IsDomain() {
   175  		grpcDestHost = dest.Address.Domain()
   176  	} else {
   177  		grpcDestHost = dest.Address.IP().String()
   178  	}
   179  
   180  	conn, err := grpc.Dial(
   181  		gonet.JoinHostPort(grpcDestHost, dest.Port.String()),
   182  		dialOptions...,
   183  	)
   184  	globalDialerMap[dialerConf{dest, streamSettings}] = conn
   185  	return conn, err
   186  }