golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/socket/mmsghdr_unix.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || linux || netbsd
     6  
     7  package socket
     8  
     9  import (
    10  	"net"
    11  	"os"
    12  	"sync"
    13  	"syscall"
    14  )
    15  
    16  type mmsghdrs []mmsghdr
    17  
    18  func (hs mmsghdrs) unpack(ms []Message, parseFn func([]byte, string) (net.Addr, error), hint string) error {
    19  	for i := range hs {
    20  		ms[i].N = int(hs[i].Len)
    21  		ms[i].NN = hs[i].Hdr.controllen()
    22  		ms[i].Flags = hs[i].Hdr.flags()
    23  		if parseFn != nil {
    24  			var err error
    25  			ms[i].Addr, err = parseFn(hs[i].Hdr.name(), hint)
    26  			if err != nil {
    27  				return err
    28  			}
    29  		}
    30  	}
    31  	return nil
    32  }
    33  
    34  // mmsghdrsPacker packs Message-slices into mmsghdrs (re-)using pre-allocated buffers.
    35  type mmsghdrsPacker struct {
    36  	// hs are the pre-allocated mmsghdrs.
    37  	hs mmsghdrs
    38  	// sockaddrs is the pre-allocated buffer for the Hdr.Name buffers.
    39  	// We use one large buffer for all messages and slice it up.
    40  	sockaddrs []byte
    41  	// vs are the pre-allocated iovecs.
    42  	// We allocate one large buffer for all messages and slice it up. This allows to reuse the buffer
    43  	// if the number of buffers per message is distributed differently between calls.
    44  	vs []iovec
    45  }
    46  
    47  func (p *mmsghdrsPacker) prepare(ms []Message) {
    48  	n := len(ms)
    49  	if n <= cap(p.hs) {
    50  		p.hs = p.hs[:n]
    51  	} else {
    52  		p.hs = make(mmsghdrs, n)
    53  	}
    54  	if n*sizeofSockaddrInet6 <= cap(p.sockaddrs) {
    55  		p.sockaddrs = p.sockaddrs[:n*sizeofSockaddrInet6]
    56  	} else {
    57  		p.sockaddrs = make([]byte, n*sizeofSockaddrInet6)
    58  	}
    59  
    60  	nb := 0
    61  	for _, m := range ms {
    62  		nb += len(m.Buffers)
    63  	}
    64  	if nb <= cap(p.vs) {
    65  		p.vs = p.vs[:nb]
    66  	} else {
    67  		p.vs = make([]iovec, nb)
    68  	}
    69  }
    70  
    71  func (p *mmsghdrsPacker) pack(ms []Message, parseFn func([]byte, string) (net.Addr, error), marshalFn func(net.Addr, []byte) int) mmsghdrs {
    72  	p.prepare(ms)
    73  	hs := p.hs
    74  	vsRest := p.vs
    75  	saRest := p.sockaddrs
    76  	for i := range hs {
    77  		nvs := len(ms[i].Buffers)
    78  		vs := vsRest[:nvs]
    79  		vsRest = vsRest[nvs:]
    80  
    81  		var sa []byte
    82  		if parseFn != nil {
    83  			sa = saRest[:sizeofSockaddrInet6]
    84  			saRest = saRest[sizeofSockaddrInet6:]
    85  		} else if marshalFn != nil {
    86  			n := marshalFn(ms[i].Addr, saRest)
    87  			if n > 0 {
    88  				sa = saRest[:n]
    89  				saRest = saRest[n:]
    90  			}
    91  		}
    92  		hs[i].Hdr.pack(vs, ms[i].Buffers, ms[i].OOB, sa)
    93  	}
    94  	return hs
    95  }
    96  
    97  // syscaller is a helper to invoke recvmmsg and sendmmsg via the RawConn.Read/Write interface.
    98  // It is reusable, to amortize the overhead of allocating a closure for the function passed to
    99  // RawConn.Read/Write.
   100  type syscaller struct {
   101  	n     int
   102  	operr error
   103  	hs    mmsghdrs
   104  	flags int
   105  
   106  	boundRecvmmsgF func(uintptr) bool
   107  	boundSendmmsgF func(uintptr) bool
   108  }
   109  
   110  func (r *syscaller) init() {
   111  	r.boundRecvmmsgF = r.recvmmsgF
   112  	r.boundSendmmsgF = r.sendmmsgF
   113  }
   114  
   115  func (r *syscaller) recvmmsg(c syscall.RawConn, hs mmsghdrs, flags int) (int, error) {
   116  	r.n = 0
   117  	r.operr = nil
   118  	r.hs = hs
   119  	r.flags = flags
   120  	if err := c.Read(r.boundRecvmmsgF); err != nil {
   121  		return r.n, err
   122  	}
   123  	if r.operr != nil {
   124  		return r.n, os.NewSyscallError("recvmmsg", r.operr)
   125  	}
   126  	return r.n, nil
   127  }
   128  
   129  func (r *syscaller) recvmmsgF(s uintptr) bool {
   130  	r.n, r.operr = recvmmsg(s, r.hs, r.flags)
   131  	return ioComplete(r.flags, r.operr)
   132  }
   133  
   134  func (r *syscaller) sendmmsg(c syscall.RawConn, hs mmsghdrs, flags int) (int, error) {
   135  	r.n = 0
   136  	r.operr = nil
   137  	r.hs = hs
   138  	r.flags = flags
   139  	if err := c.Write(r.boundSendmmsgF); err != nil {
   140  		return r.n, err
   141  	}
   142  	if r.operr != nil {
   143  		return r.n, os.NewSyscallError("sendmmsg", r.operr)
   144  	}
   145  	return r.n, nil
   146  }
   147  
   148  func (r *syscaller) sendmmsgF(s uintptr) bool {
   149  	r.n, r.operr = sendmmsg(s, r.hs, r.flags)
   150  	return ioComplete(r.flags, r.operr)
   151  }
   152  
   153  // mmsgTmps holds reusable temporary helpers for recvmmsg and sendmmsg.
   154  type mmsgTmps struct {
   155  	packer    mmsghdrsPacker
   156  	syscaller syscaller
   157  }
   158  
   159  var defaultMmsgTmpsPool = mmsgTmpsPool{
   160  	p: sync.Pool{
   161  		New: func() interface{} {
   162  			tmps := new(mmsgTmps)
   163  			tmps.syscaller.init()
   164  			return tmps
   165  		},
   166  	},
   167  }
   168  
   169  type mmsgTmpsPool struct {
   170  	p sync.Pool
   171  }
   172  
   173  func (p *mmsgTmpsPool) Get() *mmsgTmps {
   174  	m := p.p.Get().(*mmsgTmps)
   175  	// Clear fields up to the len (not the cap) of the slice,
   176  	// assuming that the previous caller only used that many elements.
   177  	for i := range m.packer.sockaddrs {
   178  		m.packer.sockaddrs[i] = 0
   179  	}
   180  	m.packer.sockaddrs = m.packer.sockaddrs[:0]
   181  	for i := range m.packer.vs {
   182  		m.packer.vs[i] = iovec{}
   183  	}
   184  	m.packer.vs = m.packer.vs[:0]
   185  	for i := range m.packer.hs {
   186  		m.packer.hs[i].Len = 0
   187  		m.packer.hs[i].Hdr = msghdr{}
   188  	}
   189  	m.packer.hs = m.packer.hs[:0]
   190  	return m
   191  }
   192  
   193  func (p *mmsgTmpsPool) Put(tmps *mmsgTmps) {
   194  	p.p.Put(tmps)
   195  }