github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/tun/tuntest/tuntest.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tuntest
     7  
     8  import (
     9  	"encoding/binary"
    10  	"io"
    11  	"net/netip"
    12  	"os"
    13  
    14  	"github.com/koomox/wireguard-go/tun"
    15  )
    16  
    17  func Ping(dst, src netip.Addr) []byte {
    18  	localPort := uint16(1337)
    19  	seq := uint16(0)
    20  
    21  	payload := make([]byte, 4)
    22  	binary.BigEndian.PutUint16(payload[0:], localPort)
    23  	binary.BigEndian.PutUint16(payload[2:], seq)
    24  
    25  	return genICMPv4(payload, dst, src)
    26  }
    27  
    28  // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
    29  func checksum(buf []byte, initial uint16) uint16 {
    30  	v := uint32(initial)
    31  	for i := 0; i < len(buf)-1; i += 2 {
    32  		v += uint32(binary.BigEndian.Uint16(buf[i:]))
    33  	}
    34  	if len(buf)%2 == 1 {
    35  		v += uint32(buf[len(buf)-1]) << 8
    36  	}
    37  	for v > 0xffff {
    38  		v = (v >> 16) + (v & 0xffff)
    39  	}
    40  	return ^uint16(v)
    41  }
    42  
    43  func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
    44  	const (
    45  		icmpv4ProtocolNumber = 1
    46  		icmpv4Echo           = 8
    47  		icmpv4ChecksumOffset = 2
    48  		icmpv4Size           = 8
    49  		ipv4Size             = 20
    50  		ipv4TotalLenOffset   = 2
    51  		ipv4ChecksumOffset   = 10
    52  		ttl                  = 65
    53  		headerSize           = ipv4Size + icmpv4Size
    54  	)
    55  
    56  	pkt := make([]byte, headerSize+len(payload))
    57  
    58  	ip := pkt[0:ipv4Size]
    59  	icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size]
    60  
    61  	// https://tools.ietf.org/html/rfc792
    62  	icmpv4[0] = icmpv4Echo // type
    63  	icmpv4[1] = 0          // code
    64  	chksum := ^checksum(icmpv4, checksum(payload, 0))
    65  	binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
    66  
    67  	// https://tools.ietf.org/html/rfc760 section 3.1
    68  	length := uint16(len(pkt))
    69  	ip[0] = (4 << 4) | (ipv4Size / 4)
    70  	binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
    71  	ip[8] = ttl
    72  	ip[9] = icmpv4ProtocolNumber
    73  	copy(ip[12:], src.AsSlice())
    74  	copy(ip[16:], dst.AsSlice())
    75  	chksum = ^checksum(ip[:], 0)
    76  	binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
    77  
    78  	copy(pkt[headerSize:], payload)
    79  	return pkt
    80  }
    81  
    82  type ChannelTUN struct {
    83  	Inbound  chan []byte // incoming packets, closed on TUN close
    84  	Outbound chan []byte // outbound packets, blocks forever on TUN close
    85  
    86  	closed chan struct{}
    87  	events chan tun.Event
    88  	tun    chTun
    89  }
    90  
    91  func NewChannelTUN() *ChannelTUN {
    92  	c := &ChannelTUN{
    93  		Inbound:  make(chan []byte),
    94  		Outbound: make(chan []byte),
    95  		closed:   make(chan struct{}),
    96  		events:   make(chan tun.Event, 1),
    97  	}
    98  	c.tun.c = c
    99  	c.events <- tun.EventUp
   100  	return c
   101  }
   102  
   103  func (c *ChannelTUN) TUN() tun.Device {
   104  	return &c.tun
   105  }
   106  
   107  type chTun struct {
   108  	c *ChannelTUN
   109  }
   110  
   111  func (t *chTun) File() *os.File { return nil }
   112  
   113  func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
   114  	select {
   115  	case <-t.c.closed:
   116  		return 0, os.ErrClosed
   117  	case msg := <-t.c.Outbound:
   118  		n := copy(packets[0][offset:], msg)
   119  		sizes[0] = n
   120  		return 1, nil
   121  	}
   122  }
   123  
   124  // Write is called by the wireguard device to deliver a packet for routing.
   125  func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
   126  	if offset == -1 {
   127  		close(t.c.closed)
   128  		close(t.c.events)
   129  		return 0, io.EOF
   130  	}
   131  	for i, data := range packets {
   132  		msg := make([]byte, len(data)-offset)
   133  		copy(msg, data[offset:])
   134  		select {
   135  		case <-t.c.closed:
   136  			return i, os.ErrClosed
   137  		case t.c.Inbound <- msg:
   138  		}
   139  	}
   140  	return len(packets), nil
   141  }
   142  
   143  func (t *chTun) BatchSize() int {
   144  	return 1
   145  }
   146  
   147  const DefaultMTU = 1420
   148  
   149  func (t *chTun) MTU() (int, error)        { return DefaultMTU, nil }
   150  func (t *chTun) Name() (string, error)    { return "loopbackTun1", nil }
   151  func (t *chTun) Events() <-chan tun.Event { return t.c.events }
   152  func (t *chTun) Close() error {
   153  	t.Write(nil, -1)
   154  	return nil
   155  }