github.com/diamondburned/arikawa/v2@v2.1.0/utils/wsutil/op.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/pkg/errors"
     9  
    10  	"github.com/diamondburned/arikawa/v2/internal/moreatomic"
    11  	"github.com/diamondburned/arikawa/v2/utils/json"
    12  )
    13  
    14  var ErrEmptyPayload = errors.New("empty payload")
    15  
    16  // OPCode is a generic type for websocket OP codes.
    17  type OPCode uint8
    18  
    19  type OP struct {
    20  	Code OPCode   `json:"op"`
    21  	Data json.Raw `json:"d,omitempty"`
    22  
    23  	// Only for Gateway Dispatch (op 0)
    24  	Sequence  int64  `json:"s,omitempty"`
    25  	EventName string `json:"t,omitempty"`
    26  }
    27  
    28  func (op *OP) UnmarshalData(v interface{}) error {
    29  	return json.Unmarshal(op.Data, v)
    30  }
    31  
    32  func DecodeOP(ev Event) (*OP, error) {
    33  	if ev.Error != nil {
    34  		return nil, ev.Error
    35  	}
    36  
    37  	if len(ev.Data) == 0 {
    38  		return nil, ErrEmptyPayload
    39  	}
    40  
    41  	var op *OP
    42  	if err := json.Unmarshal(ev.Data, &op); err != nil {
    43  		return nil, errors.Wrap(err, "OP error: "+string(ev.Data))
    44  	}
    45  
    46  	return op, nil
    47  }
    48  
    49  func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) {
    50  	op, err := DecodeOP(ev)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	if op.Code != code {
    56  		return op, fmt.Errorf(
    57  			"unexpected OP Code: %d, expected %d (%s)",
    58  			op.Code, code, op.Data,
    59  		)
    60  	}
    61  
    62  	if err := json.Unmarshal(op.Data, v); err != nil {
    63  		return op, errors.Wrap(err, "failed to decode data")
    64  	}
    65  
    66  	return op, nil
    67  }
    68  
    69  type EventHandler interface {
    70  	HandleOP(op *OP) error
    71  }
    72  
    73  func HandleEvent(h EventHandler, ev Event) error {
    74  	o, err := DecodeOP(ev)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	return h.HandleOP(o)
    80  }
    81  
    82  // WaitForEvent blocks until fn() returns true. All incoming events are handled
    83  // regardless.
    84  func WaitForEvent(ctx context.Context, h EventHandler, ch <-chan Event, fn func(*OP) bool) error {
    85  	for {
    86  		select {
    87  		case e, ok := <-ch:
    88  			if !ok {
    89  				return errors.New("event not found and event channel is closed")
    90  			}
    91  
    92  			o, err := DecodeOP(e)
    93  			if err != nil {
    94  				return err
    95  			}
    96  
    97  			// Handle the *OP first, in case it's an Invalid Session. This should
    98  			// also prevent a race condition with things that need Ready after
    99  			// Open().
   100  			if err := h.HandleOP(o); err != nil {
   101  				return err
   102  			}
   103  
   104  			// Are these events what we're looking for? If we've found the event,
   105  			// return.
   106  			if fn(o) {
   107  				return nil
   108  			}
   109  
   110  		case <-ctx.Done():
   111  			return ctx.Err()
   112  		}
   113  	}
   114  }
   115  
   116  type ExtraHandlers struct {
   117  	mutex    sync.Mutex
   118  	handlers map[uint32]*ExtraHandler
   119  	serial   uint32
   120  }
   121  
   122  type ExtraHandler struct {
   123  	Check func(*OP) bool
   124  	send  chan *OP
   125  
   126  	closed moreatomic.Bool
   127  }
   128  
   129  func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) {
   130  	handler := &ExtraHandler{
   131  		Check: check,
   132  		send:  make(chan *OP),
   133  	}
   134  
   135  	ex.mutex.Lock()
   136  	defer ex.mutex.Unlock()
   137  
   138  	if ex.handlers == nil {
   139  		ex.handlers = make(map[uint32]*ExtraHandler, 1)
   140  	}
   141  
   142  	i := ex.serial
   143  	ex.serial++
   144  
   145  	ex.handlers[i] = handler
   146  
   147  	return handler.send, func() {
   148  		// Check the atomic bool before acquiring the mutex. Might help a bit in
   149  		// performance.
   150  		if handler.closed.Get() {
   151  			return
   152  		}
   153  
   154  		ex.mutex.Lock()
   155  		defer ex.mutex.Unlock()
   156  
   157  		delete(ex.handlers, i)
   158  	}
   159  }
   160  
   161  // Check runs and sends OP data. It is not thread-safe.
   162  func (ex *ExtraHandlers) Check(op *OP) {
   163  	ex.mutex.Lock()
   164  	defer ex.mutex.Unlock()
   165  
   166  	for i, handler := range ex.handlers {
   167  		if handler.Check(op) {
   168  			// Attempt to send.
   169  			handler.send <- op
   170  
   171  			// Mark the handler as closed.
   172  			handler.closed.Set(true)
   173  
   174  			// Delete the handler.
   175  			delete(ex.handlers, i)
   176  		}
   177  	}
   178  }