github.com/eagleql/xray-core@v1.4.4/proxy/dokodemo/dokodemo.go (about)

     1  package dokodemo
     2  
     3  //go:generate go run github.com/eagleql/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/eagleql/xray-core/common"
    11  	"github.com/eagleql/xray-core/common/buf"
    12  	"github.com/eagleql/xray-core/common/log"
    13  	"github.com/eagleql/xray-core/common/net"
    14  	"github.com/eagleql/xray-core/common/protocol"
    15  	"github.com/eagleql/xray-core/common/session"
    16  	"github.com/eagleql/xray-core/common/signal"
    17  	"github.com/eagleql/xray-core/common/task"
    18  	"github.com/eagleql/xray-core/core"
    19  	"github.com/eagleql/xray-core/features/policy"
    20  	"github.com/eagleql/xray-core/features/routing"
    21  	"github.com/eagleql/xray-core/transport/internet"
    22  )
    23  
    24  func init() {
    25  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    26  		d := new(DokodemoDoor)
    27  		err := core.RequireFeatures(ctx, func(pm policy.Manager) error {
    28  			return d.Init(config.(*Config), pm, session.SockoptFromContext(ctx))
    29  		})
    30  		return d, err
    31  	}))
    32  }
    33  
    34  type DokodemoDoor struct {
    35  	policyManager policy.Manager
    36  	config        *Config
    37  	address       net.Address
    38  	port          net.Port
    39  	sockopt       *session.Sockopt
    40  }
    41  
    42  // Init initializes the DokodemoDoor instance with necessary parameters.
    43  func (d *DokodemoDoor) Init(config *Config, pm policy.Manager, sockopt *session.Sockopt) error {
    44  	if (config.NetworkList == nil || len(config.NetworkList.Network) == 0) && len(config.Networks) == 0 {
    45  		return newError("no network specified")
    46  	}
    47  	d.config = config
    48  	d.address = config.GetPredefinedAddress()
    49  	d.port = net.Port(config.Port)
    50  	d.policyManager = pm
    51  	d.sockopt = sockopt
    52  
    53  	return nil
    54  }
    55  
    56  // Network implements proxy.Inbound.
    57  func (d *DokodemoDoor) Network() []net.Network {
    58  	if len(d.config.Networks) > 0 {
    59  		return d.config.Networks
    60  	}
    61  
    62  	return d.config.NetworkList.Network
    63  }
    64  
    65  func (d *DokodemoDoor) policy() policy.Session {
    66  	config := d.config
    67  	p := d.policyManager.ForLevel(config.UserLevel)
    68  	if config.Timeout > 0 && config.UserLevel == 0 {
    69  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    70  	}
    71  	return p
    72  }
    73  
    74  type hasHandshakeAddress interface {
    75  	HandshakeAddress() net.Address
    76  }
    77  
    78  // Process implements proxy.Inbound.
    79  func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    80  	newError("processing connection from: ", conn.RemoteAddr()).AtDebug().WriteToLog(session.ExportIDToError(ctx))
    81  	dest := net.Destination{
    82  		Network: network,
    83  		Address: d.address,
    84  		Port:    d.port,
    85  	}
    86  
    87  	destinationOverridden := false
    88  	if d.config.FollowRedirect {
    89  		if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
    90  			dest = outbound.Target
    91  			destinationOverridden = true
    92  		} else if handshake, ok := conn.(hasHandshakeAddress); ok {
    93  			addr := handshake.HandshakeAddress()
    94  			if addr != nil {
    95  				dest.Address = addr
    96  				destinationOverridden = true
    97  			}
    98  		}
    99  	}
   100  	if !dest.IsValid() || dest.Address == nil {
   101  		return newError("unable to get destination")
   102  	}
   103  
   104  	inbound := session.InboundFromContext(ctx)
   105  	if inbound != nil {
   106  		inbound.User = &protocol.MemoryUser{
   107  			Level: d.config.UserLevel,
   108  		}
   109  	}
   110  
   111  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   112  		From:   conn.RemoteAddr(),
   113  		To:     dest,
   114  		Status: log.AccessAccepted,
   115  		Reason: "",
   116  	})
   117  	newError("received request for ", conn.RemoteAddr()).WriteToLog(session.ExportIDToError(ctx))
   118  
   119  	plcy := d.policy()
   120  	ctx, cancel := context.WithCancel(ctx)
   121  	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
   122  
   123  	if inbound != nil {
   124  		inbound.Timer = timer
   125  	}
   126  
   127  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   128  	link, err := dispatcher.Dispatch(ctx, dest)
   129  	if err != nil {
   130  		return newError("failed to dispatch request").Base(err)
   131  	}
   132  
   133  	requestCount := int32(1)
   134  	requestDone := func() error {
   135  		defer func() {
   136  			if atomic.AddInt32(&requestCount, -1) == 0 {
   137  				timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   138  			}
   139  		}()
   140  
   141  		var reader buf.Reader
   142  		if dest.Network == net.Network_UDP {
   143  			reader = buf.NewPacketReader(conn)
   144  		} else {
   145  			reader = buf.NewReader(conn)
   146  		}
   147  		if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   148  			return newError("failed to transport request").Base(err)
   149  		}
   150  
   151  		return nil
   152  	}
   153  
   154  	tproxyRequest := func() error {
   155  		return nil
   156  	}
   157  
   158  	var writer buf.Writer
   159  	if network == net.Network_TCP {
   160  		writer = buf.NewWriter(conn)
   161  	} else {
   162  		// if we are in TPROXY mode, use linux's udp forging functionality
   163  		if !destinationOverridden {
   164  			writer = &buf.SequentialWriter{Writer: conn}
   165  		} else {
   166  			back := conn.RemoteAddr().(*net.UDPAddr)
   167  			if !dest.Address.Family().IsIP() {
   168  				if len(back.IP) == 4 {
   169  					dest.Address = net.AnyIP
   170  				} else {
   171  					dest.Address = net.AnyIPv6
   172  				}
   173  			}
   174  			addr := &net.UDPAddr{
   175  				IP:   dest.Address.IP(),
   176  				Port: int(dest.Port),
   177  			}
   178  			var mark int
   179  			if d.sockopt != nil {
   180  				mark = int(d.sockopt.Mark)
   181  			}
   182  			pConn, err := FakeUDP(addr, mark)
   183  			if err != nil {
   184  				return err
   185  			}
   186  			writer = NewPacketWriter(pConn, &dest, mark, back)
   187  			defer writer.(*PacketWriter).Close()
   188  			/*
   189  				sockopt := &internet.SocketConfig{
   190  					Tproxy: internet.SocketConfig_TProxy,
   191  				}
   192  				if dest.Address.Family().IsIP() {
   193  					sockopt.BindAddress = dest.Address.IP()
   194  					sockopt.BindPort = uint32(dest.Port)
   195  				}
   196  				if d.sockopt != nil {
   197  					sockopt.Mark = d.sockopt.Mark
   198  				}
   199  				tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
   200  				if err != nil {
   201  					return err
   202  				}
   203  				defer tConn.Close()
   204  
   205  				writer = &buf.SequentialWriter{Writer: tConn}
   206  				tReader := buf.NewPacketReader(tConn)
   207  				requestCount++
   208  				tproxyRequest = func() error {
   209  					defer func() {
   210  						if atomic.AddInt32(&requestCount, -1) == 0 {
   211  							timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   212  						}
   213  					}()
   214  					if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   215  						return newError("failed to transport request (TPROXY conn)").Base(err)
   216  					}
   217  					return nil
   218  				}
   219  			*/
   220  		}
   221  	}
   222  
   223  	responseDone := func() error {
   224  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   225  
   226  		if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
   227  			return newError("failed to transport response").Base(err)
   228  		}
   229  		return nil
   230  	}
   231  
   232  	if err := task.Run(ctx, task.OnSuccess(func() error {
   233  		return task.Run(ctx, requestDone, tproxyRequest)
   234  	}, task.Close(link.Writer)), responseDone); err != nil {
   235  		common.Interrupt(link.Reader)
   236  		common.Interrupt(link.Writer)
   237  		return newError("connection ends").Base(err)
   238  	}
   239  
   240  	return nil
   241  }
   242  
   243  func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {
   244  	writer := &PacketWriter{
   245  		conn:  conn,
   246  		conns: make(map[net.Destination]net.PacketConn),
   247  		mark:  mark,
   248  		back:  back,
   249  	}
   250  	writer.conns[*d] = conn
   251  	return writer
   252  }
   253  
   254  type PacketWriter struct {
   255  	conn  net.PacketConn
   256  	conns map[net.Destination]net.PacketConn
   257  	mark  int
   258  	back  *net.UDPAddr
   259  }
   260  
   261  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   262  	for {
   263  		mb2, b := buf.SplitFirst(mb)
   264  		mb = mb2
   265  		if b == nil {
   266  			break
   267  		}
   268  		var err error
   269  		if b.UDP != nil && b.UDP.Address.Family().IsIP() {
   270  			conn := w.conns[*b.UDP]
   271  			if conn == nil {
   272  				conn, err = FakeUDP(
   273  					&net.UDPAddr{
   274  						IP:   b.UDP.Address.IP(),
   275  						Port: int(b.UDP.Port),
   276  					},
   277  					w.mark,
   278  				)
   279  				if err != nil {
   280  					newError(err).WriteToLog()
   281  					b.Release()
   282  					continue
   283  				}
   284  				w.conns[*b.UDP] = conn
   285  			}
   286  			_, err = conn.WriteTo(b.Bytes(), w.back)
   287  			if err != nil {
   288  				newError(err).WriteToLog()
   289  				w.conns[*b.UDP] = nil
   290  				conn.Close()
   291  			}
   292  			b.Release()
   293  		} else {
   294  			_, err = w.conn.WriteTo(b.Bytes(), w.back)
   295  			b.Release()
   296  			if err != nil {
   297  				buf.ReleaseMulti(mb)
   298  				return err
   299  			}
   300  		}
   301  	}
   302  	return nil
   303  }
   304  
   305  func (w *PacketWriter) Close() error {
   306  	for _, conn := range w.conns {
   307  		if conn != nil {
   308  			conn.Close()
   309  		}
   310  	}
   311  	return nil
   312  }