github.com/diamondburned/arikawa@v1.3.14/gateway/gateway.go (about)

     1  // Package gateway handles the Discord gateway (or Websocket) connection, its
     2  // events, and everything related to it. This includes logging into the
     3  // Websocket.
     4  //
     5  // This package does not abstract events and function handlers; instead, it
     6  // leaves that to the session package. This package exposes only a single Events
     7  // channel.
     8  package gateway
     9  
    10  import (
    11  	"context"
    12  	"net/http"
    13  	"net/url"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/diamondburned/arikawa/api"
    18  	"github.com/diamondburned/arikawa/discord"
    19  	"github.com/diamondburned/arikawa/utils/httputil"
    20  	"github.com/diamondburned/arikawa/utils/json"
    21  	"github.com/diamondburned/arikawa/utils/wsutil"
    22  	"github.com/pkg/errors"
    23  )
    24  
    25  var (
    26  	EndpointGateway    = api.Endpoint + "gateway"
    27  	EndpointGatewayBot = api.EndpointGateway + "/bot"
    28  
    29  	Version  = "6"
    30  	Encoding = "json"
    31  	// Compress = "zlib-stream"
    32  )
    33  
    34  var (
    35  	ErrMissingForResume = errors.New("missing session ID or sequence for resuming")
    36  	ErrWSMaxTries       = errors.New("max tries reached")
    37  )
    38  
    39  // BotData contains the GatewayURL as well as extra metadata on how to
    40  // shard bots.
    41  type BotData struct {
    42  	URL        string             `json:"url"`
    43  	Shards     int                `json:"shards,omitempty"`
    44  	StartLimit *SessionStartLimit `json:"session_start_limit"`
    45  }
    46  
    47  // SessionStartLimit is the information on the current session start limit. It's
    48  // used in BotData.
    49  type SessionStartLimit struct {
    50  	Total      int                  `json:"total"`
    51  	Remaining  int                  `json:"remaining"`
    52  	ResetAfter discord.Milliseconds `json:"reset_after"`
    53  }
    54  
    55  // URL asks Discord for a Websocket URL to the Gateway.
    56  func URL() (string, error) {
    57  	var g BotData
    58  
    59  	return g.URL, httputil.NewClient().RequestJSON(
    60  		&g, "GET",
    61  		EndpointGateway,
    62  	)
    63  }
    64  
    65  // BotURL fetches the Gateway URL along with extra metadata. The token
    66  // passed in will NOT be prefixed with Bot.
    67  func BotURL(token string) (*BotData, error) {
    68  	var g *BotData
    69  
    70  	return g, httputil.NewClient().RequestJSON(
    71  		&g, "GET",
    72  		EndpointGatewayBot,
    73  		httputil.WithHeaders(http.Header{
    74  			"Authorization": {token},
    75  		}),
    76  	)
    77  }
    78  
    79  type Gateway struct {
    80  	WS        *wsutil.Websocket
    81  	WSTimeout time.Duration
    82  
    83  	// All events sent over are pointers to Event structs (structs suffixed with
    84  	// "Event"). This shouldn't be accessed if the Gateway is created with a
    85  	// Session.
    86  	Events chan Event
    87  
    88  	// SessionID is used to store the session ID received after Ready. It is not
    89  	// thread-safe.
    90  	SessionID string
    91  
    92  	Identifier *Identifier
    93  	Sequence   *Sequence
    94  
    95  	PacerLoop wsutil.PacemakerLoop
    96  
    97  	ErrorLog func(err error) // default to log.Println
    98  
    99  	// AfterClose is called after each close. Error can be non-nil, as this is
   100  	// called even when the Gateway is gracefully closed. It's used mainly for
   101  	// reconnections or any type of connection interruptions.
   102  	AfterClose func(err error) // noop by default
   103  
   104  	waitGroup sync.WaitGroup
   105  }
   106  
   107  // NewGatewayWithIntents creates a new Gateway with the given intents and the
   108  // default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
   109  func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
   110  	g, err := NewGateway(token)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	for _, intent := range intents {
   116  		g.AddIntent(intent)
   117  	}
   118  
   119  	return g, nil
   120  }
   121  
   122  // NewGateway creates a new Gateway with the default stdlib JSON driver. For
   123  // more information, refer to NewGatewayWithDriver.
   124  func NewGateway(token string) (*Gateway, error) {
   125  	URL, err := URL()
   126  	if err != nil {
   127  		return nil, errors.Wrap(err, "failed to get gateway endpoint")
   128  	}
   129  
   130  	// Parameters for the gateway
   131  	param := url.Values{
   132  		"v":        {Version},
   133  		"encoding": {Encoding},
   134  	}
   135  
   136  	// Append the form to the URL
   137  	URL += "?" + param.Encode()
   138  
   139  	return NewCustomGateway(URL, token), nil
   140  }
   141  
   142  func NewCustomGateway(gatewayURL, token string) *Gateway {
   143  	return &Gateway{
   144  		WS:        wsutil.NewCustom(wsutil.NewConn(), gatewayURL),
   145  		WSTimeout: wsutil.WSTimeout,
   146  
   147  		Events:     make(chan Event, wsutil.WSBuffer),
   148  		Identifier: DefaultIdentifier(token),
   149  		Sequence:   NewSequence(),
   150  
   151  		ErrorLog:   wsutil.WSError,
   152  		AfterClose: func(error) {},
   153  	}
   154  }
   155  
   156  // AddIntent adds a Gateway Intent before connecting to the Gateway. As
   157  // such, this function will only work before Open() is called.
   158  func (g *Gateway) AddIntent(i Intents) {
   159  	g.Identifier.Intents |= i
   160  }
   161  
   162  // Close closes the underlying Websocket connection.
   163  func (g *Gateway) Close() error {
   164  	wsutil.WSDebug("Trying to close. Pacemaker check skipped.")
   165  
   166  	wsutil.WSDebug("Closing the Websocket...")
   167  	err := g.WS.Close()
   168  
   169  	if errors.Is(err, wsutil.ErrWebsocketClosed) {
   170  		wsutil.WSDebug("Websocket already closed.")
   171  		return nil
   172  	}
   173  
   174  	wsutil.WSDebug("Websocket closed; error:", err)
   175  
   176  	wsutil.WSDebug("Waiting for the Pacemaker loop to exit.")
   177  	g.waitGroup.Wait()
   178  	wsutil.WSDebug("Pacemaker loop exited.")
   179  
   180  	g.AfterClose(err)
   181  	wsutil.WSDebug("AfterClose callback finished.")
   182  
   183  	return err
   184  }
   185  
   186  // Reconnect tries to reconnect forever. It will resume the connection if
   187  // possible. If an Invalid Session is received, it will start a fresh one.
   188  func (g *Gateway) Reconnect() {
   189  	for {
   190  		if err := g.ReconnectCtx(context.Background()); err != nil {
   191  			g.ErrorLog(err)
   192  		} else {
   193  			return
   194  		}
   195  	}
   196  }
   197  
   198  // ReconnectCtx attempts to reconnect until context expires. If context cannot
   199  // expire, then the gateway will try to reconnect forever.
   200  func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
   201  	wsutil.WSDebug("Reconnecting...")
   202  
   203  	// Guarantee the gateway is already closed. Ignore its error, as we're
   204  	// redialing anyway.
   205  	g.Close()
   206  
   207  	for i := 1; ; i++ {
   208  		select {
   209  		case <-ctx.Done():
   210  			return err
   211  		default:
   212  		}
   213  
   214  		wsutil.WSDebug("Trying to dial, attempt", i)
   215  
   216  		// Condition: err == ErrInvalidSession:
   217  		// If the connection is rate limited (documented behavior):
   218  		// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
   219  
   220  		// make sure we don't overwrite our last error
   221  		if err = g.OpenContext(ctx); err != nil {
   222  			g.ErrorLog(err)
   223  			continue
   224  		}
   225  
   226  		wsutil.WSDebug("Started after attempt:", i)
   227  
   228  		return
   229  	}
   230  }
   231  
   232  // Open connects to the Websocket and authenticate it. You should usually use
   233  // this function over Start().
   234  func (g *Gateway) Open() error {
   235  	ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
   236  	defer cancel()
   237  
   238  	return g.OpenContext(ctx)
   239  }
   240  
   241  // OpenContext connects to the Websocket and authenticates it. You should
   242  // usually use this function over Start(). The given context provides
   243  // cancellation and timeout.
   244  func (g *Gateway) OpenContext(ctx context.Context) error {
   245  	// Reconnect to the Gateway
   246  	if err := g.WS.Dial(ctx); err != nil {
   247  		return errors.Wrap(err, "failed to reconnect")
   248  	}
   249  
   250  	wsutil.WSDebug("Trying to start...")
   251  
   252  	// Try to resume the connection
   253  	if err := g.StartCtx(ctx); err != nil {
   254  		return err
   255  	}
   256  
   257  	// Started successfully, return
   258  	return nil
   259  }
   260  
   261  // Start calls StartCtx with a background context. You wouldn't usually use this
   262  // function, but Open() instead.
   263  func (g *Gateway) Start() error {
   264  	ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
   265  	defer cancel()
   266  
   267  	return g.StartCtx(ctx)
   268  }
   269  
   270  // StartCtx authenticates with the websocket, or resume from a dead Websocket
   271  // connection. You wouldn't usually use this function, but OpenCtx() instead.
   272  func (g *Gateway) StartCtx(ctx context.Context) error {
   273  	if err := g.start(ctx); err != nil {
   274  		wsutil.WSDebug("Start failed:", err)
   275  
   276  		// Close can be called with the mutex still acquired here, as the
   277  		// pacemaker hasn't started yet.
   278  		if err := g.Close(); err != nil {
   279  			wsutil.WSDebug("Failed to close after start fail:", err)
   280  		}
   281  		return err
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  func (g *Gateway) start(ctx context.Context) error {
   288  	// This is where we'll get our events
   289  	ch := g.WS.Listen()
   290  
   291  	// Create a new Hello event and wait for it.
   292  	var hello HelloEvent
   293  	// Wait for an OP 10 Hello.
   294  	select {
   295  	case e, ok := <-ch:
   296  		if !ok {
   297  			return errors.New("unexpected ws close while waiting for Hello")
   298  		}
   299  		if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
   300  			return errors.Wrap(err, "error at Hello")
   301  		}
   302  	case <-ctx.Done():
   303  		return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
   304  	}
   305  
   306  	wsutil.WSDebug("Hello received; duration:", hello.HeartbeatInterval)
   307  
   308  	// Send Discord either the Identify packet (if it's a fresh connection), or
   309  	// a Resume packet (if it's a dead connection).
   310  	if g.SessionID == "" {
   311  		// SessionID is empty, so this is a completely new session.
   312  		if err := g.IdentifyCtx(ctx); err != nil {
   313  			return errors.Wrap(err, "failed to identify")
   314  		}
   315  	} else {
   316  		if err := g.ResumeCtx(ctx); err != nil {
   317  			return errors.Wrap(err, "failed to resume")
   318  		}
   319  	}
   320  
   321  	// Expect either READY or RESUMED before continuing.
   322  	wsutil.WSDebug("Waiting for either READY or RESUMED.")
   323  
   324  	// WaitForEvent should
   325  	err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
   326  		switch op.EventName {
   327  		case "READY":
   328  			wsutil.WSDebug("Found READY event.")
   329  			return true
   330  		case "RESUMED":
   331  			wsutil.WSDebug("Found RESUMED event.")
   332  			return true
   333  		}
   334  		return false
   335  	})
   336  
   337  	if err != nil {
   338  		return errors.Wrap(err, "first error")
   339  	}
   340  
   341  	// Start the event handler, which also handles the pacemaker death signal.
   342  	g.waitGroup.Add(1)
   343  
   344  	// Use the pacemaker loop.
   345  	g.PacerLoop.RunAsync(hello.HeartbeatInterval.Duration(), ch, g, func(err error) {
   346  		g.waitGroup.Done() // mark so Close() can exit.
   347  		wsutil.WSDebug("Event loop stopped with error:", err)
   348  
   349  		if err != nil {
   350  			g.ErrorLog(err)
   351  			g.Reconnect()
   352  		}
   353  	})
   354  
   355  	wsutil.WSDebug("Started successfully.")
   356  
   357  	return nil
   358  }
   359  
   360  // SendCtx is a low-level function to send an OP payload to the Gateway. Most
   361  // users shouldn't touch this, unless they know what they're doing.
   362  func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
   363  	var op = wsutil.OP{
   364  		Code: code,
   365  	}
   366  
   367  	if v != nil {
   368  		b, err := json.Marshal(v)
   369  		if err != nil {
   370  			return errors.Wrap(err, "failed to encode v")
   371  		}
   372  
   373  		op.Data = b
   374  	}
   375  
   376  	b, err := json.Marshal(op)
   377  	if err != nil {
   378  		return errors.Wrap(err, "failed to encode payload")
   379  	}
   380  
   381  	// WS should already be thread-safe.
   382  	return g.WS.SendCtx(ctx, b)
   383  }