github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/bindtest/bindtest.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package bindtest
     7  
     8  import (
     9  	"fmt"
    10  	"math/rand"
    11  	"net"
    12  	"os"
    13  	"strconv"
    14  
    15  	"github.com/bugfan/wireguard-go/conn"
    16  )
    17  
    18  type ChannelBind struct {
    19  	rx4, tx4         *chan []byte
    20  	rx6, tx6         *chan []byte
    21  	closeSignal      chan bool
    22  	source4, source6 ChannelEndpoint
    23  	target4, target6 ChannelEndpoint
    24  }
    25  
    26  type ChannelEndpoint uint16
    27  
    28  var _ conn.Bind = (*ChannelBind)(nil)
    29  var _ conn.Endpoint = (*ChannelEndpoint)(nil)
    30  
    31  func NewChannelBinds() [2]conn.Bind {
    32  	arx4 := make(chan []byte, 8192)
    33  	brx4 := make(chan []byte, 8192)
    34  	arx6 := make(chan []byte, 8192)
    35  	brx6 := make(chan []byte, 8192)
    36  	var binds [2]ChannelBind
    37  	binds[0].rx4 = &arx4
    38  	binds[0].tx4 = &brx4
    39  	binds[1].rx4 = &brx4
    40  	binds[1].tx4 = &arx4
    41  	binds[0].rx6 = &arx6
    42  	binds[0].tx6 = &brx6
    43  	binds[1].rx6 = &brx6
    44  	binds[1].tx6 = &arx6
    45  	binds[0].target4 = ChannelEndpoint(1)
    46  	binds[1].target4 = ChannelEndpoint(2)
    47  	binds[0].target6 = ChannelEndpoint(3)
    48  	binds[1].target6 = ChannelEndpoint(4)
    49  	binds[0].source4 = binds[1].target4
    50  	binds[0].source6 = binds[1].target6
    51  	binds[1].source4 = binds[0].target4
    52  	binds[1].source6 = binds[0].target6
    53  	return [2]conn.Bind{&binds[0], &binds[1]}
    54  }
    55  
    56  func (c ChannelEndpoint) ClearSrc() {}
    57  
    58  func (c ChannelEndpoint) SrcToString() string { return "" }
    59  
    60  func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
    61  
    62  func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
    63  
    64  func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
    65  
    66  func (c ChannelEndpoint) SrcIP() net.IP { return nil }
    67  
    68  func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
    69  	c.closeSignal = make(chan bool)
    70  	fns = append(fns, c.makeReceiveFunc(*c.rx4))
    71  	fns = append(fns, c.makeReceiveFunc(*c.rx6))
    72  	if rand.Uint32()&1 == 0 {
    73  		return fns, uint16(c.source4), nil
    74  	} else {
    75  		return fns, uint16(c.source6), nil
    76  	}
    77  }
    78  
    79  func (c *ChannelBind) Close() error {
    80  	if c.closeSignal != nil {
    81  		select {
    82  		case <-c.closeSignal:
    83  		default:
    84  			close(c.closeSignal)
    85  		}
    86  	}
    87  	return nil
    88  }
    89  
    90  func (c *ChannelBind) SetMark(mark uint32) error { return nil }
    91  
    92  func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
    93  	return func(b []byte) (n int, ep conn.Endpoint, err error) {
    94  		select {
    95  		case <-c.closeSignal:
    96  			return 0, nil, net.ErrClosed
    97  		case rx := <-ch:
    98  			return copy(b, rx), c.target6, nil
    99  		}
   100  	}
   101  }
   102  
   103  func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
   104  	select {
   105  	case <-c.closeSignal:
   106  		return net.ErrClosed
   107  	default:
   108  		bc := make([]byte, len(b))
   109  		copy(bc, b)
   110  		if ep.(ChannelEndpoint) == c.target4 {
   111  			*c.tx4 <- bc
   112  		} else if ep.(ChannelEndpoint) == c.target6 {
   113  			*c.tx6 <- bc
   114  		} else {
   115  			return os.ErrInvalid
   116  		}
   117  	}
   118  	return nil
   119  }
   120  
   121  func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
   122  	_, port, err := net.SplitHostPort(s)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	i, err := strconv.ParseUint(port, 10, 16)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	return ChannelEndpoint(i), nil
   131  }