github.com/eagleql/xray-core@v1.4.4/transport/internet/grpc/hub.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  
     6  	"google.golang.org/grpc"
     7  	"google.golang.org/grpc/credentials"
     8  
     9  	"github.com/eagleql/xray-core/common"
    10  	"github.com/eagleql/xray-core/common/net"
    11  	"github.com/eagleql/xray-core/common/session"
    12  	"github.com/eagleql/xray-core/transport/internet"
    13  	"github.com/eagleql/xray-core/transport/internet/grpc/encoding"
    14  	"github.com/eagleql/xray-core/transport/internet/tls"
    15  )
    16  
    17  type Listener struct {
    18  	encoding.UnimplementedGRPCServiceServer
    19  	ctx     context.Context
    20  	handler internet.ConnHandler
    21  	local   net.Addr
    22  	config  *Config
    23  	locker  *internet.FileLocker // for unix domain socket
    24  
    25  	s *grpc.Server
    26  }
    27  
    28  func (l Listener) Tun(server encoding.GRPCService_TunServer) error {
    29  	tunCtx, cancel := context.WithCancel(l.ctx)
    30  	l.handler(encoding.NewHunkConn(server, cancel))
    31  	<-tunCtx.Done()
    32  	return nil
    33  }
    34  
    35  func (l Listener) TunMulti(server encoding.GRPCService_TunMultiServer) error {
    36  	tunCtx, cancel := context.WithCancel(l.ctx)
    37  	l.handler(encoding.NewMultiHunkConn(server, cancel))
    38  	<-tunCtx.Done()
    39  	return nil
    40  }
    41  
    42  func (l Listener) Close() error {
    43  	l.s.Stop()
    44  	return nil
    45  }
    46  
    47  func (l Listener) Addr() net.Addr {
    48  	return l.local
    49  }
    50  
    51  func Listen(ctx context.Context, address net.Address, port net.Port, settings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
    52  	grpcSettings := settings.ProtocolSettings.(*Config)
    53  	var listener *Listener
    54  	if port == net.Port(0) { // unix
    55  		listener = &Listener{
    56  			handler: handler,
    57  			local: &net.UnixAddr{
    58  				Name: address.Domain(),
    59  				Net:  "unix",
    60  			},
    61  			config: grpcSettings,
    62  		}
    63  	} else { // tcp
    64  		listener = &Listener{
    65  			handler: handler,
    66  			local: &net.TCPAddr{
    67  				IP:   address.IP(),
    68  				Port: int(port),
    69  			},
    70  			config: grpcSettings,
    71  		}
    72  	}
    73  
    74  	listener.ctx = ctx
    75  
    76  	config := tls.ConfigFromStreamSettings(settings)
    77  
    78  	var s *grpc.Server
    79  	if config == nil {
    80  		s = grpc.NewServer()
    81  	} else {
    82  		s = grpc.NewServer(grpc.Creds(credentials.NewTLS(config.GetTLSConfig(tls.WithNextProto("h2")))))
    83  	}
    84  	listener.s = s
    85  
    86  	if settings.SocketSettings != nil && settings.SocketSettings.AcceptProxyProtocol {
    87  		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
    88  	}
    89  
    90  	go func() {
    91  		var streamListener net.Listener
    92  		var err error
    93  		if port == net.Port(0) { // unix
    94  			streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{
    95  				Name: address.Domain(),
    96  				Net:  "unix",
    97  			}, settings.SocketSettings)
    98  			if err != nil {
    99  				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   100  				return
   101  			}
   102  			locker := ctx.Value(address.Domain())
   103  			if locker != nil {
   104  				listener.locker = locker.(*internet.FileLocker)
   105  			}
   106  		} else { // tcp
   107  			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
   108  				IP:   address.IP(),
   109  				Port: int(port),
   110  			}, settings.SocketSettings)
   111  			if err != nil {
   112  				newError("failed to listen on ", address, ":", port).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
   113  				return
   114  			}
   115  		}
   116  
   117  		encoding.RegisterGRPCServiceServerX(s, listener, grpcSettings.ServiceName)
   118  
   119  		if err = s.Serve(streamListener); err != nil {
   120  			newError("Listener for gRPC ended").Base(err).WriteToLog()
   121  		}
   122  	}()
   123  
   124  	return listener, nil
   125  }
   126  
   127  func init() {
   128  	common.Must(internet.RegisterTransportListener(protocolName, Listen))
   129  }