github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/conn/bindtest/bindtest.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package bindtest
     7  
     8  import (
     9  	"fmt"
    10  	"math/rand"
    11  	"net"
    12  	"net/netip"
    13  	"os"
    14  
    15  	"github.com/koomox/wireguard-go/conn"
    16  )
    17  
    18  type ChannelBind struct {
    19  	rx4, tx4         *chan []byte
    20  	rx6, tx6         *chan []byte
    21  	closeSignal      chan bool
    22  	source4, source6 ChannelEndpoint
    23  	target4, target6 ChannelEndpoint
    24  }
    25  
    26  type ChannelEndpoint uint16
    27  
    28  var (
    29  	_ conn.Bind     = (*ChannelBind)(nil)
    30  	_ conn.Endpoint = (*ChannelEndpoint)(nil)
    31  )
    32  
    33  func NewChannelBinds() [2]conn.Bind {
    34  	arx4 := make(chan []byte, 8192)
    35  	brx4 := make(chan []byte, 8192)
    36  	arx6 := make(chan []byte, 8192)
    37  	brx6 := make(chan []byte, 8192)
    38  	var binds [2]ChannelBind
    39  	binds[0].rx4 = &arx4
    40  	binds[0].tx4 = &brx4
    41  	binds[1].rx4 = &brx4
    42  	binds[1].tx4 = &arx4
    43  	binds[0].rx6 = &arx6
    44  	binds[0].tx6 = &brx6
    45  	binds[1].rx6 = &brx6
    46  	binds[1].tx6 = &arx6
    47  	binds[0].target4 = ChannelEndpoint(1)
    48  	binds[1].target4 = ChannelEndpoint(2)
    49  	binds[0].target6 = ChannelEndpoint(3)
    50  	binds[1].target6 = ChannelEndpoint(4)
    51  	binds[0].source4 = binds[1].target4
    52  	binds[0].source6 = binds[1].target6
    53  	binds[1].source4 = binds[0].target4
    54  	binds[1].source6 = binds[0].target6
    55  	return [2]conn.Bind{&binds[0], &binds[1]}
    56  }
    57  
    58  func (c ChannelEndpoint) ClearSrc() {}
    59  
    60  func (c ChannelEndpoint) SrcToString() string { return "" }
    61  
    62  func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
    63  
    64  func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
    65  
    66  func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
    67  
    68  func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
    69  
    70  func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
    71  	c.closeSignal = make(chan bool)
    72  	fns = append(fns, c.makeReceiveFunc(*c.rx4))
    73  	fns = append(fns, c.makeReceiveFunc(*c.rx6))
    74  	if rand.Uint32()&1 == 0 {
    75  		return fns, uint16(c.source4), nil
    76  	} else {
    77  		return fns, uint16(c.source6), nil
    78  	}
    79  }
    80  
    81  func (c *ChannelBind) Close() error {
    82  	if c.closeSignal != nil {
    83  		select {
    84  		case <-c.closeSignal:
    85  		default:
    86  			close(c.closeSignal)
    87  		}
    88  	}
    89  	return nil
    90  }
    91  
    92  func (c *ChannelBind) BatchSize() int { return 1 }
    93  
    94  func (c *ChannelBind) SetMark(mark uint32) error { return nil }
    95  
    96  func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
    97  	return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
    98  		select {
    99  		case <-c.closeSignal:
   100  			return 0, net.ErrClosed
   101  		case rx := <-ch:
   102  			copied := copy(bufs[0], rx)
   103  			sizes[0] = copied
   104  			eps[0] = c.target6
   105  			return 1, nil
   106  		}
   107  	}
   108  }
   109  
   110  func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
   111  	for _, b := range bufs {
   112  		select {
   113  		case <-c.closeSignal:
   114  			return net.ErrClosed
   115  		default:
   116  			bc := make([]byte, len(b))
   117  			copy(bc, b)
   118  			if ep.(ChannelEndpoint) == c.target4 {
   119  				*c.tx4 <- bc
   120  			} else if ep.(ChannelEndpoint) == c.target6 {
   121  				*c.tx6 <- bc
   122  			} else {
   123  				return os.ErrInvalid
   124  			}
   125  		}
   126  	}
   127  	return nil
   128  }
   129  
   130  func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
   131  	addr, err := netip.ParseAddrPort(s)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	return ChannelEndpoint(addr.Port()), nil
   136  }