github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/wrappers.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 */ 9 10 package noisysockets 11 12 import ( 13 stdnet "net" 14 "net/netip" 15 "strings" 16 17 "github.com/noisysockets/netstack/pkg/tcpip" 18 "github.com/noisysockets/noisysockets/types" 19 ) 20 21 // Addr is a wrapper around net.Addr that includes the source NoisePublicKey. 22 type Addr struct { 23 stdnet.Addr 24 pk types.NoisePublicKey 25 } 26 27 // PublicKey returns the NoisePublicKey of the peer. 28 func (a *Addr) PublicKey() types.NoisePublicKey { 29 return a.pk 30 } 31 32 // Conn is a wrapper around net.Conn that includes the source NoisePublicKey. 33 type Conn struct { 34 stdnet.Conn 35 peers *peerList 36 } 37 38 func (c *Conn) RemoteAddr() stdnet.Addr { 39 remoteAddr := c.Conn.RemoteAddr() 40 if remoteAddr == nil { 41 return nil 42 } 43 44 peer, ok := c.peers.getByAddress(netip.MustParseAddrPort(remoteAddr.String()).Addr()) 45 if !ok { 46 // Just return the standard address if we can't find the peer. 47 return remoteAddr 48 } 49 50 return &Addr{Addr: c.Conn.RemoteAddr(), pk: peer.PublicKey()} 51 } 52 53 type listener struct { 54 stdnet.Listener 55 peers *peerList 56 } 57 58 func (l *listener) Accept() (stdnet.Conn, error) { 59 conn, err := l.Listener.Accept() 60 if err != nil { 61 // The network stack was closed. 62 if strings.Contains(err.Error(), (&tcpip.ErrInvalidEndpointState{}).String()) { 63 return nil, stdnet.ErrClosed 64 } 65 66 return nil, err 67 } 68 69 return &Conn{Conn: conn, peers: l.peers}, nil 70 } 71 72 type packetConn struct { 73 stdnet.PacketConn 74 peers *peerList 75 } 76 77 func (pc *packetConn) ReadFrom(b []byte) (int, stdnet.Addr, error) { 78 n, addr, err := pc.PacketConn.ReadFrom(b) 79 if addr == nil { 80 return n, nil, err 81 } 82 83 peer, ok := pc.peers.getByAddress(netip.MustParseAddrPort(addr.String()).Addr()) 84 if !ok { 85 // Just return the standard address if we can't find the peer. 86 return n, addr, err 87 } 88 89 return n, &Addr{Addr: addr, pk: peer.PublicKey()}, err 90 } 91 92 func (pc *packetConn) WriteTo(b []byte, addr stdnet.Addr) (int, error) { 93 addrPort, err := netip.ParseAddrPort(addr.String()) 94 if err != nil { 95 return 0, err 96 } 97 98 return pc.PacketConn.WriteTo(b, &stdnet.UDPAddr{ 99 IP: addrPort.Addr().AsSlice(), 100 Port: int(addrPort.Port()), 101 }) 102 }