github.com/eagleql/xray-core@v1.4.4/transport/internet/grpc/dial.go (about) 1 package grpc 2 3 import ( 4 "context" 5 gonet "net" 6 "sync" 7 "time" 8 9 "google.golang.org/grpc" 10 "google.golang.org/grpc/backoff" 11 "google.golang.org/grpc/connectivity" 12 "google.golang.org/grpc/credentials" 13 14 "github.com/eagleql/xray-core/common" 15 "github.com/eagleql/xray-core/common/net" 16 "github.com/eagleql/xray-core/common/session" 17 "github.com/eagleql/xray-core/transport/internet" 18 "github.com/eagleql/xray-core/transport/internet/grpc/encoding" 19 "github.com/eagleql/xray-core/transport/internet/tls" 20 ) 21 22 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { 23 newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) 24 25 conn, err := dialgRPC(ctx, dest, streamSettings) 26 if err != nil { 27 return nil, newError("failed to dial gRPC").Base(err) 28 } 29 return internet.Connection(conn), nil 30 } 31 32 func init() { 33 common.Must(internet.RegisterTransportDialer(protocolName, Dial)) 34 } 35 36 type dialerConf struct { 37 net.Destination 38 *internet.SocketConfig 39 *tls.Config 40 } 41 42 var ( 43 globalDialerMap map[dialerConf]*grpc.ClientConn 44 globalDialerAccess sync.Mutex 45 ) 46 47 func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) { 48 grpcSettings := streamSettings.ProtocolSettings.(*Config) 49 50 tlsConfig := tls.ConfigFromStreamSettings(streamSettings) 51 52 conn, err := getGrpcClient(ctx, dest, tlsConfig, streamSettings.SocketSettings) 53 54 if err != nil { 55 return nil, newError("Cannot dial gRPC").Base(err) 56 } 57 client := encoding.NewGRPCServiceClient(conn) 58 if grpcSettings.MultiMode { 59 newError("using gRPC multi mode").AtDebug().WriteToLog() 60 grpcService, err := client.(encoding.GRPCServiceClientX).TunMultiCustomName(ctx, grpcSettings.ServiceName) 61 if err != nil { 62 return nil, newError("Cannot dial gRPC").Base(err) 63 } 64 return encoding.NewMultiHunkConn(grpcService, nil), nil 65 } 66 67 grpcService, err := client.(encoding.GRPCServiceClientX).TunCustomName(ctx, grpcSettings.ServiceName) 68 if err != nil { 69 return nil, newError("Cannot dial gRPC").Base(err) 70 } 71 72 return encoding.NewHunkConn(grpcService, nil), nil 73 } 74 75 func getGrpcClient(ctx context.Context, dest net.Destination, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (*grpc.ClientConn, error) { 76 globalDialerAccess.Lock() 77 defer globalDialerAccess.Unlock() 78 79 if globalDialerMap == nil { 80 globalDialerMap = make(map[dialerConf]*grpc.ClientConn) 81 } 82 83 if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsConfig}]; found && client.GetState() != connectivity.Shutdown { 84 return client, nil 85 } 86 87 dialOption := grpc.WithInsecure() 88 89 if tlsConfig != nil { 90 dialOption = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig.GetTLSConfig())) 91 } 92 93 var grpcDestHost string 94 if dest.Address.Family().IsDomain() { 95 grpcDestHost = dest.Address.Domain() 96 } else { 97 grpcDestHost = dest.Address.IP().String() 98 } 99 conn, err := grpc.Dial( 100 gonet.JoinHostPort(grpcDestHost, dest.Port.String()), 101 dialOption, 102 grpc.WithConnectParams(grpc.ConnectParams{ 103 Backoff: backoff.Config{ 104 BaseDelay: 500 * time.Millisecond, 105 Multiplier: 1.5, 106 Jitter: 0.2, 107 MaxDelay: 19 * time.Second, 108 }, 109 MinConnectTimeout: 5 * time.Second, 110 }), 111 grpc.WithContextDialer(func(gctx context.Context, s string) (gonet.Conn, error) { 112 gctx = session.ContextWithID(gctx, session.IDFromContext(ctx)) 113 gctx = session.ContextWithOutbound(gctx, session.OutboundFromContext(ctx)) 114 115 rawHost, rawPort, err := net.SplitHostPort(s) 116 select { 117 case <-gctx.Done(): 118 return nil, gctx.Err() 119 default: 120 } 121 122 if err != nil { 123 return nil, err 124 } 125 if len(rawPort) == 0 { 126 rawPort = "443" 127 } 128 port, err := net.PortFromString(rawPort) 129 if err != nil { 130 return nil, err 131 } 132 address := net.ParseAddress(rawHost) 133 return internet.DialSystem(gctx, net.TCPDestination(address, port), sockopt) 134 }), 135 ) 136 globalDialerMap[dialerConf{dest, sockopt, tlsConfig}] = conn 137 return conn, err 138 }