github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/simple/simple.go (about)

     1  package simple
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/direct"
    14  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    16  )
    17  
    18  type Simple struct {
    19  	netapi.EmptyDispatch
    20  
    21  	p netapi.Proxy
    22  
    23  	addrs      []netapi.Address
    24  	index      atomic.Uint32
    25  	updateTime time.Time
    26  }
    27  
    28  func init() {
    29  	point.RegisterProtocol(NewClient)
    30  }
    31  
    32  func NewClient(c *protocol.Protocol_Simple) point.WrapProxy {
    33  	return func(p netapi.Proxy) (netapi.Proxy, error) {
    34  		var addrs []netapi.Address
    35  		addrs = append(addrs, netapi.ParseAddressPort(0, c.Simple.GetHost(), netapi.ParsePort(c.Simple.GetPort())))
    36  		for _, v := range c.Simple.GetAlternateHost() {
    37  			addrs = append(addrs, netapi.ParseAddressPort(0, v.GetHost(), netapi.ParsePort(v.GetPort())))
    38  		}
    39  
    40  		simple := &Simple{
    41  			addrs: addrs,
    42  			p:     p,
    43  		}
    44  
    45  		return simple, nil
    46  	}
    47  }
    48  
    49  func (c *Simple) dial(ctx context.Context, addr netapi.Address, length int) (net.Conn, error) {
    50  	ctx, cancel, er := dialer.PartialDeadlineCtx(ctx, length)
    51  	if er != nil {
    52  		// Ran out of time.
    53  		return nil, er
    54  	}
    55  	defer cancel()
    56  
    57  	if c.p != nil && !point.IsBootstrap(c.p) {
    58  		return c.p.Conn(ctx, addr)
    59  	}
    60  
    61  	return netapi.DialHappyEyeballs(ctx, addr)
    62  }
    63  
    64  func (c *Simple) Conn(ctx context.Context, _ netapi.Address) (net.Conn, error) {
    65  	return c.dialGroup(ctx)
    66  	// tconn, ok := conn.(*net.TCPConn)
    67  	// if ok {
    68  	// _ = tconn.SetKeepAlive(true)
    69  	// https://github.com/golang/go/issues/48622
    70  	// _ = tconn.SetKeepAlivePeriod(time.Minute * 3)
    71  	// }
    72  }
    73  
    74  func (c *Simple) dialGroup(ctx context.Context) (net.Conn, error) {
    75  	var err error
    76  	var conn net.Conn
    77  
    78  	lastIndex := c.index.Load()
    79  	index := lastIndex
    80  	if lastIndex != 0 && time.Since(c.updateTime) > time.Minute*15 {
    81  		index = 0
    82  	}
    83  
    84  	length := len(c.addrs)
    85  
    86  	conn, err = c.dial(ctx, c.addrs[index], length)
    87  	if err == nil {
    88  		if lastIndex != 0 && index == 0 {
    89  			c.index.Store(0)
    90  		}
    91  
    92  		return conn, nil
    93  	}
    94  
    95  	for i, addr := range c.addrs {
    96  		if i == int(index) {
    97  			continue
    98  		}
    99  
   100  		length--
   101  
   102  		con, er := c.dial(ctx, addr, length)
   103  		if er != nil {
   104  			err = errors.Join(err, er)
   105  			continue
   106  		}
   107  
   108  		conn = con
   109  		c.index.Store(uint32(i))
   110  
   111  		if i != 0 {
   112  			c.updateTime = time.Now()
   113  		}
   114  		break
   115  	}
   116  
   117  	if conn == nil {
   118  		return nil, fmt.Errorf("simple dial failed: %w", err)
   119  	}
   120  
   121  	return conn, nil
   122  }
   123  
   124  type PacketDirectKey struct{}
   125  
   126  func (c *Simple) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
   127  	if ctx.Value(PacketDirectKey{}) == true {
   128  		return direct.Default.PacketConn(ctx, addr)
   129  	}
   130  
   131  	if c.p != nil && !point.IsBootstrap(c.p) {
   132  		return c.p.PacketConn(ctx, addr)
   133  	}
   134  
   135  	conn, err := dialer.ListenPacket("udp", "")
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	ur := c.addrs[c.index.Load()].UDPAddr(ctx)
   140  
   141  	if ur.Err != nil {
   142  		return nil, ur.Err
   143  	}
   144  
   145  	return &packetConn{conn, ur.V}, nil
   146  }
   147  
   148  type packetConn struct {
   149  	net.PacketConn
   150  	addr *net.UDPAddr
   151  }
   152  
   153  func (p *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   154  	return p.PacketConn.WriteTo(b, p.addr)
   155  }
   156  
   157  func (p *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
   158  	z, _, err := p.PacketConn.ReadFrom(b)
   159  	return z, p.addr, err
   160  }