github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bind_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "errors" 10 "net" 11 "net/netip" 12 "strconv" 13 "sync" 14 "syscall" 15 "unsafe" 16 17 "golang.org/x/sys/unix" 18 ) 19 20 type ipv4Source struct { 21 Src [4]byte 22 Ifindex int32 23 } 24 25 type ipv6Source struct { 26 src [16]byte 27 // ifindex belongs in dst.ZoneId 28 } 29 30 type LinuxSocketEndpoint struct { 31 mu sync.Mutex 32 dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte 33 src [unsafe.Sizeof(ipv6Source{})]byte 34 isV6 bool 35 } 36 37 func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() } 38 func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } 39 func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 } 40 41 func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source { 42 return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0])) 43 } 44 45 func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source { 46 return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0])) 47 } 48 49 func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 { 50 return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) 51 } 52 53 func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { 54 return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) 55 } 56 57 // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. 58 type LinuxSocketBind struct { 59 // mu guards sock4 and sock6 and the associated fds. 60 // As long as someone holds mu (read or write), the associated fds are valid. 61 mu sync.RWMutex 62 sock4 int 63 sock6 int 64 } 65 66 func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } 67 func NewDefaultBind() Bind { return NewLinuxSocketBind() } 68 69 var ( 70 _ Endpoint = (*LinuxSocketEndpoint)(nil) 71 _ Bind = (*LinuxSocketBind)(nil) 72 ) 73 74 func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { 75 var end LinuxSocketEndpoint 76 e, err := netip.ParseAddrPort(s) 77 if err != nil { 78 return nil, err 79 } 80 81 if e.Addr().Is4() { 82 dst := end.dst4() 83 end.isV6 = false 84 dst.Port = int(e.Port()) 85 dst.Addr = e.Addr().As4() 86 end.ClearSrc() 87 return &end, nil 88 } 89 90 if e.Addr().Is6() { 91 zone, err := zoneToUint32(e.Addr().Zone()) 92 if err != nil { 93 return nil, err 94 } 95 dst := end.dst6() 96 end.isV6 = true 97 dst.Port = int(e.Port()) 98 dst.ZoneId = zone 99 dst.Addr = e.Addr().As16() 100 end.ClearSrc() 101 return &end, nil 102 } 103 104 return nil, errors.New("invalid IP address") 105 } 106 107 func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { 108 bind.mu.Lock() 109 defer bind.mu.Unlock() 110 111 var err error 112 var newPort uint16 113 var tries int 114 115 if bind.sock4 != -1 || bind.sock6 != -1 { 116 return nil, 0, ErrBindAlreadyOpen 117 } 118 119 originalPort := port 120 121 again: 122 port = originalPort 123 var sock4, sock6 int 124 // Attempt ipv6 bind, update port if successful. 125 sock6, newPort, err = create6(port) 126 if err != nil { 127 if !errors.Is(err, syscall.EAFNOSUPPORT) { 128 return nil, 0, err 129 } 130 } else { 131 port = newPort 132 } 133 134 // Attempt ipv4 bind, update port if successful. 135 sock4, newPort, err = create4(port) 136 if err != nil { 137 if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 138 unix.Close(sock6) 139 tries++ 140 goto again 141 } 142 if !errors.Is(err, syscall.EAFNOSUPPORT) { 143 unix.Close(sock6) 144 return nil, 0, err 145 } 146 } else { 147 port = newPort 148 } 149 150 var fns []ReceiveFunc 151 if sock4 != -1 { 152 bind.sock4 = sock4 153 fns = append(fns, bind.receiveIPv4) 154 } 155 if sock6 != -1 { 156 bind.sock6 = sock6 157 fns = append(fns, bind.receiveIPv6) 158 } 159 if len(fns) == 0 { 160 return nil, 0, syscall.EAFNOSUPPORT 161 } 162 return fns, port, nil 163 } 164 165 func (bind *LinuxSocketBind) SetMark(value uint32) error { 166 bind.mu.RLock() 167 defer bind.mu.RUnlock() 168 169 if bind.sock6 != -1 { 170 err := unix.SetsockoptInt( 171 bind.sock6, 172 unix.SOL_SOCKET, 173 unix.SO_MARK, 174 int(value), 175 ) 176 if err != nil { 177 return err 178 } 179 } 180 181 if bind.sock4 != -1 { 182 err := unix.SetsockoptInt( 183 bind.sock4, 184 unix.SOL_SOCKET, 185 unix.SO_MARK, 186 int(value), 187 ) 188 if err != nil { 189 return err 190 } 191 } 192 193 return nil 194 } 195 196 func (bind *LinuxSocketBind) Close() error { 197 // Take a readlock to shut down the sockets... 198 bind.mu.RLock() 199 if bind.sock6 != -1 { 200 unix.Shutdown(bind.sock6, unix.SHUT_RDWR) 201 } 202 if bind.sock4 != -1 { 203 unix.Shutdown(bind.sock4, unix.SHUT_RDWR) 204 } 205 bind.mu.RUnlock() 206 // ...and a write lock to close the fd. 207 // This ensures that no one else is using the fd. 208 bind.mu.Lock() 209 defer bind.mu.Unlock() 210 var err1, err2 error 211 if bind.sock6 != -1 { 212 err1 = unix.Close(bind.sock6) 213 bind.sock6 = -1 214 } 215 if bind.sock4 != -1 { 216 err2 = unix.Close(bind.sock4) 217 bind.sock4 = -1 218 } 219 220 if err1 != nil { 221 return err1 222 } 223 return err2 224 } 225 226 func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) { 227 bind.mu.RLock() 228 defer bind.mu.RUnlock() 229 if bind.sock4 == -1 { 230 return 0, nil, net.ErrClosed 231 } 232 var end LinuxSocketEndpoint 233 n, err := receive4(bind.sock4, buf, &end) 234 return n, &end, err 235 } 236 237 func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) { 238 bind.mu.RLock() 239 defer bind.mu.RUnlock() 240 if bind.sock6 == -1 { 241 return 0, nil, net.ErrClosed 242 } 243 var end LinuxSocketEndpoint 244 n, err := receive6(bind.sock6, buf, &end) 245 return n, &end, err 246 } 247 248 func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { 249 nend, ok := end.(*LinuxSocketEndpoint) 250 if !ok { 251 return ErrWrongEndpointType 252 } 253 bind.mu.RLock() 254 defer bind.mu.RUnlock() 255 if !nend.isV6 { 256 if bind.sock4 == -1 { 257 return net.ErrClosed 258 } 259 return send4(bind.sock4, nend, buff) 260 } else { 261 if bind.sock6 == -1 { 262 return net.ErrClosed 263 } 264 return send6(bind.sock6, nend, buff) 265 } 266 } 267 268 func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { 269 if !end.isV6 { 270 return netip.AddrFrom4(end.src4().Src) 271 } else { 272 return netip.AddrFrom16(end.src6().src) 273 } 274 } 275 276 func (end *LinuxSocketEndpoint) DstIP() netip.Addr { 277 if !end.isV6 { 278 return netip.AddrFrom4(end.dst4().Addr) 279 } else { 280 return netip.AddrFrom16(end.dst6().Addr) 281 } 282 } 283 284 func (end *LinuxSocketEndpoint) DstToBytes() []byte { 285 if !end.isV6 { 286 return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] 287 } else { 288 return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] 289 } 290 } 291 292 func (end *LinuxSocketEndpoint) SrcToString() string { 293 return end.SrcIP().String() 294 } 295 296 func (end *LinuxSocketEndpoint) DstToString() string { 297 var port int 298 if !end.isV6 { 299 port = end.dst4().Port 300 } else { 301 port = end.dst6().Port 302 } 303 return netip.AddrPortFrom(end.DstIP(), uint16(port)).String() 304 } 305 306 func (end *LinuxSocketEndpoint) ClearDst() { 307 for i := range end.dst { 308 end.dst[i] = 0 309 } 310 } 311 312 func (end *LinuxSocketEndpoint) ClearSrc() { 313 for i := range end.src { 314 end.src[i] = 0 315 } 316 } 317 318 func zoneToUint32(zone string) (uint32, error) { 319 if zone == "" { 320 return 0, nil 321 } 322 if intr, err := net.InterfaceByName(zone); err == nil { 323 return uint32(intr.Index), nil 324 } 325 n, err := strconv.ParseUint(zone, 10, 32) 326 return uint32(n), err 327 } 328 329 func create4(port uint16) (int, uint16, error) { 330 // create socket 331 332 fd, err := unix.Socket( 333 unix.AF_INET, 334 unix.SOCK_DGRAM, 335 0, 336 ) 337 if err != nil { 338 return -1, 0, err 339 } 340 341 addr := unix.SockaddrInet4{ 342 Port: int(port), 343 } 344 345 // set sockopts and bind 346 347 if err := func() error { 348 if err := unix.SetsockoptInt( 349 fd, 350 unix.IPPROTO_IP, 351 unix.IP_PKTINFO, 352 1, 353 ); err != nil { 354 return err 355 } 356 357 return unix.Bind(fd, &addr) 358 }(); err != nil { 359 unix.Close(fd) 360 return -1, 0, err 361 } 362 363 sa, err := unix.Getsockname(fd) 364 if err == nil { 365 addr.Port = sa.(*unix.SockaddrInet4).Port 366 } 367 368 return fd, uint16(addr.Port), err 369 } 370 371 func create6(port uint16) (int, uint16, error) { 372 // create socket 373 374 fd, err := unix.Socket( 375 unix.AF_INET6, 376 unix.SOCK_DGRAM, 377 0, 378 ) 379 if err != nil { 380 return -1, 0, err 381 } 382 383 // set sockopts and bind 384 385 addr := unix.SockaddrInet6{ 386 Port: int(port), 387 } 388 389 if err := func() error { 390 if err := unix.SetsockoptInt( 391 fd, 392 unix.IPPROTO_IPV6, 393 unix.IPV6_RECVPKTINFO, 394 1, 395 ); err != nil { 396 return err 397 } 398 399 if err := unix.SetsockoptInt( 400 fd, 401 unix.IPPROTO_IPV6, 402 unix.IPV6_V6ONLY, 403 1, 404 ); err != nil { 405 return err 406 } 407 408 return unix.Bind(fd, &addr) 409 }(); err != nil { 410 unix.Close(fd) 411 return -1, 0, err 412 } 413 414 sa, err := unix.Getsockname(fd) 415 if err == nil { 416 addr.Port = sa.(*unix.SockaddrInet6).Port 417 } 418 419 return fd, uint16(addr.Port), err 420 } 421 422 func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { 423 // construct message header 424 425 cmsg := struct { 426 cmsghdr unix.Cmsghdr 427 pktinfo unix.Inet4Pktinfo 428 }{ 429 unix.Cmsghdr{ 430 Level: unix.IPPROTO_IP, 431 Type: unix.IP_PKTINFO, 432 Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, 433 }, 434 unix.Inet4Pktinfo{ 435 Spec_dst: end.src4().Src, 436 Ifindex: end.src4().Ifindex, 437 }, 438 } 439 440 end.mu.Lock() 441 _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) 442 end.mu.Unlock() 443 444 if err == nil { 445 return nil 446 } 447 448 // clear src and retry 449 450 if err == unix.EINVAL { 451 end.ClearSrc() 452 cmsg.pktinfo = unix.Inet4Pktinfo{} 453 end.mu.Lock() 454 _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) 455 end.mu.Unlock() 456 } 457 458 return err 459 } 460 461 func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { 462 // construct message header 463 464 cmsg := struct { 465 cmsghdr unix.Cmsghdr 466 pktinfo unix.Inet6Pktinfo 467 }{ 468 unix.Cmsghdr{ 469 Level: unix.IPPROTO_IPV6, 470 Type: unix.IPV6_PKTINFO, 471 Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, 472 }, 473 unix.Inet6Pktinfo{ 474 Addr: end.src6().src, 475 Ifindex: end.dst6().ZoneId, 476 }, 477 } 478 479 if cmsg.pktinfo.Addr == [16]byte{} { 480 cmsg.pktinfo.Ifindex = 0 481 } 482 483 end.mu.Lock() 484 _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) 485 end.mu.Unlock() 486 487 if err == nil { 488 return nil 489 } 490 491 // clear src and retry 492 493 if err == unix.EINVAL { 494 end.ClearSrc() 495 cmsg.pktinfo = unix.Inet6Pktinfo{} 496 end.mu.Lock() 497 _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) 498 end.mu.Unlock() 499 } 500 501 return err 502 } 503 504 func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { 505 // construct message header 506 507 var cmsg struct { 508 cmsghdr unix.Cmsghdr 509 pktinfo unix.Inet4Pktinfo 510 } 511 512 size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) 513 if err != nil { 514 return 0, err 515 } 516 end.isV6 = false 517 518 if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { 519 *end.dst4() = *newDst4 520 } 521 522 // update source cache 523 524 if cmsg.cmsghdr.Level == unix.IPPROTO_IP && 525 cmsg.cmsghdr.Type == unix.IP_PKTINFO && 526 cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { 527 end.src4().Src = cmsg.pktinfo.Spec_dst 528 end.src4().Ifindex = cmsg.pktinfo.Ifindex 529 } 530 531 return size, nil 532 } 533 534 func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { 535 // construct message header 536 537 var cmsg struct { 538 cmsghdr unix.Cmsghdr 539 pktinfo unix.Inet6Pktinfo 540 } 541 542 size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) 543 if err != nil { 544 return 0, err 545 } 546 end.isV6 = true 547 548 if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { 549 *end.dst6() = *newDst6 550 } 551 552 // update source cache 553 554 if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && 555 cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && 556 cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { 557 end.src6().src = cmsg.pktinfo.Addr 558 end.dst6().ZoneId = cmsg.pktinfo.Ifindex 559 } 560 561 return size, nil 562 }