github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/server.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  
     8  	"github.com/xtls/xray-core/common"
     9  	"github.com/xtls/xray-core/common/buf"
    10  	"github.com/xtls/xray-core/common/log"
    11  	"github.com/xtls/xray-core/common/net"
    12  	"github.com/xtls/xray-core/common/session"
    13  	"github.com/xtls/xray-core/common/signal"
    14  	"github.com/xtls/xray-core/common/task"
    15  	"github.com/xtls/xray-core/core"
    16  	"github.com/xtls/xray-core/features/dns"
    17  	"github.com/xtls/xray-core/features/policy"
    18  	"github.com/xtls/xray-core/features/routing"
    19  	"github.com/xtls/xray-core/transport/internet/stat"
    20  )
    21  
    22  var nullDestination = net.TCPDestination(net.AnyIP, 0)
    23  
    24  type Server struct {
    25  	bindServer *netBindServer
    26  
    27  	info          routingInfo
    28  	policyManager policy.Manager
    29  }
    30  
    31  type routingInfo struct {
    32  	ctx         context.Context
    33  	dispatcher  routing.Dispatcher
    34  	inboundTag  *session.Inbound
    35  	outboundTag *session.Outbound
    36  	contentTag  *session.Content
    37  }
    38  
    39  func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
    40  	v := core.MustFromContext(ctx)
    41  
    42  	endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	server := &Server{
    48  		bindServer: &netBindServer{
    49  			netBind: netBind{
    50  				dns: v.GetFeature(dns.ClientType()).(dns.Client),
    51  				dnsOption: dns.IPOption{
    52  					IPv4Enable: hasIPv4,
    53  					IPv6Enable: hasIPv6,
    54  				},
    55  			},
    56  		},
    57  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    58  	}
    59  
    60  	tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
    66  		_ = tun.Close()
    67  		return nil, err
    68  	}
    69  
    70  	return server, nil
    71  }
    72  
    73  // Network implements proxy.Inbound.
    74  func (*Server) Network() []net.Network {
    75  	return []net.Network{net.Network_UDP}
    76  }
    77  
    78  // Process implements proxy.Inbound.
    79  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    80  	inbound := session.InboundFromContext(ctx)
    81  	inbound.Name = "wireguard"
    82  	inbound.CanSpliceCopy = 3
    83  	outbounds := session.OutboundsFromContext(ctx)
    84  	ob := outbounds[len(outbounds) - 1]
    85  
    86  	s.info = routingInfo{
    87  		ctx:         core.ToBackgroundDetachedContext(ctx),
    88  		dispatcher:  dispatcher,
    89  		inboundTag:  session.InboundFromContext(ctx),
    90  		outboundTag: ob,
    91  		contentTag:  session.ContentFromContext(ctx),
    92  	}
    93  
    94  	ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	nep := ep.(*netEndpoint)
   100  	nep.conn = conn
   101  
   102  	reader := buf.NewPacketReader(conn)
   103  	for {
   104  		mpayload, err := reader.ReadMultiBuffer()
   105  		if err != nil {
   106  			return err
   107  		}
   108  
   109  		for _, payload := range mpayload {
   110  			v, ok := <-s.bindServer.readQueue
   111  			if !ok {
   112  				return nil
   113  			}
   114  			i, err := payload.Read(v.buff)
   115  
   116  			v.bytes = i
   117  			v.endpoint = nep
   118  			v.err = err
   119  			v.waiter.Done()
   120  			if err != nil && errors.Is(err, io.EOF) {
   121  				nep.conn = nil
   122  				return nil
   123  			}
   124  		}
   125  	}
   126  }
   127  
   128  func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
   129  	if s.info.dispatcher == nil {
   130  		newError("unexpected: dispatcher == nil").AtError().WriteToLog()
   131  		return
   132  	}
   133  	defer conn.Close()
   134  
   135  	ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
   136  	plcy := s.policyManager.ForLevel(0)
   137  	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
   138  
   139  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   140  		From:   nullDestination,
   141  		To:     dest,
   142  		Status: log.AccessAccepted,
   143  		Reason: "",
   144  	})
   145  
   146  	if s.info.inboundTag != nil {
   147  		ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
   148  	}
   149  	if s.info.outboundTag != nil {
   150  		ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag})
   151  	}
   152  	if s.info.contentTag != nil {
   153  		ctx = session.ContextWithContent(ctx, s.info.contentTag)
   154  	}
   155  
   156  	link, err := s.info.dispatcher.Dispatch(ctx, dest)
   157  	if err != nil {
   158  		newError("dispatch connection").Base(err).AtError().WriteToLog()
   159  	}
   160  	defer cancel()
   161  
   162  	requestDone := func() error {
   163  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   164  		if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
   165  			return newError("failed to transport all TCP request").Base(err)
   166  		}
   167  
   168  		return nil
   169  	}
   170  
   171  	responseDone := func() error {
   172  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   173  		if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
   174  			return newError("failed to transport all TCP response").Base(err)
   175  		}
   176  
   177  		return nil
   178  	}
   179  
   180  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   181  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   182  		common.Interrupt(link.Reader)
   183  		common.Interrupt(link.Writer)
   184  		newError("connection ends").Base(err).AtDebug().WriteToLog()
   185  		return
   186  	}
   187  }