github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/conn/bindtest/bindtest.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package bindtest 7 8 import ( 9 "fmt" 10 "math/rand" 11 "net" 12 "os" 13 "strconv" 14 15 "github.com/tailscale/wireguard-go/conn" 16 ) 17 18 type ChannelBind struct { 19 rx4, tx4 *chan []byte 20 rx6, tx6 *chan []byte 21 closeSignal chan bool 22 source4, source6 ChannelEndpoint 23 target4, target6 ChannelEndpoint 24 } 25 26 type ChannelEndpoint uint16 27 28 var _ conn.Bind = (*ChannelBind)(nil) 29 var _ conn.Endpoint = (*ChannelEndpoint)(nil) 30 31 func NewChannelBinds() [2]conn.Bind { 32 arx4 := make(chan []byte, 8192) 33 brx4 := make(chan []byte, 8192) 34 arx6 := make(chan []byte, 8192) 35 brx6 := make(chan []byte, 8192) 36 var binds [2]ChannelBind 37 binds[0].rx4 = &arx4 38 binds[0].tx4 = &brx4 39 binds[1].rx4 = &brx4 40 binds[1].tx4 = &arx4 41 binds[0].rx6 = &arx6 42 binds[0].tx6 = &brx6 43 binds[1].rx6 = &brx6 44 binds[1].tx6 = &arx6 45 binds[0].target4 = ChannelEndpoint(1) 46 binds[1].target4 = ChannelEndpoint(2) 47 binds[0].target6 = ChannelEndpoint(3) 48 binds[1].target6 = ChannelEndpoint(4) 49 binds[0].source4 = binds[1].target4 50 binds[0].source6 = binds[1].target6 51 binds[1].source4 = binds[0].target4 52 binds[1].source6 = binds[0].target6 53 return [2]conn.Bind{&binds[0], &binds[1]} 54 } 55 56 func (c ChannelEndpoint) ClearSrc() {} 57 58 func (c ChannelEndpoint) SrcToString() string { return "" } 59 60 func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 61 62 func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 63 64 func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } 65 66 func (c ChannelEndpoint) SrcIP() net.IP { return nil } 67 68 func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 69 c.closeSignal = make(chan bool) 70 fns = append(fns, c.makeReceiveFunc(*c.rx4)) 71 fns = append(fns, c.makeReceiveFunc(*c.rx6)) 72 if rand.Uint32()&1 == 0 { 73 return fns, uint16(c.source4), nil 74 } else { 75 return fns, uint16(c.source6), nil 76 } 77 } 78 79 func (c *ChannelBind) Close() error { 80 if c.closeSignal != nil { 81 select { 82 case <-c.closeSignal: 83 default: 84 close(c.closeSignal) 85 } 86 } 87 return nil 88 } 89 90 func (c *ChannelBind) SetMark(mark uint32) error { return nil } 91 92 func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 93 return func(b []byte) (n int, ep conn.Endpoint, err error) { 94 select { 95 case <-c.closeSignal: 96 return 0, nil, net.ErrClosed 97 case rx := <-ch: 98 return copy(b, rx), c.target6, nil 99 } 100 } 101 } 102 103 func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { 104 select { 105 case <-c.closeSignal: 106 return net.ErrClosed 107 default: 108 bc := make([]byte, len(b)) 109 copy(bc, b) 110 if ep.(ChannelEndpoint) == c.target4 { 111 *c.tx4 <- bc 112 } else if ep.(ChannelEndpoint) == c.target6 { 113 *c.tx6 <- bc 114 } else { 115 return os.ErrInvalid 116 } 117 } 118 return nil 119 } 120 121 func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 122 _, port, err := net.SplitHostPort(s) 123 if err != nil { 124 return nil, err 125 } 126 i, err := strconv.ParseUint(port, 10, 16) 127 if err != nil { 128 return nil, err 129 } 130 return ChannelEndpoint(i), nil 131 }