github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bindtest/bindtest.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package bindtest
     7  
     8  import (
     9  	"fmt"
    10  	"math/rand"
    11  	"net"
    12  	"net/netip"
    13  	"os"
    14  
    15  	"github.com/cawidtu/notwireguard-go/conn"
    16  )
    17  
    18  type ChannelBind struct {
    19  	rx4, tx4         *chan []byte
    20  	rx6, tx6         *chan []byte
    21  	closeSignal      chan bool
    22  	source4, source6 ChannelEndpoint
    23  	target4, target6 ChannelEndpoint
    24  }
    25  
    26  type ChannelEndpoint uint16
    27  
    28  var (
    29  	_ conn.Bind     = (*ChannelBind)(nil)
    30  	_ conn.Endpoint = (*ChannelEndpoint)(nil)
    31  )
    32  
    33  func NewChannelBinds() [2]conn.Bind {
    34  	arx4 := make(chan []byte, 8192)
    35  	brx4 := make(chan []byte, 8192)
    36  	arx6 := make(chan []byte, 8192)
    37  	brx6 := make(chan []byte, 8192)
    38  	var binds [2]ChannelBind
    39  	binds[0].rx4 = &arx4
    40  	binds[0].tx4 = &brx4
    41  	binds[1].rx4 = &brx4
    42  	binds[1].tx4 = &arx4
    43  	binds[0].rx6 = &arx6
    44  	binds[0].tx6 = &brx6
    45  	binds[1].rx6 = &brx6
    46  	binds[1].tx6 = &arx6
    47  	binds[0].target4 = ChannelEndpoint(1)
    48  	binds[1].target4 = ChannelEndpoint(2)
    49  	binds[0].target6 = ChannelEndpoint(3)
    50  	binds[1].target6 = ChannelEndpoint(4)
    51  	binds[0].source4 = binds[1].target4
    52  	binds[0].source6 = binds[1].target6
    53  	binds[1].source4 = binds[0].target4
    54  	binds[1].source6 = binds[0].target6
    55  	return [2]conn.Bind{&binds[0], &binds[1]}
    56  }
    57  
    58  func (c ChannelEndpoint) ClearSrc() {}
    59  
    60  func (c ChannelEndpoint) SrcToString() string { return "" }
    61  
    62  func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
    63  
    64  func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
    65  
    66  func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
    67  
    68  func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
    69  
    70  func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
    71  	c.closeSignal = make(chan bool)
    72  	fns = append(fns, c.makeReceiveFunc(*c.rx4))
    73  	fns = append(fns, c.makeReceiveFunc(*c.rx6))
    74  	if rand.Uint32()&1 == 0 {
    75  		return fns, uint16(c.source4), nil
    76  	} else {
    77  		return fns, uint16(c.source6), nil
    78  	}
    79  }
    80  
    81  func (c *ChannelBind) Close() error {
    82  	if c.closeSignal != nil {
    83  		select {
    84  		case <-c.closeSignal:
    85  		default:
    86  			close(c.closeSignal)
    87  		}
    88  	}
    89  	return nil
    90  }
    91  
    92  func (c *ChannelBind) SetMark(mark uint32) error { return nil }
    93  
    94  func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
    95  	return func(b []byte) (n int, ep conn.Endpoint, err error) {
    96  		select {
    97  		case <-c.closeSignal:
    98  			return 0, nil, net.ErrClosed
    99  		case rx := <-ch:
   100  			return copy(b, rx), c.target6, nil
   101  		}
   102  	}
   103  }
   104  
   105  func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
   106  	select {
   107  	case <-c.closeSignal:
   108  		return net.ErrClosed
   109  	default:
   110  		bc := make([]byte, len(b))
   111  		copy(bc, b)
   112  		if ep.(ChannelEndpoint) == c.target4 {
   113  			*c.tx4 <- bc
   114  		} else if ep.(ChannelEndpoint) == c.target6 {
   115  			*c.tx6 <- bc
   116  		} else {
   117  			return os.ErrInvalid
   118  		}
   119  	}
   120  	return nil
   121  }
   122  
   123  func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
   124  	addr, err := netip.ParseAddrPort(s)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	return ChannelEndpoint(addr.Port()), nil
   129  }