github.com/iceber/iouring-go@v0.0.0-20230403020409-002cfd2e2a90/mmap.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package iouring
     5  
     6  import (
     7  	"fmt"
     8  	"os"
     9  	"reflect"
    10  	"syscall"
    11  	"unsafe"
    12  
    13  	iouring_syscall "github.com/iceber/iouring-go/syscall"
    14  )
    15  
    16  const uint32Size = uint32(unsafe.Sizeof(uint32(0)))
    17  
    18  func mmapIOURing(iour *IOURing) (err error) {
    19  	defer func() {
    20  		if err != nil {
    21  			munmapIOURing(iour)
    22  		}
    23  	}()
    24  	iour.sq = new(SubmissionQueue)
    25  	iour.cq = new(CompletionQueue)
    26  
    27  	if err = mmapSQ(iour); err != nil {
    28  		return err
    29  	}
    30  
    31  	if (iour.params.Features & iouring_syscall.IORING_FEAT_SINGLE_MMAP) != 0 {
    32  		iour.cq.ptr = iour.sq.ptr
    33  	}
    34  
    35  	if err = mmapCQ(iour); err != nil {
    36  		return err
    37  	}
    38  
    39  	if err = mmapSQEs(iour); err != nil {
    40  		return err
    41  	}
    42  	return nil
    43  }
    44  
    45  func mmapSQ(iour *IOURing) (err error) {
    46  	sq := iour.sq
    47  	params := iour.params
    48  
    49  	sq.size = params.SQOffset.Array + params.SQEntries*uint32Size
    50  	sq.ptr, err = mmap(iour.fd, sq.size, iouring_syscall.IORING_OFF_SQ_RING)
    51  	if err != nil {
    52  		return fmt.Errorf("mmap sq ring: %w", err)
    53  	}
    54  
    55  	sq.head = (*uint32)(unsafe.Pointer(sq.ptr + uintptr(params.SQOffset.Head)))
    56  	sq.tail = (*uint32)(unsafe.Pointer(sq.ptr + uintptr(params.SQOffset.Tail)))
    57  	sq.mask = (*uint32)(unsafe.Pointer(sq.ptr + uintptr(params.SQOffset.RingMask)))
    58  	sq.entries = (*uint32)(unsafe.Pointer(sq.ptr + uintptr(params.SQOffset.RingEntries)))
    59  	sq.flags = (*uint32)(unsafe.Pointer(sq.ptr + uintptr(params.SQOffset.Flags)))
    60  	sq.dropped = (*uint32)(unsafe.Pointer(sq.ptr + uintptr(params.SQOffset.Dropped)))
    61  
    62  	sq.array = *(*[]uint32)(unsafe.Pointer(&reflect.SliceHeader{
    63  		Data: sq.ptr + uintptr(params.SQOffset.Array),
    64  		Len:  int(params.SQEntries),
    65  		Cap:  int(params.SQEntries),
    66  	}))
    67  
    68  	return nil
    69  }
    70  
    71  func mmapCQ(iour *IOURing) (err error) {
    72  	params := iour.params
    73  	cq := iour.cq
    74  
    75  	cqes := makeCompletionQueueRing(params.Flags)
    76  
    77  	cq.size = params.CQOffset.Cqes + params.CQEntries*cqes.entrySz()
    78  	if cq.ptr == 0 {
    79  		cq.ptr, err = mmap(iour.fd, cq.size, iouring_syscall.IORING_OFF_CQ_RING)
    80  		if err != nil {
    81  			return fmt.Errorf("mmap cq ring: %w", err)
    82  		}
    83  	}
    84  
    85  	cq.head = (*uint32)(unsafe.Pointer(cq.ptr + uintptr(params.CQOffset.Head)))
    86  	cq.tail = (*uint32)(unsafe.Pointer(cq.ptr + uintptr(params.CQOffset.Tail)))
    87  	cq.mask = (*uint32)(unsafe.Pointer(cq.ptr + uintptr(params.CQOffset.RingMask)))
    88  	cq.entries = (*uint32)(unsafe.Pointer(cq.ptr + uintptr(params.CQOffset.RingEntries)))
    89  	cq.flags = (*uint32)(unsafe.Pointer(cq.ptr + uintptr(params.CQOffset.Flags)))
    90  	cq.overflow = (*uint32)(unsafe.Pointer(cq.ptr + uintptr(params.CQOffset.Overflow)))
    91  
    92  	cqes.assignQueue(cq.ptr+uintptr(params.CQOffset.Cqes), int(params.CQEntries))
    93  	cq.cqes = cqes
    94  
    95  	return nil
    96  }
    97  
    98  func mmapSQEs(iour *IOURing) error {
    99  	params := iour.params
   100  
   101  	sqes := makeSubmissionQueueRing(params.Flags)
   102  
   103  	ptr, err := mmap(iour.fd, params.SQEntries*sqes.entrySz(), iouring_syscall.IORING_OFF_SQES)
   104  	if err != nil {
   105  		return fmt.Errorf("mmap sqe array: %w", err)
   106  	}
   107  
   108  	sqes.assignQueue(ptr, int(params.SQEntries))
   109  	iour.sq.sqes = sqes
   110  
   111  	return nil
   112  }
   113  
   114  func munmapIOURing(iour *IOURing) error {
   115  	if iour.sq != nil && iour.sq.ptr != 0 {
   116  		if iour.sq.sqes.isActive() {
   117  			err := munmap(iour.sq.sqes.mappedPtr(), iour.sq.sqes.ringSz())
   118  			if err != nil {
   119  				return fmt.Errorf("ummap sqe array: %w", err)
   120  			}
   121  			iour.sq.sqes = nil
   122  		}
   123  
   124  		if err := munmap(iour.sq.ptr, iour.sq.size); err != nil {
   125  			return fmt.Errorf("munmap sq: %w", err)
   126  		}
   127  		if iour.sq.ptr == iour.cq.ptr {
   128  			iour.cq = nil
   129  		}
   130  		iour.sq = nil
   131  	}
   132  
   133  	if iour.cq != nil && iour.cq.ptr != 0 {
   134  		if err := munmap(iour.cq.ptr, iour.cq.size); err != nil {
   135  			return fmt.Errorf("munmap cq: %w", err)
   136  		}
   137  		iour.cq = nil
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func mmap(fd int, length uint32, offset uint64) (uintptr, error) {
   144  	ptr, _, errno := syscall.Syscall6(
   145  		syscall.SYS_MMAP,
   146  		0,
   147  		uintptr(length),
   148  		syscall.PROT_READ|syscall.PROT_WRITE,
   149  		syscall.MAP_SHARED|syscall.MAP_POPULATE,
   150  		uintptr(fd),
   151  		uintptr(offset),
   152  	)
   153  	if errno != 0 {
   154  		return 0, os.NewSyscallError("mmap", errno)
   155  	}
   156  	return uintptr(ptr), nil
   157  }
   158  
   159  func munmap(ptr uintptr, length uint32) error {
   160  	_, _, errno := syscall.Syscall(
   161  		syscall.SYS_MUNMAP,
   162  		ptr,
   163  		uintptr(length),
   164  		0,
   165  	)
   166  	if errno != 0 {
   167  		return os.NewSyscallError("munmap", errno)
   168  	}
   169  	return nil
   170  }