gitlab.com/go-extension/tls@v0.0.0-20240304171319-e6745021905e/kernel_linux.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "os" 13 "strings" 14 "syscall" 15 "unsafe" 16 17 "github.com/blang/semver/v4" 18 "github.com/pmorjan/kmod" 19 "golang.org/x/sys/unix" 20 ) 21 22 const ( 23 TLS_TX = 1 24 TLS_RX = 2 25 TLS_TX_ZEROCOPY_RO = 3 // TX zerocopy (only sendfile now) 26 TLS_RX_EXPECT_NO_PAD = 4 // Attempt opportunistic zero-copy, TLS 1.3 only 27 28 TLS_SET_RECORD_TYPE = 1 29 TLS_GET_RECORD_TYPE = 2 30 31 kernelOverhead = 16 32 ) 33 34 type kernelInfo struct { 35 TLS_TX, TLS_RX bool 36 37 TLS_Version13 bool 38 TLS_TX_ZEROCOPY bool 39 TLS_RX_NOPAD bool 40 41 TLS_AESGCM256, TLS_AESCCM128, TLS_CHACHA20 bool 42 TLS_ARIAGCM bool 43 } 44 45 var kernel kernelInfo 46 47 func init() { 48 func() { 49 defer recover() 50 kmod, err := kmod.New() 51 if err != nil { 52 return 53 } 54 55 kmod.Load("tls", "", 0) 56 }() 57 58 // when kernel tls module enabled, /sys/module/tls is available 59 if _, err := os.Stat("/sys/module/tls"); err != nil { 60 return 61 } 62 63 var uname unix.Utsname 64 if err := unix.Uname(&uname); err != nil { 65 return 66 } 67 68 kernelVersion, err := semver.Parse(strings.Trim(string(uname.Release[:]), "\x00")) 69 if err != nil { 70 return 71 } 72 kernelVersion.Pre = nil 73 kernelVersion.Build = nil 74 75 if kernelVersion.GTE(semver.MustParse("4.13.0")) { 76 kernel.TLS_TX = true 77 } 78 79 if kernelVersion.GTE(semver.MustParse("4.17.0")) { 80 kernel.TLS_RX = true 81 } 82 83 if kernelVersion.GTE(semver.MustParse("5.1.0")) { 84 kernel.TLS_AESGCM256 = true 85 kernel.TLS_Version13 = true 86 } 87 88 if kernelVersion.GTE(semver.MustParse("5.2.0")) { 89 kernel.TLS_AESCCM128 = true 90 } 91 92 if kernelVersion.GTE(semver.MustParse("5.11.0")) { 93 kernel.TLS_CHACHA20 = true 94 } 95 96 if kernelVersion.GTE(semver.MustParse("5.16.0")) { 97 // SM4_GCM, SM4_CCM (not supported) 98 } 99 100 if kernelVersion.GTE(semver.MustParse("5.19.0")) { 101 kernel.TLS_CHACHA20 = true 102 } 103 104 if kernelVersion.Major > 5 { 105 kernel.TLS_RX_NOPAD = true 106 } 107 108 if kernelVersion.GTE(semver.MustParse("6.1.0")) { 109 kernel.TLS_ARIAGCM = true 110 } 111 } 112 113 func (c *Conn) setup() error { 114 if !kernel.TLS_TX || (!c.config.KernelTX && !c.config.KernelRX) { 115 return nil 116 } 117 118 if c.quic != nil { 119 return nil 120 } 121 122 var rwc syscall.RawConn 123 { 124 conn, ok := c.conn.(*net.TCPConn) 125 if !ok { 126 return nil 127 } 128 129 var err error 130 rwc, err = conn.SyscallConn() 131 if err != nil { 132 return nil 133 } 134 } 135 136 var in, out kernelCrypto 137 switch c.vers { 138 case VersionTLS12: 139 if kernel.TLS_RX && c.config.KernelRX { 140 in = c.in.kernelCipher(c.cipherSuite) 141 } 142 143 if c.config.KernelTX { 144 out = c.out.kernelCipher(c.cipherSuite) 145 } 146 case VersionTLS13: 147 if !kernel.TLS_Version13 { 148 return nil 149 } 150 151 if kernel.TLS_RX && c.config.KernelRX { 152 in = c.in.kernelCipher(c.cipherSuite) 153 } 154 155 if c.config.KernelTX { 156 out = c.out.kernelCipher(c.cipherSuite) 157 } 158 default: 159 return nil 160 } 161 162 if in == nil && out == nil { 163 return nil 164 } 165 166 var err, er error 167 err = rwc.Control(func(fd uintptr) { 168 er = syscall.SetsockoptString(int(fd), unix.SOL_TCP, unix.TCP_ULP, "tls") 169 if er != nil { 170 return 171 } 172 173 if in != nil { 174 er = syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_RX, in.String()) 175 if er != nil { 176 return 177 } 178 179 if c.vers >= VersionTLS13 && kernel.TLS_RX_NOPAD { 180 er = syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_RX_EXPECT_NO_PAD, 1) 181 if er != nil { 182 return 183 } 184 } 185 } 186 187 if out != nil { 188 er = syscall.SetsockoptString(int(fd), unix.SOL_TLS, TLS_TX, out.String()) 189 if er != nil { 190 return 191 } 192 193 if kernel.TLS_TX_ZEROCOPY { 194 er = syscall.SetsockoptInt(int(fd), unix.SOL_TLS, TLS_TX_ZEROCOPY_RO, 1) 195 if er != nil { 196 return 197 } 198 } 199 } 200 }) 201 if er != nil { 202 return er 203 } 204 if err != nil { 205 return err 206 } 207 return nil 208 } 209 210 func (c *Conn) readKernelRecord(handshakeState uint32) (typ recordType, data []byte, err error) { 211 if handshakeState != handshakeCompleted { 212 c.sendAlertLocked(alertInternalError) 213 err = c.in.setErrorLocked(errors.New("tls: internal error: set kTLSCipher before handshake completed")) 214 return 215 } 216 217 var rwc syscall.RawConn 218 { 219 conn, ok := c.conn.(*net.TCPConn) 220 if !ok { 221 err = errors.New("unsupported conn types") 222 return 223 } 224 225 rwc, err = conn.SyscallConn() 226 if err != nil { 227 return 228 } 229 } 230 231 if c.rawInput.Len() < maxPlaintext { 232 c.rawInput.Grow(maxPlaintext - c.rawInput.Len()) 233 } 234 235 var n int 236 data = c.rawInput.Bytes()[:maxPlaintext] 237 238 // cmsg for record type 239 buffer := make([]byte, unix.CmsgSpace(1)) 240 cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) 241 cmsg.SetLen(unix.CmsgLen(1)) 242 243 var iov unix.Iovec 244 iov.Base = &data[0] 245 iov.SetLen(len(data)) 246 247 var msg unix.Msghdr 248 msg.Control = &buffer[0] 249 msg.Controllen = cmsg.Len 250 msg.Iov = &iov 251 msg.Iovlen = 1 252 253 er := rwc.Read(func(fd uintptr) bool { 254 flags := 0 255 256 n, err = recvmsg(fd, &msg, flags) 257 if err == unix.EAGAIN { 258 // data is not ready, goroutine will be parked 259 return false 260 } 261 262 // n should not be zero when err == nil 263 if err == nil && n == 0 { 264 err = io.EOF 265 } 266 return true 267 }) 268 if er != nil { 269 err = er 270 } 271 272 if err != nil { 273 return 274 } 275 276 if n <= 0 { 277 data = nil 278 return 279 } 280 281 if cmsg.Level != unix.SOL_TLS { 282 err = fmt.Errorf("unsupported cmsg level: %d", cmsg.Level) 283 return 284 } 285 286 if cmsg.Type != TLS_GET_RECORD_TYPE { 287 err = fmt.Errorf("unsupported cmsg type: %d", cmsg.Type) 288 return 289 } 290 291 typ = recordType(buffer[unix.SizeofCmsghdr]) 292 data = data[:n] 293 return 294 } 295 296 func (c *Conn) writeKernelRecord(typ recordType, data []byte) (n int, err error) { 297 if typ == recordTypeApplicationData { 298 return c.write(data) 299 } 300 301 var rwc syscall.RawConn 302 { 303 conn, ok := c.conn.(*net.TCPConn) 304 if !ok { 305 err = errors.New("unsupported conn types") 306 return 307 } 308 309 rwc, err = conn.SyscallConn() 310 if err != nil { 311 return 312 } 313 } 314 315 // cmsg for record type 316 buffer := make([]byte, unix.CmsgSpace(1)) 317 cmsg := (*unix.Cmsghdr)(unsafe.Pointer(&buffer[0])) 318 cmsg.SetLen(unix.CmsgLen(1)) 319 buffer[unix.SizeofCmsghdr] = byte(typ) 320 cmsg.Level = unix.SOL_TLS 321 cmsg.Type = TLS_SET_RECORD_TYPE 322 323 var iov unix.Iovec 324 iov.Base = &data[0] 325 iov.SetLen(len(data)) 326 327 var msg unix.Msghdr 328 msg.Control = &buffer[0] 329 msg.Controllen = cmsg.Len 330 msg.Iov = &iov 331 msg.Iovlen = 1 332 333 ew := rwc.Write(func(fd uintptr) bool { 334 flags := 0 335 336 n, err = sendmsg(fd, &msg, flags) 337 return err != unix.EAGAIN 338 }) 339 if ew != nil { 340 err = ew 341 } 342 return 343 } 344 345 func (c *Conn) ReadFrom(r io.Reader) (n int64, err error) { 346 if err := c.Handshake(); err != nil { 347 return 0, err 348 } 349 350 if !c.out.usingKernel() { 351 return io.Copy(&wrappedConn{c}, r) 352 } 353 354 return io.Copy(c.conn, r) 355 } 356 357 func (c *Conn) WriteTo(w io.Writer) (n int64, err error) { 358 if err := c.Handshake(); err != nil { 359 return 0, err 360 } 361 362 if !c.in.usingKernel() { 363 return io.Copy(w, &wrappedConn{c}) 364 } 365 366 return io.Copy(w, c.conn) 367 } 368 369 func recvmsg(fd uintptr, msg *unix.Msghdr, flags int) (n int, err error) { 370 r0, _, e1 := unix.Syscall(unix.SYS_RECVMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) 371 n = int(r0) 372 err = errnoErr(e1) 373 return 374 } 375 376 func sendmsg(fd uintptr, msg *unix.Msghdr, flags int) (n int, err error) { 377 r0, _, e1 := unix.Syscall(unix.SYS_SENDMSG, fd, uintptr(unsafe.Pointer(msg)), uintptr(flags)) 378 n = int(r0) 379 err = errnoErr(e1) 380 return 381 } 382 383 // errnoErr returns common boxed Errno values, to prevent 384 // allocations at runtime. 385 func errnoErr(e unix.Errno) error { 386 switch e { 387 case 0: 388 return nil 389 case unix.EAGAIN: 390 return unix.EAGAIN 391 case unix.EINVAL: 392 return unix.EINVAL 393 case unix.ENOENT: 394 return unix.ENOENT 395 } 396 return e 397 }