github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/conn/bindtest/bindtest.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package bindtest 7 8 import ( 9 "fmt" 10 "math/rand" 11 "net" 12 "net/netip" 13 "os" 14 15 "github.com/koomox/wireguard-go/conn" 16 ) 17 18 type ChannelBind struct { 19 rx4, tx4 *chan []byte 20 rx6, tx6 *chan []byte 21 closeSignal chan bool 22 source4, source6 ChannelEndpoint 23 target4, target6 ChannelEndpoint 24 } 25 26 type ChannelEndpoint uint16 27 28 var ( 29 _ conn.Bind = (*ChannelBind)(nil) 30 _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 ) 32 33 func NewChannelBinds() [2]conn.Bind { 34 arx4 := make(chan []byte, 8192) 35 brx4 := make(chan []byte, 8192) 36 arx6 := make(chan []byte, 8192) 37 brx6 := make(chan []byte, 8192) 38 var binds [2]ChannelBind 39 binds[0].rx4 = &arx4 40 binds[0].tx4 = &brx4 41 binds[1].rx4 = &brx4 42 binds[1].tx4 = &arx4 43 binds[0].rx6 = &arx6 44 binds[0].tx6 = &brx6 45 binds[1].rx6 = &brx6 46 binds[1].tx6 = &arx6 47 binds[0].target4 = ChannelEndpoint(1) 48 binds[1].target4 = ChannelEndpoint(2) 49 binds[0].target6 = ChannelEndpoint(3) 50 binds[1].target6 = ChannelEndpoint(4) 51 binds[0].source4 = binds[1].target4 52 binds[0].source6 = binds[1].target6 53 binds[1].source4 = binds[0].target4 54 binds[1].source6 = binds[0].target6 55 return [2]conn.Bind{&binds[0], &binds[1]} 56 } 57 58 func (c ChannelEndpoint) ClearSrc() {} 59 60 func (c ChannelEndpoint) SrcToString() string { return "" } 61 62 func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 64 func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 66 func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 68 func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 70 func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 c.closeSignal = make(chan bool) 72 fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 if rand.Uint32()&1 == 0 { 75 return fns, uint16(c.source4), nil 76 } else { 77 return fns, uint16(c.source6), nil 78 } 79 } 80 81 func (c *ChannelBind) Close() error { 82 if c.closeSignal != nil { 83 select { 84 case <-c.closeSignal: 85 default: 86 close(c.closeSignal) 87 } 88 } 89 return nil 90 } 91 92 func (c *ChannelBind) BatchSize() int { return 1 } 93 94 func (c *ChannelBind) SetMark(mark uint32) error { return nil } 95 96 func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 97 return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { 98 select { 99 case <-c.closeSignal: 100 return 0, net.ErrClosed 101 case rx := <-ch: 102 copied := copy(bufs[0], rx) 103 sizes[0] = copied 104 eps[0] = c.target6 105 return 1, nil 106 } 107 } 108 } 109 110 func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { 111 for _, b := range bufs { 112 select { 113 case <-c.closeSignal: 114 return net.ErrClosed 115 default: 116 bc := make([]byte, len(b)) 117 copy(bc, b) 118 if ep.(ChannelEndpoint) == c.target4 { 119 *c.tx4 <- bc 120 } else if ep.(ChannelEndpoint) == c.target6 { 121 *c.tx6 <- bc 122 } else { 123 return os.ErrInvalid 124 } 125 } 126 } 127 return nil 128 } 129 130 func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 131 addr, err := netip.ParseAddrPort(s) 132 if err != nil { 133 return nil, err 134 } 135 return ChannelEndpoint(addr.Port()), nil 136 }