github.com/diamondburned/arikawa@v1.3.14/voice/voice.go (about)

     1  // Package voice handles the Discord voice gateway and UDP connections, as well
     2  // as managing and keeping track of multiple voice sessions.
     3  //
     4  // This package abstracts the subpackage voice/voicesession and voice/udp.
     5  package voice
     6  
     7  import (
     8  	"context"
     9  	"log"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/diamondburned/arikawa/discord"
    15  	"github.com/diamondburned/arikawa/gateway"
    16  	"github.com/diamondburned/arikawa/state"
    17  	"github.com/pkg/errors"
    18  )
    19  
    20  var (
    21  	// defaultErrorHandler is the default error handler
    22  	defaultErrorHandler = func(err error) { log.Println("voice gateway error:", err) }
    23  
    24  	// ErrCannotSend is an error when audio is sent to a closed channel.
    25  	ErrCannotSend = errors.New("cannot send audio to closed channel")
    26  )
    27  
    28  // Voice represents a Voice Repository used for managing voice sessions.
    29  type Voice struct {
    30  	*state.State
    31  
    32  	// Session holds all of the active voice sessions.
    33  	mapmutex sync.Mutex
    34  	sessions map[discord.GuildID]*Session
    35  
    36  	// Callbacks to remove the handlers.
    37  	closers []func()
    38  
    39  	// ErrorLog will be called when an error occurs (defaults to log.Println)
    40  	ErrorLog func(err error)
    41  }
    42  
    43  // NewVoiceFromToken creates a new voice session from the given token.
    44  func NewVoiceFromToken(token string) (*Voice, error) {
    45  	s, err := state.New(token)
    46  	if err != nil {
    47  		return nil, errors.Wrap(err, "failed to create a new session")
    48  	}
    49  
    50  	return NewVoice(s), nil
    51  }
    52  
    53  // NewVoice creates a new Voice repository wrapped around a state. The function
    54  // will also automatically add the GuildVoiceStates intent, as that is required.
    55  func NewVoice(s *state.State) *Voice {
    56  	v := &Voice{
    57  		State:    s,
    58  		sessions: make(map[discord.GuildID]*Session),
    59  		ErrorLog: defaultErrorHandler,
    60  	}
    61  
    62  	// Add the required event handlers to the session.
    63  	v.closers = []func(){
    64  		s.AddHandler(v.onVoiceStateUpdate),
    65  		s.AddHandler(v.onVoiceServerUpdate),
    66  	}
    67  
    68  	return v
    69  }
    70  
    71  // onVoiceStateUpdate receives VoiceStateUpdateEvents from the gateway
    72  // to keep track of the current user's voice state.
    73  func (v *Voice) onVoiceStateUpdate(e *gateway.VoiceStateUpdateEvent) {
    74  	// Get the current user.
    75  	me, err := v.Me()
    76  	if err != nil {
    77  		v.ErrorLog(err)
    78  		return
    79  	}
    80  
    81  	// Ignore the event if it is an update from another user.
    82  	if me.ID != e.UserID {
    83  		return
    84  	}
    85  
    86  	// Get the stored voice session for the given guild.
    87  	vs, ok := v.GetSession(e.GuildID)
    88  	if !ok {
    89  		return
    90  	}
    91  
    92  	// Do what we must.
    93  	vs.UpdateState(e)
    94  
    95  	// Remove the connection if the current user has disconnected.
    96  	if e.ChannelID == 0 {
    97  		v.RemoveSession(e.GuildID)
    98  	}
    99  }
   100  
   101  // onVoiceServerUpdate receives VoiceServerUpdateEvents from the gateway
   102  // to manage the current user's voice connections.
   103  func (v *Voice) onVoiceServerUpdate(e *gateway.VoiceServerUpdateEvent) {
   104  	// Get the stored voice session for the given guild.
   105  	vs, ok := v.GetSession(e.GuildID)
   106  	if !ok {
   107  		return
   108  	}
   109  
   110  	// Do what we must.
   111  	vs.UpdateServer(e)
   112  }
   113  
   114  // GetSession gets a session for a guild with a read lock.
   115  func (v *Voice) GetSession(guildID discord.GuildID) (*Session, bool) {
   116  	v.mapmutex.Lock()
   117  	defer v.mapmutex.Unlock()
   118  
   119  	// For some reason you cannot just put `return v.sessions[]` and return a bool D:
   120  	conn, ok := v.sessions[guildID]
   121  	return conn, ok
   122  }
   123  
   124  // RemoveSession removes a session.
   125  func (v *Voice) RemoveSession(guildID discord.GuildID) {
   126  	v.mapmutex.Lock()
   127  	defer v.mapmutex.Unlock()
   128  
   129  	// Ensure that the session is disconnected.
   130  	if ses, ok := v.sessions[guildID]; ok {
   131  		ses.Disconnect()
   132  	}
   133  
   134  	delete(v.sessions, guildID)
   135  }
   136  
   137  // JoinChannel joins the specified channel in the specified guild.
   138  func (v *Voice) JoinChannel(gID discord.GuildID, cID discord.ChannelID, muted, deafened bool) (*Session, error) {
   139  	// Get the stored voice session for the given guild.
   140  	conn, ok := v.GetSession(gID)
   141  
   142  	// Create a new voice session if one does not exist.
   143  	if !ok {
   144  		u, err := v.Me()
   145  		if err != nil {
   146  			return nil, errors.Wrap(err, "failed to get self")
   147  		}
   148  
   149  		conn = NewSession(v.Session, u.ID)
   150  		conn.ErrorLog = v.ErrorLog
   151  
   152  		v.mapmutex.Lock()
   153  		v.sessions[gID] = conn
   154  		v.mapmutex.Unlock()
   155  	}
   156  
   157  	// Connect.
   158  	return conn, conn.JoinChannel(gID, cID, muted, deafened)
   159  }
   160  
   161  func (v *Voice) Close() error {
   162  	err := &CloseError{
   163  		SessionErrors: make(map[discord.GuildID]error),
   164  	}
   165  
   166  	v.mapmutex.Lock()
   167  	defer v.mapmutex.Unlock()
   168  
   169  	// Remove all callback handlers.
   170  	for _, fn := range v.closers {
   171  		fn()
   172  	}
   173  
   174  	for gID, s := range v.sessions {
   175  		log.Println("closing", gID)
   176  		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   177  		if dErr := s.DisconnectCtx(ctx); dErr != nil {
   178  			err.SessionErrors[gID] = dErr
   179  		}
   180  		cancel()
   181  		log.Println("closed", gID)
   182  	}
   183  
   184  	err.StateErr = v.State.Close()
   185  	if err.HasError() {
   186  		return err
   187  	}
   188  
   189  	return nil
   190  }
   191  
   192  type CloseError struct {
   193  	SessionErrors map[discord.GuildID]error
   194  	StateErr      error
   195  }
   196  
   197  func (e *CloseError) HasError() bool {
   198  	if e.StateErr != nil {
   199  		return true
   200  	}
   201  
   202  	return len(e.SessionErrors) > 0
   203  }
   204  
   205  func (e *CloseError) Error() string {
   206  	if e.StateErr != nil {
   207  		return e.StateErr.Error()
   208  	}
   209  
   210  	if len(e.SessionErrors) < 1 {
   211  		return ""
   212  	}
   213  
   214  	return strconv.Itoa(len(e.SessionErrors)) + " voice sessions returned errors while attempting to disconnect"
   215  }