github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/simple/simple.go (about) 1 package simple 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "sync/atomic" 9 "time" 10 11 "github.com/Asutorufa/yuhaiin/pkg/net/dialer" 12 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 13 "github.com/Asutorufa/yuhaiin/pkg/net/proxy/direct" 14 "github.com/Asutorufa/yuhaiin/pkg/protos/node/point" 15 "github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol" 16 ) 17 18 type Simple struct { 19 netapi.EmptyDispatch 20 21 p netapi.Proxy 22 23 addrs []netapi.Address 24 index atomic.Uint32 25 updateTime time.Time 26 } 27 28 func init() { 29 point.RegisterProtocol(NewClient) 30 } 31 32 func NewClient(c *protocol.Protocol_Simple) point.WrapProxy { 33 return func(p netapi.Proxy) (netapi.Proxy, error) { 34 var addrs []netapi.Address 35 addrs = append(addrs, netapi.ParseAddressPort(0, c.Simple.GetHost(), netapi.ParsePort(c.Simple.GetPort()))) 36 for _, v := range c.Simple.GetAlternateHost() { 37 addrs = append(addrs, netapi.ParseAddressPort(0, v.GetHost(), netapi.ParsePort(v.GetPort()))) 38 } 39 40 simple := &Simple{ 41 addrs: addrs, 42 p: p, 43 } 44 45 return simple, nil 46 } 47 } 48 49 func (c *Simple) dial(ctx context.Context, addr netapi.Address, length int) (net.Conn, error) { 50 ctx, cancel, er := dialer.PartialDeadlineCtx(ctx, length) 51 if er != nil { 52 // Ran out of time. 53 return nil, er 54 } 55 defer cancel() 56 57 if c.p != nil && !point.IsBootstrap(c.p) { 58 return c.p.Conn(ctx, addr) 59 } 60 61 return netapi.DialHappyEyeballs(ctx, addr) 62 } 63 64 func (c *Simple) Conn(ctx context.Context, _ netapi.Address) (net.Conn, error) { 65 return c.dialGroup(ctx) 66 // tconn, ok := conn.(*net.TCPConn) 67 // if ok { 68 // _ = tconn.SetKeepAlive(true) 69 // https://github.com/golang/go/issues/48622 70 // _ = tconn.SetKeepAlivePeriod(time.Minute * 3) 71 // } 72 } 73 74 func (c *Simple) dialGroup(ctx context.Context) (net.Conn, error) { 75 var err error 76 var conn net.Conn 77 78 lastIndex := c.index.Load() 79 index := lastIndex 80 if lastIndex != 0 && time.Since(c.updateTime) > time.Minute*15 { 81 index = 0 82 } 83 84 length := len(c.addrs) 85 86 conn, err = c.dial(ctx, c.addrs[index], length) 87 if err == nil { 88 if lastIndex != 0 && index == 0 { 89 c.index.Store(0) 90 } 91 92 return conn, nil 93 } 94 95 for i, addr := range c.addrs { 96 if i == int(index) { 97 continue 98 } 99 100 length-- 101 102 con, er := c.dial(ctx, addr, length) 103 if er != nil { 104 err = errors.Join(err, er) 105 continue 106 } 107 108 conn = con 109 c.index.Store(uint32(i)) 110 111 if i != 0 { 112 c.updateTime = time.Now() 113 } 114 break 115 } 116 117 if conn == nil { 118 return nil, fmt.Errorf("simple dial failed: %w", err) 119 } 120 121 return conn, nil 122 } 123 124 type PacketDirectKey struct{} 125 126 func (c *Simple) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) { 127 if ctx.Value(PacketDirectKey{}) == true { 128 return direct.Default.PacketConn(ctx, addr) 129 } 130 131 if c.p != nil && !point.IsBootstrap(c.p) { 132 return c.p.PacketConn(ctx, addr) 133 } 134 135 conn, err := dialer.ListenPacket("udp", "") 136 if err != nil { 137 return nil, err 138 } 139 ur := c.addrs[c.index.Load()].UDPAddr(ctx) 140 141 if ur.Err != nil { 142 return nil, ur.Err 143 } 144 145 return &packetConn{conn, ur.V}, nil 146 } 147 148 type packetConn struct { 149 net.PacketConn 150 addr *net.UDPAddr 151 } 152 153 func (p *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { 154 return p.PacketConn.WriteTo(b, p.addr) 155 } 156 157 func (p *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { 158 z, _, err := p.PacketConn.ReadFrom(b) 159 return z, p.addr, err 160 }