github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/bindtest/bindtest.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package bindtest
     7  
     8  import (
     9  	"fmt"
    10  	"math/rand"
    11  	"net"
    12  	"net/netip"
    13  	"os"
    14  
    15  	"github.com/amnezia-vpn/amneziawg-go/conn"
    16  )
    17  
    18  type ChannelBind struct {
    19  	rx4, tx4         *chan []byte
    20  	rx6, tx6         *chan []byte
    21  	closeSignal      chan bool
    22  	source4, source6 ChannelEndpoint
    23  	target4, target6 ChannelEndpoint
    24  }
    25  
    26  type ChannelEndpoint uint16
    27  
    28  var (
    29  	_ conn.Bind     = (*ChannelBind)(nil)
    30  	_ conn.Endpoint = (*ChannelEndpoint)(nil)
    31  )
    32  
    33  func NewChannelBinds() [2]conn.Bind {
    34  	arx4 := make(chan []byte, 8192)
    35  	brx4 := make(chan []byte, 8192)
    36  	arx6 := make(chan []byte, 8192)
    37  	brx6 := make(chan []byte, 8192)
    38  	var binds [2]ChannelBind
    39  	binds[0].rx4 = &arx4
    40  	binds[0].tx4 = &brx4
    41  	binds[1].rx4 = &brx4
    42  	binds[1].tx4 = &arx4
    43  	binds[0].rx6 = &arx6
    44  	binds[0].tx6 = &brx6
    45  	binds[1].rx6 = &brx6
    46  	binds[1].tx6 = &arx6
    47  	binds[0].target4 = ChannelEndpoint(1)
    48  	binds[1].target4 = ChannelEndpoint(2)
    49  	binds[0].target6 = ChannelEndpoint(3)
    50  	binds[1].target6 = ChannelEndpoint(4)
    51  	binds[0].source4 = binds[1].target4
    52  	binds[0].source6 = binds[1].target6
    53  	binds[1].source4 = binds[0].target4
    54  	binds[1].source6 = binds[0].target6
    55  	return [2]conn.Bind{&binds[0], &binds[1]}
    56  }
    57  
    58  func (c ChannelEndpoint) ClearSrc() {}
    59  
    60  func (c ChannelEndpoint) SrcToString() string { return "" }
    61  
    62  func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
    63  
    64  func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
    65  
    66  func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
    67  
    68  func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
    69  
    70  func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
    71  	c.closeSignal = make(chan bool)
    72  	fns = append(fns, c.makeReceiveFunc(*c.rx4))
    73  	fns = append(fns, c.makeReceiveFunc(*c.rx6))
    74  	if rand.Uint32()&1 == 0 {
    75  		return fns, uint16(c.source4), nil
    76  	} else {
    77  		return fns, uint16(c.source6), nil
    78  	}
    79  }
    80  
    81  func (c *ChannelBind) Close() error {
    82  	if c.closeSignal != nil {
    83  		select {
    84  		case <-c.closeSignal:
    85  		default:
    86  			close(c.closeSignal)
    87  		}
    88  	}
    89  	return nil
    90  }
    91  
    92  func (c *ChannelBind) BatchSize() int { return 1 }
    93  
    94  func (c *ChannelBind) GetOffloadInfo() string { return "" }
    95  
    96  func (c *ChannelBind) SetMark(mark uint32) error { return nil }
    97  
    98  func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
    99  	return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
   100  		select {
   101  		case <-c.closeSignal:
   102  			return 0, net.ErrClosed
   103  		case rx := <-ch:
   104  			copied := copy(bufs[0], rx)
   105  			sizes[0] = copied
   106  			eps[0] = c.target6
   107  			return 1, nil
   108  		}
   109  	}
   110  }
   111  
   112  func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
   113  	for _, b := range bufs {
   114  		select {
   115  		case <-c.closeSignal:
   116  			return net.ErrClosed
   117  		default:
   118  			bc := make([]byte, len(b))
   119  			copy(bc, b)
   120  			if ep.(ChannelEndpoint) == c.target4 {
   121  				*c.tx4 <- bc
   122  			} else if ep.(ChannelEndpoint) == c.target6 {
   123  				*c.tx6 <- bc
   124  			} else {
   125  				return os.ErrInvalid
   126  			}
   127  		}
   128  	}
   129  	return nil
   130  }
   131  
   132  func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
   133  	addr, err := netip.ParseAddrPort(s)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	return ChannelEndpoint(addr.Port()), nil
   138  }