github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/bindtest/bindtest.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package bindtest 7 8 import ( 9 "fmt" 10 "math/rand" 11 "net" 12 "net/netip" 13 "os" 14 15 "github.com/amnezia-vpn/amneziawg-go/conn" 16 ) 17 18 type ChannelBind struct { 19 rx4, tx4 *chan []byte 20 rx6, tx6 *chan []byte 21 closeSignal chan bool 22 source4, source6 ChannelEndpoint 23 target4, target6 ChannelEndpoint 24 } 25 26 type ChannelEndpoint uint16 27 28 var ( 29 _ conn.Bind = (*ChannelBind)(nil) 30 _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 ) 32 33 func NewChannelBinds() [2]conn.Bind { 34 arx4 := make(chan []byte, 8192) 35 brx4 := make(chan []byte, 8192) 36 arx6 := make(chan []byte, 8192) 37 brx6 := make(chan []byte, 8192) 38 var binds [2]ChannelBind 39 binds[0].rx4 = &arx4 40 binds[0].tx4 = &brx4 41 binds[1].rx4 = &brx4 42 binds[1].tx4 = &arx4 43 binds[0].rx6 = &arx6 44 binds[0].tx6 = &brx6 45 binds[1].rx6 = &brx6 46 binds[1].tx6 = &arx6 47 binds[0].target4 = ChannelEndpoint(1) 48 binds[1].target4 = ChannelEndpoint(2) 49 binds[0].target6 = ChannelEndpoint(3) 50 binds[1].target6 = ChannelEndpoint(4) 51 binds[0].source4 = binds[1].target4 52 binds[0].source6 = binds[1].target6 53 binds[1].source4 = binds[0].target4 54 binds[1].source6 = binds[0].target6 55 return [2]conn.Bind{&binds[0], &binds[1]} 56 } 57 58 func (c ChannelEndpoint) ClearSrc() {} 59 60 func (c ChannelEndpoint) SrcToString() string { return "" } 61 62 func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 64 func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 66 func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 68 func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 70 func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 c.closeSignal = make(chan bool) 72 fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 if rand.Uint32()&1 == 0 { 75 return fns, uint16(c.source4), nil 76 } else { 77 return fns, uint16(c.source6), nil 78 } 79 } 80 81 func (c *ChannelBind) Close() error { 82 if c.closeSignal != nil { 83 select { 84 case <-c.closeSignal: 85 default: 86 close(c.closeSignal) 87 } 88 } 89 return nil 90 } 91 92 func (c *ChannelBind) BatchSize() int { return 1 } 93 94 func (c *ChannelBind) GetOffloadInfo() string { return "" } 95 96 func (c *ChannelBind) SetMark(mark uint32) error { return nil } 97 98 func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 99 return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { 100 select { 101 case <-c.closeSignal: 102 return 0, net.ErrClosed 103 case rx := <-ch: 104 copied := copy(bufs[0], rx) 105 sizes[0] = copied 106 eps[0] = c.target6 107 return 1, nil 108 } 109 } 110 } 111 112 func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { 113 for _, b := range bufs { 114 select { 115 case <-c.closeSignal: 116 return net.ErrClosed 117 default: 118 bc := make([]byte, len(b)) 119 copy(bc, b) 120 if ep.(ChannelEndpoint) == c.target4 { 121 *c.tx4 <- bc 122 } else if ep.(ChannelEndpoint) == c.target6 { 123 *c.tx6 <- bc 124 } else { 125 return os.ErrInvalid 126 } 127 } 128 } 129 return nil 130 } 131 132 func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 133 addr, err := netip.ParseAddrPort(s) 134 if err != nil { 135 return nil, err 136 } 137 return ChannelEndpoint(addr.Port()), nil 138 }