github.com/diamondburned/arikawa/v2@v2.1.0/utils/wsutil/op.go (about) 1 package wsutil 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 8 "github.com/pkg/errors" 9 10 "github.com/diamondburned/arikawa/v2/internal/moreatomic" 11 "github.com/diamondburned/arikawa/v2/utils/json" 12 ) 13 14 var ErrEmptyPayload = errors.New("empty payload") 15 16 // OPCode is a generic type for websocket OP codes. 17 type OPCode uint8 18 19 type OP struct { 20 Code OPCode `json:"op"` 21 Data json.Raw `json:"d,omitempty"` 22 23 // Only for Gateway Dispatch (op 0) 24 Sequence int64 `json:"s,omitempty"` 25 EventName string `json:"t,omitempty"` 26 } 27 28 func (op *OP) UnmarshalData(v interface{}) error { 29 return json.Unmarshal(op.Data, v) 30 } 31 32 func DecodeOP(ev Event) (*OP, error) { 33 if ev.Error != nil { 34 return nil, ev.Error 35 } 36 37 if len(ev.Data) == 0 { 38 return nil, ErrEmptyPayload 39 } 40 41 var op *OP 42 if err := json.Unmarshal(ev.Data, &op); err != nil { 43 return nil, errors.Wrap(err, "OP error: "+string(ev.Data)) 44 } 45 46 return op, nil 47 } 48 49 func AssertEvent(ev Event, code OPCode, v interface{}) (*OP, error) { 50 op, err := DecodeOP(ev) 51 if err != nil { 52 return nil, err 53 } 54 55 if op.Code != code { 56 return op, fmt.Errorf( 57 "unexpected OP Code: %d, expected %d (%s)", 58 op.Code, code, op.Data, 59 ) 60 } 61 62 if err := json.Unmarshal(op.Data, v); err != nil { 63 return op, errors.Wrap(err, "failed to decode data") 64 } 65 66 return op, nil 67 } 68 69 type EventHandler interface { 70 HandleOP(op *OP) error 71 } 72 73 func HandleEvent(h EventHandler, ev Event) error { 74 o, err := DecodeOP(ev) 75 if err != nil { 76 return err 77 } 78 79 return h.HandleOP(o) 80 } 81 82 // WaitForEvent blocks until fn() returns true. All incoming events are handled 83 // regardless. 84 func WaitForEvent(ctx context.Context, h EventHandler, ch <-chan Event, fn func(*OP) bool) error { 85 for { 86 select { 87 case e, ok := <-ch: 88 if !ok { 89 return errors.New("event not found and event channel is closed") 90 } 91 92 o, err := DecodeOP(e) 93 if err != nil { 94 return err 95 } 96 97 // Handle the *OP first, in case it's an Invalid Session. This should 98 // also prevent a race condition with things that need Ready after 99 // Open(). 100 if err := h.HandleOP(o); err != nil { 101 return err 102 } 103 104 // Are these events what we're looking for? If we've found the event, 105 // return. 106 if fn(o) { 107 return nil 108 } 109 110 case <-ctx.Done(): 111 return ctx.Err() 112 } 113 } 114 } 115 116 type ExtraHandlers struct { 117 mutex sync.Mutex 118 handlers map[uint32]*ExtraHandler 119 serial uint32 120 } 121 122 type ExtraHandler struct { 123 Check func(*OP) bool 124 send chan *OP 125 126 closed moreatomic.Bool 127 } 128 129 func (ex *ExtraHandlers) Add(check func(*OP) bool) (<-chan *OP, func()) { 130 handler := &ExtraHandler{ 131 Check: check, 132 send: make(chan *OP), 133 } 134 135 ex.mutex.Lock() 136 defer ex.mutex.Unlock() 137 138 if ex.handlers == nil { 139 ex.handlers = make(map[uint32]*ExtraHandler, 1) 140 } 141 142 i := ex.serial 143 ex.serial++ 144 145 ex.handlers[i] = handler 146 147 return handler.send, func() { 148 // Check the atomic bool before acquiring the mutex. Might help a bit in 149 // performance. 150 if handler.closed.Get() { 151 return 152 } 153 154 ex.mutex.Lock() 155 defer ex.mutex.Unlock() 156 157 delete(ex.handlers, i) 158 } 159 } 160 161 // Check runs and sends OP data. It is not thread-safe. 162 func (ex *ExtraHandlers) Check(op *OP) { 163 ex.mutex.Lock() 164 defer ex.mutex.Unlock() 165 166 for i, handler := range ex.handlers { 167 if handler.Check(op) { 168 // Attempt to send. 169 handler.send <- op 170 171 // Mark the handler as closed. 172 handler.closed.Set(true) 173 174 // Delete the handler. 175 delete(ex.handlers, i) 176 } 177 } 178 }