github.com/livekit/protocol@v1.39.3/sip/sip.go (about)

     1  // Copyright 2023 LiveKit, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sip
    16  
    17  import (
    18  	"crypto/sha256"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"io"
    22  	"maps"
    23  	"math"
    24  	"net/netip"
    25  	"regexp"
    26  	"sort"
    27  	"strings"
    28  
    29  	"github.com/dennwc/iters"
    30  	"github.com/twitchtv/twirp"
    31  	"golang.org/x/exp/slices"
    32  
    33  	"github.com/livekit/protocol/livekit"
    34  	"github.com/livekit/protocol/logger"
    35  	"github.com/livekit/protocol/rpc"
    36  	"github.com/livekit/protocol/utils"
    37  	"github.com/livekit/protocol/utils/guid"
    38  )
    39  
    40  //go:generate stringer -type TrunkFilteredReason -trimprefix TrunkFiltered
    41  //go:generate stringer -type TrunkConflictReason -trimprefix TrunkConflict
    42  //go:generate stringer -type DispatchRuleConflictReason -trimprefix DispatchRuleConflict
    43  
    44  func NewCallID() string {
    45  	return guid.New(utils.SIPCallPrefix)
    46  }
    47  
    48  type ErrNoDispatchMatched struct {
    49  	NoRules      bool
    50  	NoTrunks     bool
    51  	CalledNumber string
    52  }
    53  
    54  func (e *ErrNoDispatchMatched) Error() string {
    55  	if e.NoRules {
    56  		return "No SIP Dispatch Rules defined"
    57  	}
    58  	if e.NoTrunks {
    59  		return fmt.Sprintf("No SIP Trunk or Dispatch Rules matched for %q", e.CalledNumber)
    60  	}
    61  	return fmt.Sprintf("No SIP Dispatch Rules matched for %q", e.CalledNumber)
    62  }
    63  
    64  // DispatchRulePriority returns sorting priority for dispatch rules. Lower value means higher priority.
    65  func DispatchRulePriority(info *livekit.SIPDispatchRuleInfo) int32 {
    66  	// In all these cases, prefer pin-protected rules and rules for specific calling number.
    67  	// Thus, the order will be the following:
    68  	// - 0: Direct or Pin (both pin-protected)
    69  	// - 1: Caller, aka Individual (pin-protected)
    70  	// - 2: Callee (pin-protected)
    71  	// - 100: Direct (open)
    72  	// - 101: Caller, aka Individual (open)
    73  	// - 102: Callee (open)
    74  	// Also, add 1K penalty for not specifying the calling number.
    75  	const (
    76  		last = math.MaxInt32
    77  	)
    78  	// TODO: Maybe allow setting specific priorities for dispatch rules?
    79  	priority := int32(0)
    80  	switch rule := info.GetRule().GetRule().(type) {
    81  	default:
    82  		return last
    83  	case *livekit.SIPDispatchRule_DispatchRuleDirect:
    84  		if rule.DispatchRuleDirect.GetPin() != "" {
    85  			priority = 0
    86  		} else {
    87  			priority = 100
    88  		}
    89  	case *livekit.SIPDispatchRule_DispatchRuleIndividual:
    90  		if rule.DispatchRuleIndividual.GetPin() != "" {
    91  			priority = 1
    92  		} else {
    93  			priority = 101
    94  		}
    95  	case *livekit.SIPDispatchRule_DispatchRuleCallee:
    96  		if rule.DispatchRuleCallee.GetPin() != "" {
    97  			priority = 2
    98  		} else {
    99  			priority = 102
   100  		}
   101  	}
   102  	if len(info.InboundNumbers) == 0 {
   103  		priority += 1000
   104  	}
   105  	return priority
   106  }
   107  
   108  func hasHigherPriority(r1, r2 *livekit.SIPDispatchRuleInfo) bool {
   109  	p1, p2 := DispatchRulePriority(r1), DispatchRulePriority(r2)
   110  	if p1 < p2 {
   111  		return true
   112  	} else if p1 > p2 {
   113  		return false
   114  	}
   115  	// For predictable sorting order.
   116  	room1, _, _ := GetPinAndRoom(r1)
   117  	room2, _, _ := GetPinAndRoom(r2)
   118  	return room1 < room2
   119  }
   120  
   121  // SortDispatchRules predictably sorts dispatch rules by priority (first one is highest).
   122  func SortDispatchRules(rules []*livekit.SIPDispatchRuleInfo) {
   123  	sort.Slice(rules, func(i, j int) bool {
   124  		return hasHigherPriority(rules[i], rules[j])
   125  	})
   126  }
   127  
   128  func printID(s string) string {
   129  	if s == "" {
   130  		return "<new>"
   131  	}
   132  	return s
   133  }
   134  
   135  // ValidateDispatchRules checks a set of dispatch rules for conflicts.
   136  //
   137  // Deprecated: use ValidateDispatchRulesIter
   138  func ValidateDispatchRules(rules []*livekit.SIPDispatchRuleInfo, opts ...MatchDispatchRuleOpt) error {
   139  	_, err := ValidateDispatchRulesIter(iters.Slice(rules), opts...)
   140  	return err
   141  }
   142  
   143  // ValidateDispatchRulesIter checks a set of dispatch rules for conflicts.
   144  func ValidateDispatchRulesIter(it iters.Iter[*livekit.SIPDispatchRuleInfo], opts ...MatchDispatchRuleOpt) (best *livekit.SIPDispatchRuleInfo, _ error) {
   145  	it = NewDispatchRuleValidator(opts...).ValidateIter(it)
   146  	defer it.Close()
   147  	for {
   148  		r, err := it.Next()
   149  		if err == io.EOF {
   150  			break
   151  		} else if err != nil {
   152  			return best, err
   153  		}
   154  		if best == nil || hasHigherPriority(r, best) {
   155  			best = r
   156  		}
   157  	}
   158  	return best, nil
   159  }
   160  
   161  func NewDispatchRuleValidator(opts ...MatchDispatchRuleOpt) *DispatchRuleValidator {
   162  	var opt matchDispatchRuleOpts
   163  	for _, fnc := range opts {
   164  		fnc(&opt)
   165  	}
   166  	opt.defaults()
   167  	return &DispatchRuleValidator{
   168  		opt:       opt,
   169  		byRuleKey: make(map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo),
   170  	}
   171  }
   172  
   173  type dispatchRuleKey struct {
   174  	Pin    string
   175  	Trunk  string
   176  	Number string
   177  }
   178  
   179  type DispatchRuleValidator struct {
   180  	opt       matchDispatchRuleOpts
   181  	byRuleKey map[dispatchRuleKey]*livekit.SIPDispatchRuleInfo
   182  }
   183  
   184  func (v *DispatchRuleValidator) ValidateIter(it iters.Iter[*livekit.SIPDispatchRuleInfo]) iters.Iter[*livekit.SIPDispatchRuleInfo] {
   185  	return &dispatchRuleValidatorIter{v: v, it: it}
   186  }
   187  
   188  func (v *DispatchRuleValidator) Validate(r *livekit.SIPDispatchRuleInfo) error {
   189  	_, pin, err := GetPinAndRoom(r)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	trunks := r.TrunkIds
   194  	if len(trunks) == 0 {
   195  		// This rule matches all trunks, but collides only with other default ones (specific rules take priority).
   196  		trunks = []string{""}
   197  	}
   198  	numbers := r.InboundNumbers
   199  	if len(numbers) == 0 {
   200  		// This rule matches all numbers, but collides only with other default ones (specific rules take priority).
   201  		numbers = []string{""}
   202  	}
   203  	for _, trunk := range trunks {
   204  		for _, number := range numbers {
   205  			key := dispatchRuleKey{Pin: pin, Trunk: trunk, Number: NormalizeNumber(number)}
   206  			r2 := v.byRuleKey[key]
   207  			if r2 != nil {
   208  				v.opt.Conflict(r, r2, DispatchRuleConflictGeneric)
   209  				if v.opt.AllowConflicts {
   210  					continue
   211  				}
   212  				return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting SIP Dispatch Rules: same Trunk+Number+PIN combination for for %q and %q",
   213  					printID(r.SipDispatchRuleId), printID(r2.SipDispatchRuleId))
   214  			}
   215  			v.byRuleKey[key] = r
   216  		}
   217  	}
   218  	return nil
   219  }
   220  
   221  type dispatchRuleValidatorIter struct {
   222  	v  *DispatchRuleValidator
   223  	it iters.Iter[*livekit.SIPDispatchRuleInfo]
   224  }
   225  
   226  func (v *dispatchRuleValidatorIter) Next() (*livekit.SIPDispatchRuleInfo, error) {
   227  	r, err := v.it.Next()
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	r = v.v.opt.Replace(r)
   232  	if err = v.v.Validate(r); err != nil {
   233  		return nil, err
   234  	}
   235  	return r, nil
   236  }
   237  
   238  func (v *dispatchRuleValidatorIter) Close() {
   239  	v.it.Close()
   240  }
   241  
   242  // SelectDispatchRule takes a list of dispatch rules, and takes the decision which one should be selected.
   243  // It returns an error if there are conflicting rules. Returns nil if no rules match.
   244  //
   245  // Deprecated: use MatchDispatchRuleIter
   246  func SelectDispatchRule(rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
   247  	// Sorting will do the selection for us. We already filtered out irrelevant ones in MatchDispatchRule and above.
   248  	// Nil is fine here. We will report "no rules matched" later.
   249  	return ValidateDispatchRulesIter(iters.Slice(rules), opts...)
   250  }
   251  
   252  // GetPinAndRoom returns a room name/prefix and the pin for a dispatch rule. Just a convenience wrapper.
   253  func GetPinAndRoom(info *livekit.SIPDispatchRuleInfo) (room, pin string, err error) {
   254  	// TODO: Could probably add methods on SIPDispatchRuleInfo struct instead.
   255  	switch rule := info.GetRule().GetRule().(type) {
   256  	default:
   257  		return "", "", fmt.Errorf("Unsupported SIP Dispatch Rule: %T", rule)
   258  	case *livekit.SIPDispatchRule_DispatchRuleDirect:
   259  		pin = rule.DispatchRuleDirect.GetPin()
   260  		room = rule.DispatchRuleDirect.GetRoomName()
   261  	case *livekit.SIPDispatchRule_DispatchRuleIndividual:
   262  		pin = rule.DispatchRuleIndividual.GetPin()
   263  		room = rule.DispatchRuleIndividual.GetRoomPrefix()
   264  	case *livekit.SIPDispatchRule_DispatchRuleCallee:
   265  		pin = rule.DispatchRuleCallee.GetPin()
   266  		room = rule.DispatchRuleCallee.GetRoomPrefix()
   267  	}
   268  	return room, pin, nil
   269  }
   270  
   271  func printNumbers(numbers []string) string {
   272  	if len(numbers) == 0 {
   273  		return "<any>"
   274  	}
   275  	return fmt.Sprintf("%q", numbers)
   276  }
   277  
   278  var (
   279  	reNumber     = regexp.MustCompile(`^\+?[\d\- ()]+$`)
   280  	reNumberRepl = strings.NewReplacer(
   281  		" ", "",
   282  		"-", "",
   283  		"(", "",
   284  		")", "",
   285  	)
   286  )
   287  
   288  func NormalizeNumber(num string) string {
   289  	if num == "" {
   290  		return ""
   291  	}
   292  	if !reNumber.MatchString(num) {
   293  		return num
   294  	}
   295  	num = reNumberRepl.Replace(num)
   296  	if !strings.HasPrefix(num, "+") {
   297  		return "+" + num
   298  	}
   299  	return num
   300  }
   301  
   302  func validateTrunkInbound(byInbound map[string]*livekit.SIPInboundTrunkInfo, t *livekit.SIPInboundTrunkInfo, opt *matchTrunkOpts) error {
   303  	if len(t.AllowedNumbers) == 0 {
   304  		if t2 := byInbound[""]; t2 != nil {
   305  			opt.Conflict(t, t2, TrunkConflictCalledNumber)
   306  			if opt.AllowConflicts {
   307  				return nil
   308  			}
   309  			return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s without AllowedNumbers set",
   310  				printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers))
   311  		}
   312  		byInbound[""] = t
   313  	} else {
   314  		for _, num := range t.AllowedNumbers {
   315  			inboundKey := NormalizeNumber(num)
   316  			t2 := byInbound[inboundKey]
   317  			if t2 != nil {
   318  				opt.Conflict(t, t2, TrunkConflictCallingNumber)
   319  				if opt.AllowConflicts {
   320  					continue
   321  				}
   322  				return twirp.NewErrorf(twirp.InvalidArgument, "Conflicting inbound SIP Trunks: %q and %q, using the same number(s) %s and AllowedNumber %q",
   323  					printID(t.SipTrunkId), printID(t2.SipTrunkId), printNumbers(t.Numbers), num)
   324  			}
   325  			byInbound[inboundKey] = t
   326  		}
   327  	}
   328  	return nil
   329  }
   330  
   331  // ValidateTrunks checks a set of trunks for conflicts.
   332  //
   333  // Deprecated: use ValidateTrunksIter
   334  func ValidateTrunks(trunks []*livekit.SIPInboundTrunkInfo, opts ...MatchTrunkOpt) error {
   335  	return ValidateTrunksIter(iters.Slice(trunks), opts...)
   336  }
   337  
   338  // ValidateTrunksIter checks a set of trunks for conflicts.
   339  func ValidateTrunksIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], opts ...MatchTrunkOpt) error {
   340  	defer it.Close()
   341  	var opt matchTrunkOpts
   342  	for _, fnc := range opts {
   343  		fnc(&opt)
   344  	}
   345  	opt.defaults()
   346  	byOutboundAndInbound := make(map[string]map[string]*livekit.SIPInboundTrunkInfo)
   347  	for {
   348  		t, err := it.Next()
   349  		if err == io.EOF {
   350  			break
   351  		} else if err != nil {
   352  			return err
   353  		}
   354  		t = opt.Replace(t)
   355  		if len(t.Numbers) == 0 {
   356  			byInbound := byOutboundAndInbound[""]
   357  			if byInbound == nil {
   358  				byInbound = make(map[string]*livekit.SIPInboundTrunkInfo)
   359  				byOutboundAndInbound[""] = byInbound
   360  			}
   361  			if err := validateTrunkInbound(byInbound, t, &opt); err != nil {
   362  				return err
   363  			}
   364  		} else {
   365  			for _, num := range t.Numbers {
   366  				byInbound := byOutboundAndInbound[num]
   367  				if byInbound == nil {
   368  					byInbound = make(map[string]*livekit.SIPInboundTrunkInfo)
   369  					byOutboundAndInbound[num] = byInbound
   370  				}
   371  				if err := validateTrunkInbound(byInbound, t, &opt); err != nil {
   372  					return err
   373  				}
   374  			}
   375  		}
   376  	}
   377  	return nil
   378  }
   379  
   380  func isValidMask(mask string) bool {
   381  	// Allowed formats:
   382  	// - 1.2.3.4
   383  	// - 1.2.3.4/8
   384  	// - [::]
   385  	// - [::]/8
   386  	// - some.host.name
   387  	if strings.ContainsAny(mask, "()+*;, \t\n\r") {
   388  		return false
   389  	}
   390  	if strings.Contains(mask, "://") {
   391  		return false
   392  	}
   393  	return true
   394  }
   395  
   396  func filterInvalidAddrMasks(masks []string) []string {
   397  	if len(masks) == 0 {
   398  		return nil
   399  	}
   400  	out := make([]string, 0, len(masks))
   401  	for _, m := range masks {
   402  		if isValidMask(m) {
   403  			out = append(out, m)
   404  		}
   405  	}
   406  	return out
   407  }
   408  
   409  func matchAddrMask(ip netip.Addr, mask string) bool {
   410  	if !strings.Contains(mask, "/") {
   411  		expIP, err := netip.ParseAddr(mask)
   412  		if err != nil {
   413  			return false
   414  		}
   415  		return ip == expIP
   416  	}
   417  	pref, err := netip.ParsePrefix(mask)
   418  	if err != nil {
   419  		return false
   420  	}
   421  	return pref.Contains(ip)
   422  }
   423  
   424  func matchAddrMasks(addr string, host string, masks []string) bool {
   425  	ip, err := netip.ParseAddr(addr)
   426  	if err != nil {
   427  		return true
   428  	}
   429  	masks = filterInvalidAddrMasks(masks)
   430  	if len(masks) == 0 {
   431  		return true
   432  	}
   433  	for _, mask := range masks {
   434  		if mask == host || matchAddrMask(ip, mask) {
   435  			return true
   436  		}
   437  	}
   438  	return false
   439  }
   440  
   441  func matchNumbers(num string, allowed []string) bool {
   442  	if len(allowed) == 0 {
   443  		return true
   444  	}
   445  	norm := NormalizeNumber(num)
   446  	for _, allow := range allowed {
   447  		if num == allow || norm == NormalizeNumber(allow) {
   448  			return true
   449  		}
   450  	}
   451  	return false
   452  }
   453  
   454  // TrunkMatchType indicates how a trunk was matched
   455  type TrunkMatchType int
   456  
   457  const (
   458  	// TrunkMatchEmpty indicates no trunks were defined
   459  	TrunkMatchEmpty TrunkMatchType = iota
   460  	// TrunkMatchNone indicates trunks exist but none matched
   461  	TrunkMatchNone
   462  	// TrunkMatchDefault indicates only a default trunk (with no specific numbers) matched
   463  	TrunkMatchDefault
   464  	// TrunkMatchSpecific indicates a trunk with specific numbers matched
   465  	TrunkMatchSpecific
   466  )
   467  
   468  // TrunkMatchResult provides detailed information about the trunk matching process
   469  type TrunkMatchResult struct {
   470  	// The matched trunk, if any
   471  	Trunk *livekit.SIPInboundTrunkInfo
   472  	// How the trunk was matched
   473  	MatchType TrunkMatchType
   474  	// Number of default trunks found
   475  	DefaultTrunkCount int
   476  }
   477  
   478  // MatchTrunk finds a SIP Trunk definition matching the request.
   479  // Returns nil if no rules matched or an error if there are conflicting definitions.
   480  //
   481  // Deprecated: use MatchTrunkIter
   482  func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, call *rpc.SIPCall, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) {
   483  	return MatchTrunkIter(iters.Slice(trunks), call, opts...)
   484  }
   485  
   486  // MatchTrunkDetailed is like MatchTrunkIter but returns detailed match information
   487  func MatchTrunkDetailed(it iters.Iter[*livekit.SIPInboundTrunkInfo], call *rpc.SIPCall, opts ...MatchTrunkOpt) (*TrunkMatchResult, error) {
   488  	defer it.Close()
   489  	var opt matchTrunkOpts
   490  	for _, fnc := range opts {
   491  		fnc(&opt)
   492  	}
   493  	opt.defaults()
   494  
   495  	result := &TrunkMatchResult{
   496  		MatchType: TrunkMatchEmpty, // Start with assumption it's empty
   497  	}
   498  
   499  	var (
   500  		selectedTrunk    *livekit.SIPInboundTrunkInfo
   501  		defaultTrunk     *livekit.SIPInboundTrunkInfo
   502  		defaultTrunkPrev *livekit.SIPInboundTrunkInfo
   503  		sawAnyTrunk      bool
   504  	)
   505  	calledNorm := NormalizeNumber(call.To.User)
   506  	for {
   507  		tr, err := it.Next()
   508  		if err == io.EOF {
   509  			break
   510  		} else if err != nil {
   511  			return nil, err
   512  		}
   513  		if !sawAnyTrunk {
   514  			sawAnyTrunk = true
   515  			result.MatchType = TrunkMatchNone // We have trunks but haven't matched any yet
   516  		}
   517  		tr = opt.Replace(tr)
   518  		// Do not consider it if number doesn't match.
   519  		if !matchNumbers(call.From.User, tr.AllowedNumbers) {
   520  			if !opt.Filtered(tr, TrunkFilteredCallingNumberDisallowed) {
   521  				continue
   522  			}
   523  		}
   524  		if !matchAddrMasks(call.SourceIp, call.From.Host, tr.AllowedAddresses) {
   525  			if !opt.Filtered(tr, TrunkFilteredSourceAddressDisallowed) {
   526  				continue
   527  			}
   528  		}
   529  		if len(tr.Numbers) == 0 {
   530  			// Default/wildcard trunk.
   531  			defaultTrunkPrev = defaultTrunk
   532  			defaultTrunk = tr
   533  			result.DefaultTrunkCount++
   534  		} else {
   535  			for _, num := range tr.Numbers {
   536  				if num == call.To.User || NormalizeNumber(num) == calledNorm {
   537  					// Trunk specific to the number.
   538  					if selectedTrunk != nil {
   539  						opt.Conflict(selectedTrunk, tr, TrunkConflictCalledNumber)
   540  						if opt.AllowConflicts {
   541  							// This path is unreachable, since we pick the first trunk. Kept for completeness.
   542  							continue
   543  						}
   544  						return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple SIP Trunks matched for %q", call.To.User)
   545  					}
   546  					selectedTrunk = tr
   547  					if opt.AllowConflicts {
   548  						// Pick the first match as soon as it's found. We don't care about conflicts.
   549  						result.Trunk = selectedTrunk
   550  						result.MatchType = TrunkMatchSpecific
   551  						return result, nil
   552  					}
   553  					// Keep searching! We want to know if there are any conflicting Trunk definitions.
   554  				} else {
   555  					opt.Filtered(tr, TrunkFilteredCalledNumberDisallowed)
   556  				}
   557  			}
   558  		}
   559  	}
   560  
   561  	if selectedTrunk != nil {
   562  		result.Trunk = selectedTrunk
   563  		result.MatchType = TrunkMatchSpecific
   564  		return result, nil
   565  	}
   566  	if result.DefaultTrunkCount > 1 {
   567  		opt.Conflict(defaultTrunk, defaultTrunkPrev, TrunkConflictDefault)
   568  		if !opt.AllowConflicts {
   569  			return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", call.To.User)
   570  		}
   571  	}
   572  	if defaultTrunk != nil {
   573  		result.Trunk = defaultTrunk
   574  		result.MatchType = TrunkMatchDefault
   575  	}
   576  	return result, nil
   577  }
   578  
   579  type matchTrunkOpts struct {
   580  	AllowConflicts bool
   581  	Filtered       TrunkFilteredFunc
   582  	Conflict       TrunkConflictFunc
   583  	Replace        TrunkReplaceFunc
   584  }
   585  
   586  func (opt *matchTrunkOpts) defaults() {
   587  	if opt.Filtered == nil {
   588  		opt.Filtered = func(_ *livekit.SIPInboundTrunkInfo, _ TrunkFilteredReason) bool {
   589  			return false
   590  		}
   591  	}
   592  	if opt.Conflict == nil {
   593  		opt.Conflict = func(_, _ *livekit.SIPInboundTrunkInfo, _ TrunkConflictReason) {}
   594  	}
   595  	if opt.Replace == nil {
   596  		opt.Replace = func(t *livekit.SIPInboundTrunkInfo) *livekit.SIPInboundTrunkInfo {
   597  			return t
   598  		}
   599  	}
   600  }
   601  
   602  type MatchTrunkOpt func(opt *matchTrunkOpts)
   603  
   604  type TrunkFilteredReason int
   605  
   606  const (
   607  	TrunkFilteredInvalid = TrunkFilteredReason(iota)
   608  	TrunkFilteredCallingNumberDisallowed
   609  	TrunkFilteredCalledNumberDisallowed
   610  	TrunkFilteredSourceAddressDisallowed
   611  )
   612  
   613  type TrunkFilteredFunc func(tr *livekit.SIPInboundTrunkInfo, reason TrunkFilteredReason) bool
   614  
   615  // WithTrunkFiltered sets a callback that is called when selected Trunk(s) doesn't match the call.
   616  // If the callback returns true, trunk will not be filtered.
   617  func WithTrunkFiltered(fnc TrunkFilteredFunc) MatchTrunkOpt {
   618  	return func(opt *matchTrunkOpts) {
   619  		opt.Filtered = fnc
   620  	}
   621  }
   622  
   623  type TrunkConflictReason int
   624  
   625  const (
   626  	TrunkConflictDefault = TrunkConflictReason(iota)
   627  	TrunkConflictCalledNumber
   628  	TrunkConflictCallingNumber
   629  )
   630  
   631  type TrunkConflictFunc func(t1, t2 *livekit.SIPInboundTrunkInfo, reason TrunkConflictReason)
   632  
   633  // WithAllowTrunkConflicts allows conflicting Trunk definitions by picking the first match.
   634  //
   635  // Using this option will prevent TrunkConflictFunc from firing, since the first match will be returned immediately.
   636  func WithAllowTrunkConflicts() MatchTrunkOpt {
   637  	return func(opt *matchTrunkOpts) {
   638  		opt.AllowConflicts = true
   639  	}
   640  }
   641  
   642  // WithTrunkConflict sets a callback that is called when two Trunks conflict.
   643  func WithTrunkConflict(fnc TrunkConflictFunc) MatchTrunkOpt {
   644  	return func(opt *matchTrunkOpts) {
   645  		opt.Conflict = fnc
   646  	}
   647  }
   648  
   649  type TrunkReplaceFunc func(t *livekit.SIPInboundTrunkInfo) *livekit.SIPInboundTrunkInfo
   650  
   651  // WithTrunkReplace sets a callback that is called to potentially replace trunks before matching runs.
   652  func WithTrunkReplace(fnc TrunkReplaceFunc) MatchTrunkOpt {
   653  	return func(opt *matchTrunkOpts) {
   654  		opt.Replace = fnc
   655  	}
   656  }
   657  
   658  // MatchTrunkIter finds a SIP Trunk definition matching the request.
   659  // Returns nil if no rules matched or an error if there are conflicting definitions.
   660  func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], call *rpc.SIPCall, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) {
   661  	result, err := MatchTrunkDetailed(it, call, opts...)
   662  	if err != nil {
   663  		return nil, err
   664  	}
   665  	return result.Trunk, nil
   666  }
   667  
   668  // MatchDispatchRule finds the best dispatch rule matching the request parameters. Returns an error if no rule matched.
   669  // Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs).
   670  //
   671  // Deprecated: use MatchDispatchRuleIter
   672  func MatchDispatchRule(trunk *livekit.SIPInboundTrunkInfo, rules []*livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
   673  	return MatchDispatchRuleIter(trunk, iters.Slice(rules), req, opts...)
   674  }
   675  
   676  type matchDispatchRuleOpts struct {
   677  	AllowConflicts bool
   678  	Conflict       DispatchRuleConflictFunc
   679  	Replace        DispatchRuleReplaceFunc
   680  }
   681  
   682  func (opt *matchDispatchRuleOpts) defaults() {
   683  	if opt.Conflict == nil {
   684  		opt.Conflict = func(_, _ *livekit.SIPDispatchRuleInfo, _ DispatchRuleConflictReason) {}
   685  	}
   686  	if opt.Replace == nil {
   687  		opt.Replace = func(r *livekit.SIPDispatchRuleInfo) *livekit.SIPDispatchRuleInfo {
   688  			return r
   689  		}
   690  	}
   691  }
   692  
   693  type MatchDispatchRuleOpt func(opt *matchDispatchRuleOpts)
   694  
   695  type DispatchRuleConflictReason int
   696  
   697  const (
   698  	DispatchRuleConflictGeneric = DispatchRuleConflictReason(iota)
   699  )
   700  
   701  type DispatchRuleConflictFunc func(r1, r2 *livekit.SIPDispatchRuleInfo, reason DispatchRuleConflictReason)
   702  
   703  // WithAllowDispatchRuleConflicts allows conflicting DispatchRule definitions.
   704  func WithAllowDispatchRuleConflicts() MatchDispatchRuleOpt {
   705  	return func(opt *matchDispatchRuleOpts) {
   706  		opt.AllowConflicts = true
   707  	}
   708  }
   709  
   710  // WithDispatchRuleConflict sets a callback that is called when two DispatchRules conflict.
   711  func WithDispatchRuleConflict(fnc DispatchRuleConflictFunc) MatchDispatchRuleOpt {
   712  	return func(opt *matchDispatchRuleOpts) {
   713  		opt.Conflict = fnc
   714  	}
   715  }
   716  
   717  type DispatchRuleReplaceFunc func(r *livekit.SIPDispatchRuleInfo) *livekit.SIPDispatchRuleInfo
   718  
   719  // WithDispatchRuleReplace sets a callback that is called to potentially replace dispatch rules before matching runs.
   720  func WithDispatchRuleReplace(fnc DispatchRuleReplaceFunc) MatchDispatchRuleOpt {
   721  	return func(opt *matchDispatchRuleOpts) {
   722  		opt.Replace = fnc
   723  	}
   724  }
   725  
   726  // MatchDispatchRuleIter finds the best dispatch rule matching the request parameters. Returns an error if no rule matched.
   727  // Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs).
   728  func MatchDispatchRuleIter(trunk *livekit.SIPInboundTrunkInfo, rules iters.Iter[*livekit.SIPDispatchRuleInfo], req *rpc.EvaluateSIPDispatchRulesRequest, opts ...MatchDispatchRuleOpt) (*livekit.SIPDispatchRuleInfo, error) {
   729  	rules = NewDispatchRuleValidator(opts...).ValidateIter(rules)
   730  	defer rules.Close()
   731  	// Trunk can still be nil here in case none matched or were defined.
   732  	// This is still fine, but only in case we'll match exactly one wildcard dispatch rule.
   733  
   734  	// We split the matched dispatch rules into two sets in relation to Trunks: specific and default (aka wildcard).
   735  	// First, attempt to match any of the specific rules, where we did match the Trunk ID.
   736  	// If nothing matches there - fallback to default/wildcard rules, where no Trunk IDs were mentioned.
   737  	var (
   738  		specificRule    *livekit.SIPDispatchRuleInfo
   739  		specificRuleCnt int
   740  		defaultRule     *livekit.SIPDispatchRuleInfo
   741  		defaultRuleCnt  int
   742  	)
   743  	noPin := req.NoPin
   744  	sentPin := req.GetPin()
   745  	for {
   746  		info, err := rules.Next()
   747  		if err == io.EOF {
   748  			break
   749  		} else if err != nil {
   750  			return nil, err
   751  		}
   752  		if len(info.InboundNumbers) != 0 && !slices.Contains(info.InboundNumbers, req.CallingNumber) {
   753  			continue
   754  		}
   755  		_, rulePin, err := GetPinAndRoom(info)
   756  		if err != nil {
   757  			logger.Errorw("Invalid SIP Dispatch Rule", err, "dispatchRuleID", info.SipDispatchRuleId)
   758  			continue
   759  		}
   760  		// Filter heavily on the Pin, so that only relevant rules remain.
   761  		if noPin {
   762  			if rulePin != "" {
   763  				// Skip pin-protected rules if no pin mode requested.
   764  				continue
   765  			}
   766  		} else if sentPin != "" {
   767  			if rulePin == "" {
   768  				// Pin already sent, skip non-pin-protected rules.
   769  				continue
   770  			}
   771  			if sentPin != rulePin {
   772  				// Pin doesn't match. Don't return an error here, just wait for other rule to match (or none at all).
   773  				// Note that we will NOT match non-pin-protected rules, thus it will not fallback to open rules.
   774  				continue
   775  			}
   776  		}
   777  		if len(info.TrunkIds) == 0 {
   778  			// Default/wildcard dispatch rule.
   779  			defaultRuleCnt++
   780  			if defaultRule == nil || hasHigherPriority(info, defaultRule) {
   781  				defaultRule = info
   782  			}
   783  			continue
   784  		}
   785  		// Specific dispatch rules. Require a Trunk associated with the number.
   786  		if trunk == nil {
   787  			continue
   788  		}
   789  		if !slices.Contains(info.TrunkIds, trunk.SipTrunkId) {
   790  			continue
   791  		}
   792  		specificRuleCnt++
   793  		if specificRule == nil || hasHigherPriority(info, specificRule) {
   794  			specificRule = info
   795  		}
   796  	}
   797  	if specificRuleCnt == 0 && defaultRuleCnt == 0 {
   798  		err := &ErrNoDispatchMatched{NoRules: true, NoTrunks: trunk == nil, CalledNumber: req.CalledNumber}
   799  		return nil, twirp.WrapError(twirp.NewErrorf(twirp.FailedPrecondition, err.Error()), err)
   800  	}
   801  	if specificRule != nil {
   802  		return specificRule, nil
   803  	}
   804  	if defaultRule != nil {
   805  		return defaultRule, nil
   806  	}
   807  	err := &ErrNoDispatchMatched{NoRules: false, NoTrunks: trunk == nil, CalledNumber: req.CalledNumber}
   808  	return nil, twirp.WrapError(twirp.NewErrorf(twirp.FailedPrecondition, err.Error()), err)
   809  }
   810  
   811  // EvaluateDispatchRule checks a selected Dispatch Rule against the provided request.
   812  func EvaluateDispatchRule(projectID string, trunk *livekit.SIPInboundTrunkInfo, rule *livekit.SIPDispatchRuleInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*rpc.EvaluateSIPDispatchRulesResponse, error) {
   813  	call := req.SIPCall()
   814  	sentPin := req.GetPin()
   815  
   816  	trunkID := req.SipTrunkId
   817  	if trunk != nil {
   818  		trunkID = trunk.SipTrunkId
   819  	}
   820  	enc := livekit.SIPMediaEncryption_SIP_MEDIA_ENCRYPT_DISABLE
   821  	if trunk != nil {
   822  		enc = trunk.MediaEncryption
   823  	}
   824  	attrs := maps.Clone(rule.Attributes)
   825  	if attrs == nil {
   826  		attrs = make(map[string]string)
   827  	}
   828  	for k, v := range req.ExtraAttributes {
   829  		attrs[k] = v
   830  	}
   831  	attrs[livekit.AttrSIPCallID] = call.LkCallId
   832  	attrs[livekit.AttrSIPTrunkID] = trunkID
   833  
   834  	to := call.To.User
   835  	from := call.From.User
   836  	fromName := "Phone " + from
   837  	fromID := "sip_" + from
   838  	if rule.HidePhoneNumber {
   839  		// Mask the phone number, hash identity. Omit number in attrs.
   840  		h := sha256.Sum256([]byte(call.From.User))
   841  		fromID = "sip_" + hex.EncodeToString(h[:8])
   842  		// TODO: Maybe keep regional code, but mask all but 4 last digits?
   843  		n := 4
   844  		if len(from) <= 4 {
   845  			n = 1
   846  		}
   847  		from = from[len(from)-n:]
   848  		fromName = "Phone " + from
   849  	} else {
   850  		attrs[livekit.AttrSIPPhoneNumber] = call.From.User
   851  		attrs[livekit.AttrSIPHostName] = call.From.Host
   852  		attrs[livekit.AttrSIPTrunkNumber] = call.To.User
   853  	}
   854  
   855  	room, rulePin, err := GetPinAndRoom(rule)
   856  	if err != nil {
   857  		return nil, err
   858  	}
   859  	if rulePin != "" {
   860  		if sentPin == "" {
   861  			return &rpc.EvaluateSIPDispatchRulesResponse{
   862  				ProjectId:         projectID,
   863  				SipTrunkId:        trunkID,
   864  				SipDispatchRuleId: rule.SipDispatchRuleId,
   865  				Result:            rpc.SIPDispatchResult_REQUEST_PIN,
   866  				MediaEncryption:   enc,
   867  				RequestPin:        true,
   868  			}, nil
   869  		}
   870  		if rulePin != sentPin {
   871  			// This should never happen in practice, because matchSIPDispatchRule should remove rules with the wrong pin.
   872  			return nil, twirp.NewError(twirp.PermissionDenied, "Incorrect PIN for SIP room")
   873  		}
   874  	} else {
   875  		// Pin was sent, but room doesn't require one. Assume user accidentally pressed phone button.
   876  	}
   877  	switch rule := rule.GetRule().GetRule().(type) {
   878  	case *livekit.SIPDispatchRule_DispatchRuleIndividual:
   879  		// TODO: Remove "_" if the prefix is empty for consistency with Callee dispatch rule.
   880  		// TODO: Do we need to escape specific characters in the number?
   881  		// TODO: Include actual SIP call ID in the room name?
   882  		room = fmt.Sprintf("%s_%s_%s", rule.DispatchRuleIndividual.GetRoomPrefix(), from, guid.New(""))
   883  	case *livekit.SIPDispatchRule_DispatchRuleCallee:
   884  		room = to
   885  		if pref := rule.DispatchRuleCallee.GetRoomPrefix(); pref != "" {
   886  			room = pref + "_" + to
   887  		}
   888  		if rule.DispatchRuleCallee.Randomize {
   889  			room += "_" + guid.New("")
   890  		}
   891  	}
   892  	attrs[livekit.AttrSIPDispatchRuleID] = rule.SipDispatchRuleId
   893  	resp := &rpc.EvaluateSIPDispatchRulesResponse{
   894  		ProjectId:             projectID,
   895  		SipTrunkId:            trunkID,
   896  		SipDispatchRuleId:     rule.SipDispatchRuleId,
   897  		Result:                rpc.SIPDispatchResult_ACCEPT,
   898  		RoomName:              room,
   899  		ParticipantIdentity:   fromID,
   900  		ParticipantName:       fromName,
   901  		ParticipantMetadata:   rule.Metadata,
   902  		ParticipantAttributes: attrs,
   903  		RoomPreset:            rule.RoomPreset,
   904  		RoomConfig:            rule.RoomConfig,
   905  		MediaEncryption:       enc,
   906  	}
   907  	krispEnabled := false
   908  	if trunk != nil {
   909  		resp.Headers = trunk.Headers
   910  		resp.HeadersToAttributes = trunk.HeadersToAttributes
   911  		resp.AttributesToHeaders = trunk.AttributesToHeaders
   912  		resp.IncludeHeaders = trunk.IncludeHeaders
   913  		resp.RingingTimeout = trunk.RingingTimeout
   914  		resp.MaxCallDuration = trunk.MaxCallDuration
   915  		krispEnabled = krispEnabled || trunk.KrispEnabled
   916  	}
   917  	if rule != nil {
   918  		krispEnabled = krispEnabled || rule.KrispEnabled
   919  		if rule.MediaEncryption != 0 {
   920  			resp.MediaEncryption = rule.MediaEncryption
   921  		}
   922  	}
   923  	if krispEnabled {
   924  		resp.EnabledFeatures = append(resp.EnabledFeatures, livekit.SIPFeature_KRISP_ENABLED)
   925  	}
   926  	return resp, nil
   927  }