github.com/diamondburned/arikawa/v2@v2.1.0/gateway/gateway.go (about)

     1  // Package gateway handles the Discord gateway (or Websocket) connection, its
     2  // events, and everything related to it. This includes logging into the
     3  // Websocket.
     4  //
     5  // This package does not abstract events and function handlers; instead, it
     6  // leaves that to the session package. This package exposes only a single Events
     7  // channel.
     8  package gateway
     9  
    10  import (
    11  	"context"
    12  	"net/http"
    13  	"net/url"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/diamondburned/arikawa/v2/api"
    19  	"github.com/diamondburned/arikawa/v2/discord"
    20  	"github.com/diamondburned/arikawa/v2/internal/moreatomic"
    21  	"github.com/diamondburned/arikawa/v2/utils/httputil"
    22  	"github.com/diamondburned/arikawa/v2/utils/json"
    23  	"github.com/diamondburned/arikawa/v2/utils/wsutil"
    24  	"github.com/gorilla/websocket"
    25  	"github.com/pkg/errors"
    26  )
    27  
    28  var (
    29  	EndpointGateway    = api.Endpoint + "gateway"
    30  	EndpointGatewayBot = api.EndpointGateway + "/bot"
    31  
    32  	Version  = api.Version
    33  	Encoding = "json"
    34  )
    35  
    36  var (
    37  	ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
    38  	ErrWSMaxTries       = errors.New(
    39  		"could not connect to the Discord gateway before reaching the timeout")
    40  	ErrClosed = errors.New("the gateway is closed and cannot reconnect")
    41  )
    42  
    43  // see
    44  // https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes
    45  const errCodeShardingRequired = 4011
    46  
    47  // BotData contains the GatewayURL as well as extra metadata on how to
    48  // shard bots.
    49  type BotData struct {
    50  	URL        string             `json:"url"`
    51  	Shards     int                `json:"shards,omitempty"`
    52  	StartLimit *SessionStartLimit `json:"session_start_limit"`
    53  }
    54  
    55  // SessionStartLimit is the information on the current session start limit. It's
    56  // used in BotData.
    57  type SessionStartLimit struct {
    58  	Total          int                  `json:"total"`
    59  	Remaining      int                  `json:"remaining"`
    60  	ResetAfter     discord.Milliseconds `json:"reset_after"`
    61  	MaxConcurrency int                  `json:"max_concurrency"`
    62  }
    63  
    64  // URL asks Discord for a Websocket URL to the Gateway.
    65  func URL() (string, error) {
    66  	var g BotData
    67  
    68  	c := httputil.NewClient()
    69  	if err := c.RequestJSON(&g, "GET", EndpointGateway); err != nil {
    70  		return "", err
    71  	}
    72  
    73  	return g.URL, nil
    74  }
    75  
    76  // BotURL fetches the Gateway URL along with extra metadata. The token
    77  // passed in will NOT be prefixed with Bot.
    78  func BotURL(token string) (*BotData, error) {
    79  	var g *BotData
    80  
    81  	return g, httputil.NewClient().RequestJSON(
    82  		&g, "GET",
    83  		EndpointGatewayBot,
    84  		httputil.WithHeaders(http.Header{
    85  			"Authorization": {token},
    86  		}),
    87  	)
    88  }
    89  
    90  type Gateway struct {
    91  	WS *wsutil.Websocket
    92  
    93  	// WSTimeout is a timeout for an arbitrary action. An example of this is the
    94  	// timeout for Start and the timeout for sending each Gateway command
    95  	// independently.
    96  	WSTimeout time.Duration
    97  
    98  	// ReconnectTimeout is the timeout used during reconnection.
    99  	// If the a connection to the gateway can't be established before the
   100  	// duration passes, the Gateway will be closed and FatalErrorCallback will
   101  	// be called.
   102  	//
   103  	// Setting this to 0 is equivalent to no timeout.
   104  	//
   105  	// Deprecated: It is recommended to use ReconnectAttempts instead.
   106  	ReconnectTimeout time.Duration
   107  	// ReconnectAttempts are the amount of attempts made to Reconnect, before
   108  	// aborting. If this set to 0, unlimited attempts will be made.
   109  	ReconnectAttempts uint
   110  
   111  	// All events sent over are pointers to Event structs (structs suffixed with
   112  	// "Event"). This shouldn't be accessed if the Gateway is created with a
   113  	// Session.
   114  	Events chan Event
   115  
   116  	sessionMu sync.RWMutex
   117  	sessionID string
   118  
   119  	Identifier *Identifier
   120  	Sequence   *moreatomic.Int64
   121  
   122  	PacerLoop wsutil.PacemakerLoop
   123  
   124  	ErrorLog func(err error) // default to log.Println
   125  	// FatalErrorCallback is called, if the Gateway exits fatally. At the point
   126  	// of calling, the gateway will be already closed.
   127  	//
   128  	// Currently this will only be called, if the ReconnectTimeout was changed
   129  	// to a definite timeout, and connection could not be established during
   130  	// that time.
   131  	// err will be ErrWSMaxTries in that case.
   132  	//
   133  	// Defaults to noop.
   134  	FatalErrorCallback func(err error)
   135  
   136  	// OnScalingRequired is the function called, if Discord closes with error
   137  	// code 4011 aka Scaling Required. At the point of calling, the Gateway
   138  	// will be closed, and can, after increasing the number of shards, be
   139  	// reopened using Open. Reconnect or ReconnectCtx, however, will not be
   140  	// available as the session is invalidated.
   141  	OnScalingRequired func()
   142  
   143  	// AfterClose is called after each close or pause. It is used mainly for
   144  	// reconnections or any type of connection interruptions.
   145  	//
   146  	// Constructors will use a no-op function by default.
   147  	AfterClose func(err error)
   148  
   149  	waitGroup sync.WaitGroup
   150  
   151  	closed chan struct{}
   152  }
   153  
   154  // NewGatewayWithIntents creates a new Gateway with the given intents and the
   155  // default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
   156  func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
   157  	g, err := NewGateway(token)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	for _, intent := range intents {
   163  		g.AddIntents(intent)
   164  	}
   165  
   166  	return g, nil
   167  }
   168  
   169  // NewGateway creates a new Gateway to the default Discord server.
   170  func NewGateway(token string) (*Gateway, error) {
   171  	return NewIdentifiedGateway(DefaultIdentifier(token))
   172  }
   173  
   174  // NewIdentifiedGateway creates a new Gateway with the given gateway identifier
   175  // and the default everything. Sharded bots should prefer this function for the
   176  // shared identifier.
   177  func NewIdentifiedGateway(id *Identifier) (*Gateway, error) {
   178  	var gatewayURL string
   179  	var botData *BotData
   180  	var err error
   181  
   182  	if strings.HasPrefix(id.Token, "Bot ") {
   183  		botData, err = BotURL(id.Token)
   184  		if err != nil {
   185  			return nil, errors.Wrap(err, "failed to get bot data")
   186  		}
   187  		gatewayURL = botData.URL
   188  
   189  	} else {
   190  		gatewayURL, err = URL()
   191  		if err != nil {
   192  			return nil, errors.Wrap(err, "failed to get gateway endpoint")
   193  		}
   194  	}
   195  
   196  	// Parameters for the gateway
   197  	param := url.Values{
   198  		"v":        {Version},
   199  		"encoding": {Encoding},
   200  	}
   201  
   202  	// Append the form to the URL
   203  	gatewayURL += "?" + param.Encode()
   204  	gateway := NewCustomIdentifiedGateway(gatewayURL, id)
   205  
   206  	// Use the supplied connect rate limit, if any.
   207  	if botData != nil && botData.StartLimit != nil {
   208  		resetAt := time.Now().Add(botData.StartLimit.ResetAfter.Duration())
   209  		limiter := gateway.Identifier.IdentifyGlobalLimit
   210  
   211  		// Update the burst to be the current given time and reset it back to
   212  		// the default when the given time is reached.
   213  		limiter.SetBurst(botData.StartLimit.Remaining)
   214  		limiter.SetBurstAt(resetAt, botData.StartLimit.Total)
   215  
   216  		// Update the maximum number of identify requests allowed per 5s.
   217  		gateway.Identifier.IdentifyShortLimit.SetBurst(botData.StartLimit.MaxConcurrency)
   218  	}
   219  
   220  	return gateway, nil
   221  }
   222  
   223  // NewCustomGateway creates a new Gateway with a custom gateway URL and a new
   224  // Identifier. Most bots connecting to the official server should not use these
   225  // custom functions.
   226  func NewCustomGateway(gatewayURL, token string) *Gateway {
   227  	return NewCustomIdentifiedGateway(gatewayURL, DefaultIdentifier(token))
   228  }
   229  
   230  // NewCustomIdentifiedGateway creates a new Gateway with a custom gateway URL
   231  // and a pre-existing Identifier. Refer to NewCustomGateway.
   232  func NewCustomIdentifiedGateway(gatewayURL string, id *Identifier) *Gateway {
   233  	return &Gateway{
   234  		WS:        wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
   235  		WSTimeout: wsutil.WSTimeout,
   236  
   237  		Events:     make(chan Event, wsutil.WSBuffer),
   238  		Identifier: id,
   239  		Sequence:   moreatomic.NewInt64(0),
   240  
   241  		ErrorLog:   wsutil.WSError,
   242  		AfterClose: func(error) {},
   243  	}
   244  }
   245  
   246  // AddIntents adds a Gateway Intent before connecting to the Gateway. As such,
   247  // this function will only work before Open() is called.
   248  func (g *Gateway) AddIntents(i Intents) {
   249  	g.Identifier.Intents |= i
   250  }
   251  
   252  // HasIntents reports if the Gateway has the passed Intents.
   253  //
   254  // If no intents are set, i.e. if using a user account HasIntents will always
   255  // return true.
   256  func (g *Gateway) HasIntents(intents Intents) bool {
   257  	if g.Identifier.Intents == 0 {
   258  		return true
   259  	}
   260  
   261  	return g.Identifier.Intents.Has(intents)
   262  }
   263  
   264  // Close closes the underlying Websocket connection, invalidating the session
   265  // ID. A new gateway connection can be established, by calling Open again.
   266  //
   267  // If the wsutil.Connection of the Gateway's WS implements
   268  // wsutil.GracefulCloser, such as the default one, Close will send a closing
   269  // frame before ending the connection, closing it gracefully. This will cause
   270  // the bot to appear as offline instantly.
   271  func (g *Gateway) Close() error {
   272  	return g.close(true)
   273  }
   274  
   275  // CloseGracefully attempts to close the gateway connection gracefully, by
   276  // sending a closing frame before ending the connection. This will cause the
   277  // gateway's session id to be rendered invalid.
   278  //
   279  // Note that a graceful closure is only possible, if the wsutil.Connection of
   280  // the Gateway's Websocket implements wsutil.GracefulCloser.
   281  //
   282  // Deprecated: Close behaves identically to CloseGracefully, and should be used
   283  // instead.
   284  func (g *Gateway) CloseGracefully() error {
   285  	return g.Close()
   286  }
   287  
   288  // Pause pauses the Gateway connection, by ending the connection without
   289  // sending a closing frame. This allows the connection to be resumed at a later
   290  // point, by calling Reconnect or ReconnectCtx.
   291  func (g *Gateway) Pause() error {
   292  	return g.close(false)
   293  }
   294  
   295  func (g *Gateway) close(graceful bool) (err error) {
   296  	wsutil.WSDebug("Trying to close. Pacemaker check skipped.")
   297  	wsutil.WSDebug("Closing the Websocket...")
   298  
   299  	if graceful {
   300  		err = g.WS.CloseGracefully()
   301  	} else {
   302  		err = g.WS.Close()
   303  	}
   304  
   305  	if errors.Is(err, wsutil.ErrWebsocketClosed) {
   306  		wsutil.WSDebug("Websocket already closed.")
   307  		return nil
   308  	}
   309  
   310  	// Explicitly signal the pacemaker loop to stop. We should do this in case
   311  	// the Start function exited before it could bind the event channel into the
   312  	// loop.
   313  	g.PacerLoop.Stop()
   314  	wsutil.WSDebug("Websocket closed; error:", err)
   315  
   316  	wsutil.WSDebug("Waiting for the Pacemaker loop to exit.")
   317  	g.waitGroup.Wait()
   318  	wsutil.WSDebug("Pacemaker loop exited.")
   319  
   320  	g.AfterClose(err)
   321  	wsutil.WSDebug("AfterClose callback finished.")
   322  
   323  	if graceful {
   324  		// If a Reconnect is in progress, signal to cancel.
   325  		close(g.closed)
   326  
   327  		// Delete our session id, as we just invalidated it.
   328  		g.sessionMu.Lock()
   329  		g.sessionID = ""
   330  		g.sessionMu.Unlock()
   331  	}
   332  
   333  	return err
   334  }
   335  
   336  // SessionID returns the session ID received after Ready. This function is
   337  // concurrently safe.
   338  func (g *Gateway) SessionID() string {
   339  	g.sessionMu.RLock()
   340  	defer g.sessionMu.RUnlock()
   341  
   342  	return g.sessionID
   343  }
   344  
   345  // Reconnect tries to reconnect to the Gateway until the ReconnectAttempts or
   346  // ReconnectTimeout is reached.
   347  func (g *Gateway) Reconnect() {
   348  	ctx := context.Background()
   349  
   350  	if g.ReconnectTimeout > 0 {
   351  		var cancel func()
   352  		ctx, cancel = context.WithTimeout(context.Background(), g.ReconnectTimeout)
   353  
   354  		defer cancel()
   355  	}
   356  
   357  	g.ReconnectCtx(ctx)
   358  }
   359  
   360  // ReconnectCtx attempts to Reconnect until context expires.
   361  // If the context expires FatalErrorCallback will be called with ErrWSMaxTries,
   362  // and the last error returned by Open will be returned.
   363  func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
   364  	wsutil.WSDebug("Reconnecting...")
   365  
   366  	// Guarantee the gateway is already closed. Ignore its error, as we're
   367  	// redialing anyway.
   368  	g.Pause()
   369  
   370  	for try := uint(1); g.ReconnectAttempts == 0 || g.ReconnectAttempts >= try; try++ {
   371  		select {
   372  		case <-g.closed:
   373  			g.ErrorLog(ErrClosed)
   374  			return ErrClosed
   375  		case <-ctx.Done():
   376  			wsutil.WSDebug("Unable to Reconnect after", try, "attempts, aborting")
   377  			g.FatalErrorCallback(ErrWSMaxTries)
   378  			return err
   379  		default:
   380  		}
   381  
   382  		wsutil.WSDebug("Trying to dial, attempt", try)
   383  
   384  		// if we encounter an error, make sure we return it, and not nil
   385  		if oerr := g.OpenContext(ctx); oerr != nil {
   386  			err = oerr
   387  			g.ErrorLog(oerr)
   388  
   389  			wait := time.Duration(4+2*try) * time.Second
   390  			if wait > 60*time.Second {
   391  				wait = 60 * time.Second
   392  			}
   393  
   394  			time.Sleep(wait)
   395  			continue
   396  		}
   397  
   398  		wsutil.WSDebug("Started after attempt:", try)
   399  		return nil
   400  	}
   401  
   402  	wsutil.WSDebug("Unable to Reconnect after", g.ReconnectAttempts, "attempts, aborting")
   403  	return err
   404  }
   405  
   406  // Open connects to the Websocket and authenticate it. You should usually use
   407  // this function over Start().
   408  func (g *Gateway) Open() error {
   409  	ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
   410  	defer cancel()
   411  
   412  	return g.OpenContext(ctx)
   413  }
   414  
   415  // OpenContext connects to the Websocket and authenticates it. You should
   416  // usually use this function over Start(). The given context provides
   417  // cancellation and timeout.
   418  func (g *Gateway) OpenContext(ctx context.Context) error {
   419  	// Reconnect to the Gateway
   420  	if err := g.WS.Dial(ctx); err != nil {
   421  		return errors.Wrap(err, "failed to Reconnect")
   422  	}
   423  
   424  	wsutil.WSDebug("Trying to start...")
   425  
   426  	// Try to resume the connection
   427  	if err := g.StartCtx(ctx); err != nil {
   428  		return err
   429  	}
   430  
   431  	// Started successfully, return
   432  	return nil
   433  }
   434  
   435  // Start calls StartCtx with a background context. You wouldn't usually use this
   436  // function, but Open() instead.
   437  func (g *Gateway) Start() error {
   438  	ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
   439  	defer cancel()
   440  
   441  	return g.StartCtx(ctx)
   442  }
   443  
   444  // StartCtx authenticates with the websocket, or resume from a dead Websocket
   445  // connection. You wouldn't usually use this function, but OpenCtx() instead.
   446  func (g *Gateway) StartCtx(ctx context.Context) error {
   447  	g.closed = make(chan struct{})
   448  
   449  	if err := g.start(ctx); err != nil {
   450  		wsutil.WSDebug("Start failed:", err)
   451  
   452  		// Close can be called with the mutex still acquired here, as the
   453  		// pacemaker hasn't started yet.
   454  		if err := g.Close(); err != nil {
   455  			wsutil.WSDebug("Failed to close after start fail:", err)
   456  		}
   457  		return err
   458  	}
   459  
   460  	return nil
   461  }
   462  
   463  func (g *Gateway) start(ctx context.Context) error {
   464  	// This is where we'll get our events
   465  	ch := g.WS.Listen()
   466  
   467  	// Create a new Hello event and wait for it.
   468  	var hello HelloEvent
   469  	// Wait for an OP 10 Hello.
   470  	select {
   471  	case e, ok := <-ch:
   472  		if !ok {
   473  			return errors.New("unexpected ws close while waiting for Hello")
   474  		}
   475  		if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
   476  			return errors.Wrap(err, "error at Hello")
   477  		}
   478  	case <-ctx.Done():
   479  		return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
   480  	}
   481  
   482  	wsutil.WSDebug("Hello received; duration:", hello.HeartbeatInterval)
   483  
   484  	// Start the event handler, which also handles the pacemaker death signal.
   485  	g.waitGroup.Add(1)
   486  
   487  	// Use the pacemaker loop.
   488  	g.PacerLoop.StartBeating(hello.HeartbeatInterval.Duration(), g, func(err error) {
   489  		g.waitGroup.Done() // mark so Close() can exit.
   490  		wsutil.WSDebug("Event loop stopped with error:", err)
   491  
   492  		// If Discord signals us sharding is required, do not attempt to
   493  		// Reconnect. Instead invalidate our session id, as we cannot resume,
   494  		// call OnShardingRequired, and exit.
   495  		var cerr *websocket.CloseError
   496  		if errors.As(err, &cerr) && cerr != nil && cerr.Code == errCodeShardingRequired {
   497  			g.ErrorLog(cerr)
   498  
   499  			g.sessionMu.Lock()
   500  			g.sessionID = ""
   501  			g.sessionMu.Unlock()
   502  
   503  			g.OnScalingRequired()
   504  			return
   505  		}
   506  
   507  		// Bail if there is no error or if the error is an explicit close, as
   508  		// there might be an ongoing reconnection.
   509  		if err == nil || errors.Is(err, wsutil.ErrWebsocketClosed) {
   510  			return
   511  		}
   512  
   513  		// Only attempt to Reconnect if we have a session ID at all. We may not
   514  		// have one if we haven't even connected successfully once.
   515  		if g.SessionID() != "" {
   516  			g.ErrorLog(err)
   517  			g.Reconnect()
   518  		}
   519  	})
   520  
   521  	// Send Discord either the Identify packet (if it's a fresh connection), or
   522  	// a Resume packet (if it's a dead connection).
   523  	if g.SessionID() == "" {
   524  		// SessionID is empty, so this is a completely new session.
   525  		if err := g.IdentifyCtx(ctx); err != nil {
   526  			return errors.Wrap(err, "failed to identify")
   527  		}
   528  	} else {
   529  		if err := g.ResumeCtx(ctx); err != nil {
   530  			return errors.Wrap(err, "failed to resume")
   531  		}
   532  	}
   533  
   534  	// Expect either READY or RESUMED before continuing.
   535  	wsutil.WSDebug("Waiting for either READY or RESUMED.")
   536  
   537  	// WaitForEvent should
   538  	err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
   539  		switch op.EventName {
   540  		case "READY":
   541  			wsutil.WSDebug("Found READY event.")
   542  			return true
   543  		case "RESUMED":
   544  			wsutil.WSDebug("Found RESUMED event.")
   545  			return true
   546  		}
   547  		return false
   548  	})
   549  
   550  	if err != nil {
   551  		return errors.Wrap(err, "first error")
   552  	}
   553  
   554  	// Bind the event channel to the pacemaker loop.
   555  	g.PacerLoop.SetEventChannel(ch)
   556  
   557  	wsutil.WSDebug("Started successfully.")
   558  
   559  	return nil
   560  }
   561  
   562  // SendCtx is a low-level function to send an OP payload to the Gateway. Most
   563  // users shouldn't touch this, unless they know what they're doing.
   564  func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
   565  	var op = wsutil.OP{
   566  		Code: code,
   567  	}
   568  
   569  	if v != nil {
   570  		b, err := json.Marshal(v)
   571  		if err != nil {
   572  			return errors.Wrap(err, "failed to encode v")
   573  		}
   574  
   575  		op.Data = b
   576  	}
   577  
   578  	b, err := json.Marshal(op)
   579  	if err != nil {
   580  		return errors.Wrap(err, "failed to encode payload")
   581  	}
   582  
   583  	// WS should already be thread-safe.
   584  	return g.WS.SendCtx(ctx, b)
   585  }