github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bindtest/bindtest.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package bindtest 7 8 import ( 9 "fmt" 10 "math/rand" 11 "net" 12 "net/netip" 13 "os" 14 15 "github.com/cawidtu/notwireguard-go/conn" 16 ) 17 18 type ChannelBind struct { 19 rx4, tx4 *chan []byte 20 rx6, tx6 *chan []byte 21 closeSignal chan bool 22 source4, source6 ChannelEndpoint 23 target4, target6 ChannelEndpoint 24 } 25 26 type ChannelEndpoint uint16 27 28 var ( 29 _ conn.Bind = (*ChannelBind)(nil) 30 _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 ) 32 33 func NewChannelBinds() [2]conn.Bind { 34 arx4 := make(chan []byte, 8192) 35 brx4 := make(chan []byte, 8192) 36 arx6 := make(chan []byte, 8192) 37 brx6 := make(chan []byte, 8192) 38 var binds [2]ChannelBind 39 binds[0].rx4 = &arx4 40 binds[0].tx4 = &brx4 41 binds[1].rx4 = &brx4 42 binds[1].tx4 = &arx4 43 binds[0].rx6 = &arx6 44 binds[0].tx6 = &brx6 45 binds[1].rx6 = &brx6 46 binds[1].tx6 = &arx6 47 binds[0].target4 = ChannelEndpoint(1) 48 binds[1].target4 = ChannelEndpoint(2) 49 binds[0].target6 = ChannelEndpoint(3) 50 binds[1].target6 = ChannelEndpoint(4) 51 binds[0].source4 = binds[1].target4 52 binds[0].source6 = binds[1].target6 53 binds[1].source4 = binds[0].target4 54 binds[1].source6 = binds[0].target6 55 return [2]conn.Bind{&binds[0], &binds[1]} 56 } 57 58 func (c ChannelEndpoint) ClearSrc() {} 59 60 func (c ChannelEndpoint) SrcToString() string { return "" } 61 62 func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 64 func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 66 func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 68 func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 70 func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 c.closeSignal = make(chan bool) 72 fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 if rand.Uint32()&1 == 0 { 75 return fns, uint16(c.source4), nil 76 } else { 77 return fns, uint16(c.source6), nil 78 } 79 } 80 81 func (c *ChannelBind) Close() error { 82 if c.closeSignal != nil { 83 select { 84 case <-c.closeSignal: 85 default: 86 close(c.closeSignal) 87 } 88 } 89 return nil 90 } 91 92 func (c *ChannelBind) SetMark(mark uint32) error { return nil } 93 94 func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 95 return func(b []byte) (n int, ep conn.Endpoint, err error) { 96 select { 97 case <-c.closeSignal: 98 return 0, nil, net.ErrClosed 99 case rx := <-ch: 100 return copy(b, rx), c.target6, nil 101 } 102 } 103 } 104 105 func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { 106 select { 107 case <-c.closeSignal: 108 return net.ErrClosed 109 default: 110 bc := make([]byte, len(b)) 111 copy(bc, b) 112 if ep.(ChannelEndpoint) == c.target4 { 113 *c.tx4 <- bc 114 } else if ep.(ChannelEndpoint) == c.target6 { 115 *c.tx6 <- bc 116 } else { 117 return os.ErrInvalid 118 } 119 } 120 return nil 121 } 122 123 func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 124 addr, err := netip.ParseAddrPort(s) 125 if err != nil { 126 return nil, err 127 } 128 return ChannelEndpoint(addr.Port()), nil 129 }