github.com/diamondburned/arikawa/v2@v2.1.0/voice/session.go (about)

     1  package voice
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/diamondburned/arikawa/v2/state"
     9  	"github.com/diamondburned/arikawa/v2/utils/handler"
    10  
    11  	"github.com/pkg/errors"
    12  
    13  	"github.com/diamondburned/arikawa/v2/discord"
    14  	"github.com/diamondburned/arikawa/v2/gateway"
    15  	"github.com/diamondburned/arikawa/v2/internal/handleloop"
    16  	"github.com/diamondburned/arikawa/v2/internal/moreatomic"
    17  	"github.com/diamondburned/arikawa/v2/session"
    18  	"github.com/diamondburned/arikawa/v2/utils/wsutil"
    19  	"github.com/diamondburned/arikawa/v2/voice/udp"
    20  	"github.com/diamondburned/arikawa/v2/voice/voicegateway"
    21  )
    22  
    23  // Protocol is the encryption protocol that this library uses.
    24  const Protocol = "xsalsa20_poly1305"
    25  
    26  // ErrAlreadyConnecting is returned when the session is already connecting.
    27  var ErrAlreadyConnecting = errors.New("already connecting")
    28  
    29  // ErrCannotSend is an error when audio is sent to a closed channel.
    30  var ErrCannotSend = errors.New("cannot send audio to closed channel")
    31  
    32  // WSTimeout is the duration to wait for a gateway operation including Session
    33  // to complete before erroring out. This only applies to functions that don't
    34  // take in a context already.
    35  var WSTimeout = 10 * time.Second
    36  
    37  // Session is a single voice session that wraps around the voice gateway and UDP
    38  // connection.
    39  type Session struct {
    40  	*handler.Handler
    41  	ErrorLog func(err error)
    42  
    43  	session *session.Session
    44  	cancels []func()
    45  	looper  *handleloop.Loop
    46  
    47  	// joining determines the behavior of incoming event callbacks (Update).
    48  	// If this is true, incoming events will just send into Updated channels. If
    49  	// false, events will trigger a reconnection.
    50  	joining  moreatomic.Bool
    51  	incoming chan struct{} // used only when joining == true
    52  
    53  	mut sync.RWMutex
    54  
    55  	state voicegateway.State // guarded except UserID
    56  
    57  	// TODO: expose getters mutex-guarded.
    58  	gateway  *voicegateway.Gateway
    59  	voiceUDP *udp.Connection
    60  }
    61  
    62  // NewSession creates a new voice session for the current user.
    63  func NewSession(state *state.State) (*Session, error) {
    64  	u, err := state.Me()
    65  	if err != nil {
    66  		return nil, errors.Wrap(err, "failed to get me")
    67  	}
    68  
    69  	return NewSessionCustom(state.Session, u.ID), nil
    70  }
    71  
    72  // NewSessionCustom creates a new voice session from the given session and user
    73  // ID.
    74  func NewSessionCustom(ses *session.Session, userID discord.UserID) *Session {
    75  	handler := handler.New()
    76  	hlooper := handleloop.NewLoop(handler)
    77  	session := &Session{
    78  		Handler: handler,
    79  		looper:  hlooper,
    80  		session: ses,
    81  		state: voicegateway.State{
    82  			UserID: userID,
    83  		},
    84  		ErrorLog: func(err error) {},
    85  		incoming: make(chan struct{}, 2),
    86  	}
    87  	session.cancels = []func(){
    88  		ses.AddHandler(session.updateServer),
    89  		ses.AddHandler(session.updateState),
    90  	}
    91  
    92  	return session
    93  }
    94  
    95  func (s *Session) updateServer(ev *gateway.VoiceServerUpdateEvent) {
    96  	// If this is true, then mutex is acquired already.
    97  	if s.joining.Get() {
    98  		if s.state.GuildID != ev.GuildID {
    99  			return
   100  		}
   101  
   102  		s.state.Endpoint = ev.Endpoint
   103  		s.state.Token = ev.Token
   104  
   105  		s.incoming <- struct{}{}
   106  		return
   107  	}
   108  
   109  	s.mut.Lock()
   110  	defer s.mut.Unlock()
   111  
   112  	if s.state.GuildID != ev.GuildID {
   113  		return
   114  	}
   115  
   116  	// Reconnect.
   117  
   118  	s.state.Endpoint = ev.Endpoint
   119  	s.state.Token = ev.Token
   120  
   121  	ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
   122  	defer cancel()
   123  
   124  	if err := s.reconnectCtx(ctx); err != nil {
   125  		s.ErrorLog(errors.Wrap(err, "failed to reconnect after voice server update"))
   126  	}
   127  }
   128  
   129  func (s *Session) updateState(ev *gateway.VoiceStateUpdateEvent) {
   130  	if s.state.UserID != ev.UserID { // constant so no mutex
   131  		// Not our state.
   132  		return
   133  	}
   134  
   135  	// If this is true, then mutex is acquired already.
   136  	if s.joining.Get() {
   137  		if s.state.GuildID != ev.GuildID {
   138  			return
   139  		}
   140  
   141  		s.state.SessionID = ev.SessionID
   142  		s.state.ChannelID = ev.ChannelID
   143  
   144  		s.incoming <- struct{}{}
   145  		return
   146  	}
   147  }
   148  
   149  func (s *Session) JoinChannel(gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
   150  	ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
   151  	defer cancel()
   152  
   153  	return s.JoinChannelCtx(ctx, gID, cID, mute, deaf)
   154  }
   155  
   156  // JoinChannelCtx joins a voice channel. Callers shouldn't use this method
   157  // directly, but rather Voice's. This method shouldn't ever be called
   158  // concurrently.
   159  func (s *Session) JoinChannelCtx(
   160  	ctx context.Context, gID discord.GuildID, cID discord.ChannelID, mute, deaf bool) error {
   161  
   162  	if s.joining.Get() {
   163  		return ErrAlreadyConnecting
   164  	}
   165  
   166  	// Acquire the mutex during join, locking during IO as well.
   167  	s.mut.Lock()
   168  	defer s.mut.Unlock()
   169  
   170  	// Set that we're joining.
   171  	s.joining.Set(true)
   172  	defer s.joining.Set(false) // reset when done
   173  
   174  	// Ensure gateway and voiceUDP are already closed.
   175  	s.ensureClosed()
   176  
   177  	// Set the state.
   178  	s.state.ChannelID = cID
   179  	s.state.GuildID = gID
   180  
   181  	// Ensure that if `cID` is zero that it passes null to the update event.
   182  	channelID := discord.NullChannelID
   183  	if cID.IsValid() {
   184  		channelID = cID
   185  	}
   186  
   187  	// https://discord.com/developers/docs/topics/voice-connections#retrieving-voice-server-information
   188  	// Send a Voice State Update event to the gateway.
   189  	err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
   190  		GuildID:   gID,
   191  		ChannelID: channelID,
   192  		SelfMute:  mute,
   193  		SelfDeaf:  deaf,
   194  	})
   195  	if err != nil {
   196  		return errors.Wrap(err, "failed to send Voice State Update event")
   197  	}
   198  
   199  	// Wait for 2 replies. The above command should reply with these 2 events.
   200  	if err := s.waitForIncoming(ctx, 2); err != nil {
   201  		return errors.Wrap(err, "failed to wait for needed gateway events")
   202  	}
   203  
   204  	// These 2 methods should've updated s.state before sending into these
   205  	// channels. Since s.state is already filled, we can go ahead and connect.
   206  
   207  	return s.reconnectCtx(ctx)
   208  }
   209  
   210  func (s *Session) waitForIncoming(ctx context.Context, n int) error {
   211  	for i := 0; i < n; i++ {
   212  		select {
   213  		case <-s.incoming:
   214  			continue
   215  		case <-ctx.Done():
   216  			return ctx.Err()
   217  		}
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  // reconnect uses the current state to reconnect to a new gateway and UDP
   224  // connection.
   225  func (s *Session) reconnectCtx(ctx context.Context) (err error) {
   226  	wsutil.WSDebug("Sending stop handle.")
   227  	s.looper.Stop()
   228  
   229  	wsutil.WSDebug("Start gateway.")
   230  	s.gateway = voicegateway.New(s.state)
   231  
   232  	// Open the voice gateway. The function will block until Ready is received.
   233  	if err := s.gateway.OpenCtx(ctx); err != nil {
   234  		return errors.Wrap(err, "failed to open voice gateway")
   235  	}
   236  
   237  	// Start the handler dispatching
   238  	s.looper.Start(s.gateway.Events)
   239  
   240  	// Get the Ready event.
   241  	voiceReady := s.gateway.Ready()
   242  
   243  	// Prepare the UDP voice connection.
   244  	s.voiceUDP, err = udp.DialConnectionCtx(ctx, voiceReady.Addr(), voiceReady.SSRC)
   245  	if err != nil {
   246  		return errors.Wrap(err, "failed to open voice UDP connection")
   247  	}
   248  
   249  	// Get the session description from the voice gateway.
   250  	d, err := s.gateway.SessionDescriptionCtx(ctx, voicegateway.SelectProtocol{
   251  		Protocol: "udp",
   252  		Data: voicegateway.SelectProtocolData{
   253  			Address: s.voiceUDP.GatewayIP,
   254  			Port:    s.voiceUDP.GatewayPort,
   255  			Mode:    Protocol,
   256  		},
   257  	})
   258  	if err != nil {
   259  		return errors.Wrap(err, "failed to select protocol")
   260  	}
   261  
   262  	s.voiceUDP.UseSecret(d.SecretKey)
   263  
   264  	return nil
   265  }
   266  
   267  // Speaking tells Discord we're speaking. This method should not be called
   268  // concurrently.
   269  func (s *Session) Speaking(flag voicegateway.SpeakingFlag) error {
   270  	s.mut.RLock()
   271  	gateway := s.gateway
   272  	s.mut.RUnlock()
   273  
   274  	return gateway.Speaking(flag)
   275  }
   276  
   277  // UseContext tells the UDP voice connection to write with the given context.
   278  func (s *Session) UseContext(ctx context.Context) error {
   279  	s.mut.Lock()
   280  	defer s.mut.Unlock()
   281  
   282  	if s.voiceUDP == nil {
   283  		return ErrCannotSend
   284  	}
   285  
   286  	return s.voiceUDP.UseContext(ctx)
   287  }
   288  
   289  // VoiceUDPConn gets a voice UDP connection. The caller could use this method to
   290  // circumvent the rapid mutex-read-lock acquire inside Write.
   291  func (s *Session) VoiceUDPConn() *udp.Connection {
   292  	s.mut.RLock()
   293  	defer s.mut.RUnlock()
   294  
   295  	return s.voiceUDP
   296  }
   297  
   298  // Write writes into the UDP voice connection WITHOUT a timeout. Refer to
   299  // WriteCtx for more information.
   300  func (s *Session) Write(b []byte) (int, error) {
   301  	return s.WriteCtx(context.Background(), b)
   302  }
   303  
   304  // WriteCtx writes into the UDP voice connection with a context for timeout.
   305  // This method is thread safe as far as calling other methods of Session goes;
   306  // HOWEVER it is not thread safe to call Write itself concurrently.
   307  func (s *Session) WriteCtx(ctx context.Context, b []byte) (int, error) {
   308  	voiceUDP := s.VoiceUDPConn()
   309  
   310  	if voiceUDP == nil {
   311  		return 0, ErrCannotSend
   312  	}
   313  
   314  	return voiceUDP.WriteCtx(ctx, b)
   315  }
   316  
   317  // Leave disconnects the current voice session from the currently connected
   318  // channel.
   319  func (s *Session) Leave() error {
   320  	ctx, cancel := context.WithTimeout(context.Background(), WSTimeout)
   321  	defer cancel()
   322  
   323  	return s.LeaveCtx(ctx)
   324  }
   325  
   326  // LeaveCtx disconencts with a context. Refer to Leave for more information.
   327  func (s *Session) LeaveCtx(ctx context.Context) error {
   328  	s.mut.Lock()
   329  	defer s.mut.Unlock()
   330  
   331  	// If we're already closed.
   332  	if s.gateway == nil && s.voiceUDP == nil {
   333  		return nil
   334  	}
   335  
   336  	s.looper.Stop()
   337  
   338  	// Notify Discord that we're leaving. This will send a
   339  	// VoiceStateUpdateEvent, in which our handler will promptly remove the
   340  	// session from the map.
   341  
   342  	err := s.session.Gateway.UpdateVoiceStateCtx(ctx, gateway.UpdateVoiceStateData{
   343  		GuildID:   s.state.GuildID,
   344  		ChannelID: discord.ChannelID(discord.NullSnowflake),
   345  		SelfMute:  true,
   346  		SelfDeaf:  true,
   347  	})
   348  
   349  	s.ensureClosed()
   350  	// wrap returns nil if err is nil
   351  	return errors.Wrap(err, "failed to update voice state")
   352  }
   353  
   354  // close ensures everything is closed. It does not acquire the mutex.
   355  func (s *Session) ensureClosed() {
   356  	s.looper.Stop()
   357  
   358  	// Disconnect the UDP connection.
   359  	if s.voiceUDP != nil {
   360  		s.voiceUDP.Close()
   361  		s.voiceUDP = nil
   362  	}
   363  
   364  	// Disconnect the voice gateway, ignoring the error.
   365  	if s.gateway != nil {
   366  		if err := s.gateway.Close(); err != nil {
   367  			wsutil.WSDebug("Uncaught voice gateway close error:", err)
   368  		}
   369  		s.gateway = nil
   370  	}
   371  }
   372  
   373  // ReadPacket reads a single packet from the UDP connection. This is NOT at all
   374  // thread safe, and must be used very carefully. The backing buffer is always
   375  // reused.
   376  func (s *Session) ReadPacket() (*udp.Packet, error) {
   377  	return s.voiceUDP.ReadPacket()
   378  }