github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/wrappers.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   */
     9  
    10  package noisysockets
    11  
    12  import (
    13  	stdnet "net"
    14  	"net/netip"
    15  	"strings"
    16  
    17  	"github.com/noisysockets/netstack/pkg/tcpip"
    18  	"github.com/noisysockets/noisysockets/types"
    19  )
    20  
    21  // Addr is a wrapper around net.Addr that includes the source NoisePublicKey.
    22  type Addr struct {
    23  	stdnet.Addr
    24  	pk types.NoisePublicKey
    25  }
    26  
    27  // PublicKey returns the NoisePublicKey of the peer.
    28  func (a *Addr) PublicKey() types.NoisePublicKey {
    29  	return a.pk
    30  }
    31  
    32  // Conn is a wrapper around net.Conn that includes the source NoisePublicKey.
    33  type Conn struct {
    34  	stdnet.Conn
    35  	peers *peerList
    36  }
    37  
    38  func (c *Conn) RemoteAddr() stdnet.Addr {
    39  	remoteAddr := c.Conn.RemoteAddr()
    40  	if remoteAddr == nil {
    41  		return nil
    42  	}
    43  
    44  	peer, ok := c.peers.getByAddress(netip.MustParseAddrPort(remoteAddr.String()).Addr())
    45  	if !ok {
    46  		// Just return the standard address if we can't find the peer.
    47  		return remoteAddr
    48  	}
    49  
    50  	return &Addr{Addr: c.Conn.RemoteAddr(), pk: peer.PublicKey()}
    51  }
    52  
    53  type listener struct {
    54  	stdnet.Listener
    55  	peers *peerList
    56  }
    57  
    58  func (l *listener) Accept() (stdnet.Conn, error) {
    59  	conn, err := l.Listener.Accept()
    60  	if err != nil {
    61  		// The network stack was closed.
    62  		if strings.Contains(err.Error(), (&tcpip.ErrInvalidEndpointState{}).String()) {
    63  			return nil, stdnet.ErrClosed
    64  		}
    65  
    66  		return nil, err
    67  	}
    68  
    69  	return &Conn{Conn: conn, peers: l.peers}, nil
    70  }
    71  
    72  type packetConn struct {
    73  	stdnet.PacketConn
    74  	peers *peerList
    75  }
    76  
    77  func (pc *packetConn) ReadFrom(b []byte) (int, stdnet.Addr, error) {
    78  	n, addr, err := pc.PacketConn.ReadFrom(b)
    79  	if addr == nil {
    80  		return n, nil, err
    81  	}
    82  
    83  	peer, ok := pc.peers.getByAddress(netip.MustParseAddrPort(addr.String()).Addr())
    84  	if !ok {
    85  		// Just return the standard address if we can't find the peer.
    86  		return n, addr, err
    87  	}
    88  
    89  	return n, &Addr{Addr: addr, pk: peer.PublicKey()}, err
    90  }
    91  
    92  func (pc *packetConn) WriteTo(b []byte, addr stdnet.Addr) (int, error) {
    93  	addrPort, err := netip.ParseAddrPort(addr.String())
    94  	if err != nil {
    95  		return 0, err
    96  	}
    97  
    98  	return pc.PacketConn.WriteTo(b, &stdnet.UDPAddr{
    99  		IP:   addrPort.Addr().AsSlice(),
   100  		Port: int(addrPort.Port()),
   101  	})
   102  }