github.com/moqsien/xraycore@v1.8.5/proxy/dokodemo/dokodemo.go (about)

     1  package dokodemo
     2  
     3  //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/moqsien/xraycore/common"
    11  	"github.com/moqsien/xraycore/common/buf"
    12  	"github.com/moqsien/xraycore/common/log"
    13  	"github.com/moqsien/xraycore/common/net"
    14  	"github.com/moqsien/xraycore/common/protocol"
    15  	"github.com/moqsien/xraycore/common/session"
    16  	"github.com/moqsien/xraycore/common/signal"
    17  	"github.com/moqsien/xraycore/common/task"
    18  	"github.com/moqsien/xraycore/core"
    19  	"github.com/moqsien/xraycore/features/policy"
    20  	"github.com/moqsien/xraycore/features/routing"
    21  	"github.com/moqsien/xraycore/transport/internet/stat"
    22  )
    23  
    24  func init() {
    25  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    26  		d := new(DokodemoDoor)
    27  		err := core.RequireFeatures(ctx, func(pm policy.Manager) error {
    28  			return d.Init(config.(*Config), pm, session.SockoptFromContext(ctx))
    29  		})
    30  		return d, err
    31  	}))
    32  }
    33  
    34  type DokodemoDoor struct {
    35  	policyManager policy.Manager
    36  	config        *Config
    37  	address       net.Address
    38  	port          net.Port
    39  	sockopt       *session.Sockopt
    40  }
    41  
    42  // Init initializes the DokodemoDoor instance with necessary parameters.
    43  func (d *DokodemoDoor) Init(config *Config, pm policy.Manager, sockopt *session.Sockopt) error {
    44  	if (config.NetworkList == nil || len(config.NetworkList.Network) == 0) && len(config.Networks) == 0 {
    45  		return newError("no network specified")
    46  	}
    47  	d.config = config
    48  	d.address = config.GetPredefinedAddress()
    49  	d.port = net.Port(config.Port)
    50  	d.policyManager = pm
    51  	d.sockopt = sockopt
    52  
    53  	return nil
    54  }
    55  
    56  // Network implements proxy.Inbound.
    57  func (d *DokodemoDoor) Network() []net.Network {
    58  	if len(d.config.Networks) > 0 {
    59  		return d.config.Networks
    60  	}
    61  
    62  	return d.config.NetworkList.Network
    63  }
    64  
    65  func (d *DokodemoDoor) policy() policy.Session {
    66  	config := d.config
    67  	p := d.policyManager.ForLevel(config.UserLevel)
    68  	if config.Timeout > 0 && config.UserLevel == 0 {
    69  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    70  	}
    71  	return p
    72  }
    73  
    74  type hasHandshakeAddress interface {
    75  	HandshakeAddress() net.Address
    76  }
    77  
    78  // Process implements proxy.Inbound.
    79  func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    80  	newError("processing connection from: ", conn.RemoteAddr()).AtDebug().WriteToLog(session.ExportIDToError(ctx))
    81  	dest := net.Destination{
    82  		Network: network,
    83  		Address: d.address,
    84  		Port:    d.port,
    85  	}
    86  
    87  	destinationOverridden := false
    88  	if d.config.FollowRedirect {
    89  		if outbound := session.OutboundFromContext(ctx); outbound != nil && outbound.Target.IsValid() {
    90  			dest = outbound.Target
    91  			destinationOverridden = true
    92  		} else if handshake, ok := conn.(hasHandshakeAddress); ok {
    93  			addr := handshake.HandshakeAddress()
    94  			if addr != nil {
    95  				dest.Address = addr
    96  				destinationOverridden = true
    97  			}
    98  		}
    99  	}
   100  	if !dest.IsValid() || dest.Address == nil {
   101  		return newError("unable to get destination")
   102  	}
   103  
   104  	inbound := session.InboundFromContext(ctx)
   105  	if inbound != nil {
   106  		inbound.Name = "dokodemo-door"
   107  		inbound.User = &protocol.MemoryUser{
   108  			Level: d.config.UserLevel,
   109  		}
   110  	}
   111  
   112  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   113  		From:   conn.RemoteAddr(),
   114  		To:     dest,
   115  		Status: log.AccessAccepted,
   116  		Reason: "",
   117  	})
   118  	newError("received request for ", conn.RemoteAddr()).WriteToLog(session.ExportIDToError(ctx))
   119  
   120  	plcy := d.policy()
   121  	ctx, cancel := context.WithCancel(ctx)
   122  	timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
   123  
   124  	if inbound != nil {
   125  		inbound.Timer = timer
   126  	}
   127  
   128  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   129  	link, err := dispatcher.Dispatch(ctx, dest)
   130  	if err != nil {
   131  		return newError("failed to dispatch request").Base(err)
   132  	}
   133  
   134  	requestCount := int32(1)
   135  	requestDone := func() error {
   136  		defer func() {
   137  			if atomic.AddInt32(&requestCount, -1) == 0 {
   138  				timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   139  			}
   140  		}()
   141  
   142  		var reader buf.Reader
   143  		if dest.Network == net.Network_UDP {
   144  			reader = buf.NewPacketReader(conn)
   145  		} else {
   146  			reader = buf.NewReader(conn)
   147  		}
   148  		if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   149  			return newError("failed to transport request").Base(err)
   150  		}
   151  
   152  		return nil
   153  	}
   154  
   155  	tproxyRequest := func() error {
   156  		return nil
   157  	}
   158  
   159  	var writer buf.Writer
   160  	if network == net.Network_TCP {
   161  		writer = buf.NewWriter(conn)
   162  	} else {
   163  		// if we are in TPROXY mode, use linux's udp forging functionality
   164  		if !destinationOverridden {
   165  			writer = &buf.SequentialWriter{Writer: conn}
   166  		} else {
   167  			back := conn.RemoteAddr().(*net.UDPAddr)
   168  			if !dest.Address.Family().IsIP() {
   169  				if len(back.IP) == 4 {
   170  					dest.Address = net.AnyIP
   171  				} else {
   172  					dest.Address = net.AnyIPv6
   173  				}
   174  			}
   175  			addr := &net.UDPAddr{
   176  				IP:   dest.Address.IP(),
   177  				Port: int(dest.Port),
   178  			}
   179  			var mark int
   180  			if d.sockopt != nil {
   181  				mark = int(d.sockopt.Mark)
   182  			}
   183  			pConn, err := FakeUDP(addr, mark)
   184  			if err != nil {
   185  				return err
   186  			}
   187  			writer = NewPacketWriter(pConn, &dest, mark, back)
   188  			defer writer.(*PacketWriter).Close()
   189  			/*
   190  				sockopt := &internet.SocketConfig{
   191  					Tproxy: internet.SocketConfig_TProxy,
   192  				}
   193  				if dest.Address.Family().IsIP() {
   194  					sockopt.BindAddress = dest.Address.IP()
   195  					sockopt.BindPort = uint32(dest.Port)
   196  				}
   197  				if d.sockopt != nil {
   198  					sockopt.Mark = d.sockopt.Mark
   199  				}
   200  				tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
   201  				if err != nil {
   202  					return err
   203  				}
   204  				defer tConn.Close()
   205  
   206  				writer = &buf.SequentialWriter{Writer: tConn}
   207  				tReader := buf.NewPacketReader(tConn)
   208  				requestCount++
   209  				tproxyRequest = func() error {
   210  					defer func() {
   211  						if atomic.AddInt32(&requestCount, -1) == 0 {
   212  							timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   213  						}
   214  					}()
   215  					if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   216  						return newError("failed to transport request (TPROXY conn)").Base(err)
   217  					}
   218  					return nil
   219  				}
   220  			*/
   221  		}
   222  	}
   223  
   224  	responseDone := func() error {
   225  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   226  
   227  		if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
   228  			return newError("failed to transport response").Base(err)
   229  		}
   230  		return nil
   231  	}
   232  
   233  	if err := task.Run(ctx, task.OnSuccess(func() error {
   234  		return task.Run(ctx, requestDone, tproxyRequest)
   235  	}, task.Close(link.Writer)), responseDone); err != nil {
   236  		common.Interrupt(link.Reader)
   237  		common.Interrupt(link.Writer)
   238  		return newError("connection ends").Base(err)
   239  	}
   240  
   241  	return nil
   242  }
   243  
   244  func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {
   245  	writer := &PacketWriter{
   246  		conn:  conn,
   247  		conns: make(map[net.Destination]net.PacketConn),
   248  		mark:  mark,
   249  		back:  back,
   250  	}
   251  	writer.conns[*d] = conn
   252  	return writer
   253  }
   254  
   255  type PacketWriter struct {
   256  	conn  net.PacketConn
   257  	conns map[net.Destination]net.PacketConn
   258  	mark  int
   259  	back  *net.UDPAddr
   260  }
   261  
   262  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   263  	for {
   264  		mb2, b := buf.SplitFirst(mb)
   265  		mb = mb2
   266  		if b == nil {
   267  			break
   268  		}
   269  		var err error
   270  		if b.UDP != nil && b.UDP.Address.Family().IsIP() {
   271  			conn := w.conns[*b.UDP]
   272  			if conn == nil {
   273  				conn, err = FakeUDP(
   274  					&net.UDPAddr{
   275  						IP:   b.UDP.Address.IP(),
   276  						Port: int(b.UDP.Port),
   277  					},
   278  					w.mark,
   279  				)
   280  				if err != nil {
   281  					newError(err).WriteToLog()
   282  					b.Release()
   283  					continue
   284  				}
   285  				w.conns[*b.UDP] = conn
   286  			}
   287  			_, err = conn.WriteTo(b.Bytes(), w.back)
   288  			if err != nil {
   289  				newError(err).WriteToLog()
   290  				w.conns[*b.UDP] = nil
   291  				conn.Close()
   292  			}
   293  			b.Release()
   294  		} else {
   295  			_, err = w.conn.WriteTo(b.Bytes(), w.back)
   296  			b.Release()
   297  			if err != nil {
   298  				buf.ReleaseMulti(mb)
   299  				return err
   300  			}
   301  		}
   302  	}
   303  	return nil
   304  }
   305  
   306  func (w *PacketWriter) Close() error {
   307  	for _, conn := range w.conns {
   308  		if conn != nil {
   309  			conn.Close()
   310  		}
   311  	}
   312  	return nil
   313  }