github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/conn/bind_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "errors" 10 "net" 11 "strconv" 12 "sync" 13 "syscall" 14 "unsafe" 15 16 "golang.org/x/sys/unix" 17 ) 18 19 type ipv4Source struct { 20 Src [4]byte 21 Ifindex int32 22 } 23 24 type ipv6Source struct { 25 src [16]byte 26 // ifindex belongs in dst.ZoneId 27 } 28 29 type LinuxSocketEndpoint struct { 30 mu sync.Mutex 31 dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte 32 src [unsafe.Sizeof(ipv6Source{})]byte 33 isV6 bool 34 } 35 36 func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() } 37 func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } 38 func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 } 39 40 func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source { 41 return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0])) 42 } 43 44 func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source { 45 return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0])) 46 } 47 48 func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 { 49 return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) 50 } 51 52 func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { 53 return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) 54 } 55 56 // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. 57 type LinuxSocketBind struct { 58 // mu guards sock4 and sock6 and the associated fds. 59 // As long as someone holds mu (read or write), the associated fds are valid. 60 mu sync.RWMutex 61 sock4 int 62 sock6 int 63 } 64 65 func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } 66 func NewDefaultBind() Bind { return NewLinuxSocketBind() } 67 68 var _ Endpoint = (*LinuxSocketEndpoint)(nil) 69 var _ Bind = (*LinuxSocketBind)(nil) 70 71 func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { 72 var end LinuxSocketEndpoint 73 addr, err := parseEndpoint(s) 74 if err != nil { 75 return nil, err 76 } 77 78 ipv4 := addr.IP.To4() 79 if ipv4 != nil { 80 dst := end.dst4() 81 end.isV6 = false 82 dst.Port = addr.Port 83 copy(dst.Addr[:], ipv4) 84 end.ClearSrc() 85 return &end, nil 86 } 87 88 ipv6 := addr.IP.To16() 89 if ipv6 != nil { 90 zone, err := zoneToUint32(addr.Zone) 91 if err != nil { 92 return nil, err 93 } 94 dst := end.dst6() 95 end.isV6 = true 96 dst.Port = addr.Port 97 dst.ZoneId = zone 98 copy(dst.Addr[:], ipv6[:]) 99 end.ClearSrc() 100 return &end, nil 101 } 102 103 return nil, errors.New("invalid IP address") 104 } 105 106 func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { 107 bind.mu.Lock() 108 defer bind.mu.Unlock() 109 110 var err error 111 var newPort uint16 112 var tries int 113 114 if bind.sock4 != -1 || bind.sock6 != -1 { 115 return nil, 0, ErrBindAlreadyOpen 116 } 117 118 originalPort := port 119 120 again: 121 port = originalPort 122 var sock4, sock6 int 123 // Attempt ipv6 bind, update port if successful. 124 sock6, newPort, err = create6(port) 125 if err != nil { 126 if !errors.Is(err, syscall.EAFNOSUPPORT) { 127 return nil, 0, err 128 } 129 } else { 130 port = newPort 131 } 132 133 // Attempt ipv4 bind, update port if successful. 134 sock4, newPort, err = create4(port) 135 if err != nil { 136 if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 137 unix.Close(sock6) 138 tries++ 139 goto again 140 } 141 if !errors.Is(err, syscall.EAFNOSUPPORT) { 142 unix.Close(sock6) 143 return nil, 0, err 144 } 145 } else { 146 port = newPort 147 } 148 149 var fns []ReceiveFunc 150 if sock4 != -1 { 151 bind.sock4 = sock4 152 fns = append(fns, bind.receiveIPv4) 153 } 154 if sock6 != -1 { 155 bind.sock6 = sock6 156 fns = append(fns, bind.receiveIPv6) 157 } 158 if len(fns) == 0 { 159 return nil, 0, syscall.EAFNOSUPPORT 160 } 161 return fns, port, nil 162 } 163 164 func (bind *LinuxSocketBind) SetMark(value uint32) error { 165 bind.mu.RLock() 166 defer bind.mu.RUnlock() 167 168 if bind.sock6 != -1 { 169 err := unix.SetsockoptInt( 170 bind.sock6, 171 unix.SOL_SOCKET, 172 unix.SO_MARK, 173 int(value), 174 ) 175 176 if err != nil { 177 return err 178 } 179 } 180 181 if bind.sock4 != -1 { 182 err := unix.SetsockoptInt( 183 bind.sock4, 184 unix.SOL_SOCKET, 185 unix.SO_MARK, 186 int(value), 187 ) 188 189 if err != nil { 190 return err 191 } 192 } 193 194 return nil 195 } 196 197 func (bind *LinuxSocketBind) Close() error { 198 // Take a readlock to shut down the sockets... 199 bind.mu.RLock() 200 if bind.sock6 != -1 { 201 unix.Shutdown(bind.sock6, unix.SHUT_RDWR) 202 } 203 if bind.sock4 != -1 { 204 unix.Shutdown(bind.sock4, unix.SHUT_RDWR) 205 } 206 bind.mu.RUnlock() 207 // ...and a write lock to close the fd. 208 // This ensures that no one else is using the fd. 209 bind.mu.Lock() 210 defer bind.mu.Unlock() 211 var err1, err2 error 212 if bind.sock6 != -1 { 213 err1 = unix.Close(bind.sock6) 214 bind.sock6 = -1 215 } 216 if bind.sock4 != -1 { 217 err2 = unix.Close(bind.sock4) 218 bind.sock4 = -1 219 } 220 221 if err1 != nil { 222 return err1 223 } 224 return err2 225 } 226 227 func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) { 228 bind.mu.RLock() 229 defer bind.mu.RUnlock() 230 if bind.sock4 == -1 { 231 return 0, nil, net.ErrClosed 232 } 233 var end LinuxSocketEndpoint 234 n, err := receive4(bind.sock4, buf, &end) 235 return n, &end, err 236 } 237 238 func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) { 239 bind.mu.RLock() 240 defer bind.mu.RUnlock() 241 if bind.sock6 == -1 { 242 return 0, nil, net.ErrClosed 243 } 244 var end LinuxSocketEndpoint 245 n, err := receive6(bind.sock6, buf, &end) 246 return n, &end, err 247 } 248 249 func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { 250 nend, ok := end.(*LinuxSocketEndpoint) 251 if !ok { 252 return ErrWrongEndpointType 253 } 254 bind.mu.RLock() 255 defer bind.mu.RUnlock() 256 if !nend.isV6 { 257 if bind.sock4 == -1 { 258 return net.ErrClosed 259 } 260 return send4(bind.sock4, nend, buff) 261 } else { 262 if bind.sock6 == -1 { 263 return net.ErrClosed 264 } 265 return send6(bind.sock6, nend, buff) 266 } 267 } 268 269 func (end *LinuxSocketEndpoint) SrcIP() net.IP { 270 if !end.isV6 { 271 return net.IPv4( 272 end.src4().Src[0], 273 end.src4().Src[1], 274 end.src4().Src[2], 275 end.src4().Src[3], 276 ) 277 } else { 278 return end.src6().src[:] 279 } 280 } 281 282 func (end *LinuxSocketEndpoint) DstIP() net.IP { 283 if !end.isV6 { 284 return net.IPv4( 285 end.dst4().Addr[0], 286 end.dst4().Addr[1], 287 end.dst4().Addr[2], 288 end.dst4().Addr[3], 289 ) 290 } else { 291 return end.dst6().Addr[:] 292 } 293 } 294 295 func (end *LinuxSocketEndpoint) DstToBytes() []byte { 296 if !end.isV6 { 297 return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] 298 } else { 299 return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] 300 } 301 } 302 303 func (end *LinuxSocketEndpoint) SrcToString() string { 304 return end.SrcIP().String() 305 } 306 307 func (end *LinuxSocketEndpoint) DstToString() string { 308 var udpAddr net.UDPAddr 309 udpAddr.IP = end.DstIP() 310 if !end.isV6 { 311 udpAddr.Port = end.dst4().Port 312 } else { 313 udpAddr.Port = end.dst6().Port 314 } 315 return udpAddr.String() 316 } 317 318 func (end *LinuxSocketEndpoint) ClearDst() { 319 for i := range end.dst { 320 end.dst[i] = 0 321 } 322 } 323 324 func (end *LinuxSocketEndpoint) ClearSrc() { 325 for i := range end.src { 326 end.src[i] = 0 327 } 328 } 329 330 func zoneToUint32(zone string) (uint32, error) { 331 if zone == "" { 332 return 0, nil 333 } 334 if intr, err := net.InterfaceByName(zone); err == nil { 335 return uint32(intr.Index), nil 336 } 337 n, err := strconv.ParseUint(zone, 10, 32) 338 return uint32(n), err 339 } 340 341 func create4(port uint16) (int, uint16, error) { 342 343 // create socket 344 345 fd, err := unix.Socket( 346 unix.AF_INET, 347 unix.SOCK_DGRAM, 348 0, 349 ) 350 351 if err != nil { 352 return -1, 0, err 353 } 354 355 addr := unix.SockaddrInet4{ 356 Port: int(port), 357 } 358 359 // set sockopts and bind 360 361 if err := func() error { 362 if err := unix.SetsockoptInt( 363 fd, 364 unix.IPPROTO_IP, 365 unix.IP_PKTINFO, 366 1, 367 ); err != nil { 368 return err 369 } 370 371 return unix.Bind(fd, &addr) 372 }(); err != nil { 373 unix.Close(fd) 374 return -1, 0, err 375 } 376 377 sa, err := unix.Getsockname(fd) 378 if err == nil { 379 addr.Port = sa.(*unix.SockaddrInet4).Port 380 } 381 382 return fd, uint16(addr.Port), err 383 } 384 385 func create6(port uint16) (int, uint16, error) { 386 387 // create socket 388 389 fd, err := unix.Socket( 390 unix.AF_INET6, 391 unix.SOCK_DGRAM, 392 0, 393 ) 394 395 if err != nil { 396 return -1, 0, err 397 } 398 399 // set sockopts and bind 400 401 addr := unix.SockaddrInet6{ 402 Port: int(port), 403 } 404 405 if err := func() error { 406 if err := unix.SetsockoptInt( 407 fd, 408 unix.IPPROTO_IPV6, 409 unix.IPV6_RECVPKTINFO, 410 1, 411 ); err != nil { 412 return err 413 } 414 415 if err := unix.SetsockoptInt( 416 fd, 417 unix.IPPROTO_IPV6, 418 unix.IPV6_V6ONLY, 419 1, 420 ); err != nil { 421 return err 422 } 423 424 return unix.Bind(fd, &addr) 425 426 }(); err != nil { 427 unix.Close(fd) 428 return -1, 0, err 429 } 430 431 sa, err := unix.Getsockname(fd) 432 if err == nil { 433 addr.Port = sa.(*unix.SockaddrInet6).Port 434 } 435 436 return fd, uint16(addr.Port), err 437 } 438 439 func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { 440 441 // construct message header 442 443 cmsg := struct { 444 cmsghdr unix.Cmsghdr 445 pktinfo unix.Inet4Pktinfo 446 }{ 447 unix.Cmsghdr{ 448 Level: unix.IPPROTO_IP, 449 Type: unix.IP_PKTINFO, 450 Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, 451 }, 452 unix.Inet4Pktinfo{ 453 Spec_dst: end.src4().Src, 454 Ifindex: end.src4().Ifindex, 455 }, 456 } 457 458 end.mu.Lock() 459 _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) 460 end.mu.Unlock() 461 462 if err == nil { 463 return nil 464 } 465 466 // clear src and retry 467 468 if err == unix.EINVAL { 469 end.ClearSrc() 470 cmsg.pktinfo = unix.Inet4Pktinfo{} 471 end.mu.Lock() 472 _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) 473 end.mu.Unlock() 474 } 475 476 return err 477 } 478 479 func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { 480 481 // construct message header 482 483 cmsg := struct { 484 cmsghdr unix.Cmsghdr 485 pktinfo unix.Inet6Pktinfo 486 }{ 487 unix.Cmsghdr{ 488 Level: unix.IPPROTO_IPV6, 489 Type: unix.IPV6_PKTINFO, 490 Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, 491 }, 492 unix.Inet6Pktinfo{ 493 Addr: end.src6().src, 494 Ifindex: end.dst6().ZoneId, 495 }, 496 } 497 498 if cmsg.pktinfo.Addr == [16]byte{} { 499 cmsg.pktinfo.Ifindex = 0 500 } 501 502 end.mu.Lock() 503 _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) 504 end.mu.Unlock() 505 506 if err == nil { 507 return nil 508 } 509 510 // clear src and retry 511 512 if err == unix.EINVAL { 513 end.ClearSrc() 514 cmsg.pktinfo = unix.Inet6Pktinfo{} 515 end.mu.Lock() 516 _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) 517 end.mu.Unlock() 518 } 519 520 return err 521 } 522 523 func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { 524 525 // construct message header 526 527 var cmsg struct { 528 cmsghdr unix.Cmsghdr 529 pktinfo unix.Inet4Pktinfo 530 } 531 532 size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) 533 534 if err != nil { 535 return 0, err 536 } 537 end.isV6 = false 538 539 if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { 540 *end.dst4() = *newDst4 541 } 542 543 // update source cache 544 545 if cmsg.cmsghdr.Level == unix.IPPROTO_IP && 546 cmsg.cmsghdr.Type == unix.IP_PKTINFO && 547 cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { 548 end.src4().Src = cmsg.pktinfo.Spec_dst 549 end.src4().Ifindex = cmsg.pktinfo.Ifindex 550 } 551 552 return size, nil 553 } 554 555 func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { 556 557 // construct message header 558 559 var cmsg struct { 560 cmsghdr unix.Cmsghdr 561 pktinfo unix.Inet6Pktinfo 562 } 563 564 size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) 565 566 if err != nil { 567 return 0, err 568 } 569 end.isV6 = true 570 571 if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { 572 *end.dst6() = *newDst6 573 } 574 575 // update source cache 576 577 if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && 578 cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && 579 cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { 580 end.src6().src = cmsg.pktinfo.Addr 581 end.dst6().ZoneId = cmsg.pktinfo.Ifindex 582 } 583 584 return size, nil 585 }