github.com/iceber/iouring-go@v0.0.0-20230403020409-002cfd2e2a90/request.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package iouring
     5  
     6  import (
     7  	"errors"
     8  	"sync"
     9  	"sync/atomic"
    10  	"syscall"
    11  
    12  	iouring_syscall "github.com/iceber/iouring-go/syscall"
    13  )
    14  
    15  type ResultResolver func(req Request)
    16  
    17  type RequestCallback func(result Result) error
    18  
    19  type Request interface {
    20  	Result
    21  
    22  	Cancel() (Request, error)
    23  	Done() <-chan struct{}
    24  
    25  	GetRes() (int, error)
    26  	// Can Only be used in ResultResolver
    27  	SetResult(r0, r1 interface{}, err error) error
    28  }
    29  
    30  type Result interface {
    31  	Fd() int
    32  	Opcode() uint8
    33  	GetRequestBuffer() (b0, b1 []byte)
    34  	GetRequestBuffers() [][]byte
    35  	GetRequestInfo() interface{}
    36  	FreeRequestBuffer()
    37  
    38  	Err() error
    39  	ReturnValue0() interface{}
    40  	ReturnValue1() interface{}
    41  	ReturnExtra1() uint64
    42  	ReturnExtra2() uint64
    43  	ReturnFd() (int, error)
    44  	ReturnInt() (int, error)
    45  
    46  	Callback() error
    47  }
    48  
    49  var _ Request = &request{}
    50  
    51  type request struct {
    52  	iour *IOURing
    53  
    54  	id     uint64
    55  	opcode uint8
    56  	res    int32
    57  
    58  	once      sync.Once
    59  	resolving bool
    60  	resolver  ResultResolver
    61  
    62  	callback RequestCallback
    63  
    64  	fd int
    65  	b0 []byte
    66  	b1 []byte
    67  	bs [][]byte
    68  
    69  	err  error
    70  	r0   interface{}
    71  	r1   interface{}
    72  	ext1 uint64
    73  	ext2 uint64
    74  
    75  	requestInfo interface{}
    76  
    77  	set  *requestSet
    78  	done chan struct{}
    79  }
    80  
    81  func (req *request) resolve() {
    82  	if req.resolver == nil {
    83  		return
    84  	}
    85  
    86  	select {
    87  	case <-req.done:
    88  	default:
    89  		return
    90  	}
    91  
    92  	req.once.Do(func() {
    93  		req.resolving = true
    94  		req.resolver(req)
    95  		req.resolving = false
    96  
    97  		req.resolver = nil
    98  	})
    99  }
   100  
   101  func (req *request) complate(cqe iouring_syscall.CompletionQueueEvent) {
   102  	req.res = cqe.Result()
   103  	req.ext1 = cqe.Extra1()
   104  	req.ext2 = cqe.Extra2()
   105  	req.iour = nil
   106  	close(req.done)
   107  
   108  	if req.set != nil {
   109  		req.set.complateOne()
   110  		req.set = nil
   111  	}
   112  }
   113  
   114  func (req *request) isDone() bool {
   115  	select {
   116  	case <-req.done:
   117  		return true
   118  	default:
   119  	}
   120  	return false
   121  }
   122  
   123  func (req *request) Callback() error {
   124  	if !req.isDone() {
   125  		return ErrRequestCompleted
   126  	}
   127  
   128  	if req.callback == nil {
   129  		return ErrNoRequestCallback
   130  	}
   131  
   132  	return req.callback(req)
   133  }
   134  
   135  // Cancel request if request is not completed
   136  func (req *request) Cancel() (Request, error) {
   137  	if req.isDone() {
   138  		return nil, ErrRequestCompleted
   139  	}
   140  
   141  	return req.iour.submitCancel(req.id)
   142  }
   143  
   144  func (req *request) Done() <-chan struct{} {
   145  	return req.done
   146  }
   147  
   148  func (req *request) Opcode() uint8 {
   149  	return req.opcode
   150  }
   151  
   152  func (req *request) Fd() int {
   153  	return req.fd
   154  }
   155  
   156  func (req *request) GetRequestBuffer() (b0, b1 []byte) {
   157  	return req.b0, req.b1
   158  }
   159  
   160  func (req *request) GetRequestBuffers() [][]byte {
   161  	return req.bs
   162  }
   163  
   164  func (req *request) GetRequestInfo() interface{} {
   165  	return req.requestInfo
   166  }
   167  
   168  func (req *request) Err() error {
   169  	req.resolve()
   170  	return req.err
   171  }
   172  
   173  func (req *request) ReturnValue0() interface{} {
   174  	req.resolve()
   175  	return req.r0
   176  }
   177  
   178  func (req *request) ReturnValue1() interface{} {
   179  	req.resolve()
   180  	return req.r1
   181  }
   182  
   183  func (req *request) ReturnExtra1() uint64 {
   184  	return req.ext1
   185  }
   186  
   187  func (req *request) ReturnExtra2() uint64 {
   188  	return req.ext2
   189  }
   190  
   191  func (req *request) ReturnFd() (int, error) {
   192  	return req.ReturnInt()
   193  }
   194  
   195  func (req *request) ReturnInt() (int, error) {
   196  	req.resolve()
   197  
   198  	if req.err != nil {
   199  		return -1, req.err
   200  	}
   201  
   202  	fd, ok := req.r0.(int)
   203  	if !ok {
   204  		return -1, errors.New("req value is not int")
   205  	}
   206  
   207  	return fd, nil
   208  }
   209  
   210  func (req *request) FreeRequestBuffer() {
   211  	req.b0 = nil
   212  	req.b1 = nil
   213  	req.bs = nil
   214  }
   215  
   216  func (req *request) GetRes() (int, error) {
   217  	if !req.isDone() {
   218  		return 0, ErrRequestNotCompleted
   219  	}
   220  
   221  	return int(req.res), nil
   222  }
   223  
   224  func (req *request) SetResult(r0, r1 interface{}, err error) error {
   225  	if !req.isDone() {
   226  		return ErrRequestNotCompleted
   227  	}
   228  	if !req.resolving {
   229  		return errors.New("request is not resolving")
   230  	}
   231  
   232  	req.r0, req.r1, req.err = r0, r1, err
   233  	return nil
   234  }
   235  
   236  type RequestSet interface {
   237  	Len() int
   238  	Done() <-chan struct{}
   239  	Requests() []Request
   240  	ErrResults() []Result
   241  }
   242  
   243  var _ RequestSet = &requestSet{}
   244  
   245  type requestSet struct {
   246  	requests []Request
   247  
   248  	complates int32
   249  	done      chan struct{}
   250  }
   251  
   252  func newRequestSet(userData []*UserData) *requestSet {
   253  	set := &requestSet{
   254  		requests: make([]Request, len(userData)),
   255  		done:     make(chan struct{}),
   256  	}
   257  
   258  	for i, data := range userData {
   259  		set.requests[i] = data.request
   260  		data.request.set = set
   261  	}
   262  	return set
   263  }
   264  
   265  func (set *requestSet) complateOne() {
   266  	if atomic.AddInt32(&set.complates, 1) == int32(len(set.requests)) {
   267  		close(set.done)
   268  	}
   269  }
   270  
   271  func (set *requestSet) Len() int {
   272  	return len(set.requests)
   273  }
   274  
   275  func (set *requestSet) Done() <-chan struct{} {
   276  	return set.done
   277  }
   278  
   279  func (set *requestSet) Requests() []Request {
   280  	return set.requests
   281  }
   282  
   283  func (set *requestSet) ErrResults() (results []Result) {
   284  	for _, req := range set.requests {
   285  		if req.Err() != nil {
   286  			results = append(results, req)
   287  		}
   288  	}
   289  	return
   290  }
   291  
   292  func errResolver(req Request) {
   293  	result := req.(*request)
   294  	if result.res < 0 {
   295  		result.err = syscall.Errno(-result.res)
   296  		if result.err == syscall.ECANCELED {
   297  			// request is canceled
   298  			result.err = ErrRequestCanceled
   299  		}
   300  	}
   301  }
   302  
   303  func fdResolver(req Request) {
   304  	result := req.(*request)
   305  	if errResolver(result); result.err != nil {
   306  		return
   307  	}
   308  	result.r0 = int(result.res)
   309  }
   310  
   311  func timeoutResolver(req Request) {
   312  	result := req.(*request)
   313  	if errResolver(result); result.err != nil {
   314  		// if timeout got completed through expiration of the timer
   315  		// result.res is -ETIME and result.err is syscall.ETIME
   316  		if result.err == syscall.ETIME {
   317  			result.err = nil
   318  			result.r0 = TimeoutExpiration
   319  		}
   320  		return
   321  	}
   322  
   323  	// if timeout got completed through requests completing
   324  	// result.res is 0
   325  	if result.res == 0 {
   326  		result.r0 = CountCompletion
   327  	}
   328  }
   329  
   330  func removeTimeoutResolver(req Request) {
   331  	result := req.(*request)
   332  	if errResolver(result); result.err != nil {
   333  		switch result.err {
   334  		case syscall.EBUSY:
   335  			// timeout request was found bu expiration was already in progress
   336  			result.err = ErrRequestCompleted
   337  		case syscall.ENOENT:
   338  			// timeout request not found
   339  			result.err = ErrRequestNotFound
   340  		}
   341  		return
   342  	}
   343  
   344  	// timeout request is found and cacelled successfully
   345  	// result.res value is 0
   346  }
   347  
   348  func cancelResolver(req Request) {
   349  	result := req.(*request)
   350  	if errResolver(result); result.err != nil {
   351  		switch result.err {
   352  		case syscall.ENOENT:
   353  			result.err = ErrRequestNotFound
   354  		case syscall.EALREADY:
   355  			result.err = nil
   356  			result.r0 = RequestMaybeCanceled
   357  		}
   358  		return
   359  	}
   360  
   361  	if result.res == 0 {
   362  		result.r0 = RequestCanceledSuccessfully
   363  	}
   364  }