github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/unfurl/extractor.go (about)

     1  package unfurl
     2  
     3  import (
     4  	"context"
     5  	"regexp"
     6  	"sync"
     7  
     8  	"mvdan.cc/xurls/v2"
     9  
    10  	"github.com/keybase/client/go/chat/globals"
    11  	"github.com/keybase/client/go/chat/types"
    12  	"github.com/keybase/client/go/chat/utils"
    13  	"github.com/keybase/client/go/protocol/chat1"
    14  	"github.com/keybase/client/go/protocol/gregor1"
    15  )
    16  
    17  type ExtractorHitTyp int
    18  
    19  const (
    20  	ExtractorHitUnfurl ExtractorHitTyp = iota
    21  	ExtractorHitPrompt
    22  )
    23  
    24  type ExtractorHit struct {
    25  	URL string
    26  	Typ ExtractorHitTyp
    27  }
    28  
    29  type Extractor struct {
    30  	utils.DebugLabeler
    31  
    32  	urlRegexp      *regexp.Regexp
    33  	quoteRegexp    *regexp.Regexp
    34  	maxHits        int
    35  	exemptionsLock sync.Mutex
    36  	exemptions     map[string]*WhitelistExemptionList
    37  }
    38  
    39  func NewExtractor(g *globals.Context) *Extractor {
    40  	return &Extractor{
    41  		DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "Extractor", false),
    42  		urlRegexp:    xurls.Strict(),
    43  		quoteRegexp:  regexp.MustCompile("`[^`]*`"),
    44  		exemptions:   make(map[string]*WhitelistExemptionList),
    45  		maxHits:      5,
    46  	}
    47  }
    48  
    49  func (e *Extractor) getExemptionList(uid gregor1.UID) (res *WhitelistExemptionList) {
    50  	e.exemptionsLock.Lock()
    51  	defer e.exemptionsLock.Unlock()
    52  	var ok bool
    53  	res, ok = e.exemptions[uid.String()]
    54  	if !ok {
    55  		res = NewWhitelistExemptionList()
    56  		e.exemptions[uid.String()] = res
    57  	}
    58  	return res
    59  }
    60  
    61  func (e *Extractor) isAutoWhitelist(domain string) bool {
    62  	switch domain {
    63  	case "giphy.com", types.MapsDomain:
    64  		return true
    65  	}
    66  	return false
    67  }
    68  
    69  func (e *Extractor) isAutoWhitelistFromHit(ctx context.Context, hit string) bool {
    70  	domain, err := GetDomain(hit)
    71  	if err != nil {
    72  		e.Debug(ctx, "isAutoWhitelistFromHit: failed to get domain: %s", err)
    73  		return false
    74  	}
    75  	return e.isAutoWhitelist(domain)
    76  }
    77  
    78  func (e *Extractor) isWhitelistHit(ctx context.Context, convID chat1.ConversationID, msgID chat1.MessageID,
    79  	hit string, whitelist map[string]bool, exemptions *WhitelistExemptionList) bool {
    80  	domain, err := GetDomain(hit)
    81  	if err != nil {
    82  		e.Debug(ctx, "isWhitelistHit: failed to get domain: %s", err)
    83  		return false
    84  	}
    85  	if e.isAutoWhitelist(domain) || whitelist[domain] {
    86  		return true
    87  	}
    88  	// Check exemptions
    89  	if exemptions.Use(convID, msgID, domain) {
    90  		e.Debug(ctx, "isWhitelistHit: hit exemption for domain, letting through")
    91  		return true
    92  	}
    93  	return false
    94  }
    95  
    96  func (e *Extractor) Extract(ctx context.Context, uid gregor1.UID, convID chat1.ConversationID,
    97  	msgID chat1.MessageID, body string, userSettings *Settings) (res []ExtractorHit, err error) {
    98  	defer e.Trace(ctx, &err, "Extract")()
    99  	body = e.quoteRegexp.ReplaceAllString(body, "")
   100  	hits := e.urlRegexp.FindAllString(body, -1)
   101  	if len(hits) == 0 {
   102  		return res, nil
   103  	}
   104  	settings, err := userSettings.Get(ctx, uid)
   105  	if err != nil {
   106  		return res, err
   107  	}
   108  	for _, h := range hits {
   109  		ehit := ExtractorHit{
   110  			URL: h,
   111  			Typ: ExtractorHitPrompt,
   112  		}
   113  		switch settings.Mode {
   114  		case chat1.UnfurlMode_ALWAYS:
   115  			ehit.Typ = ExtractorHitUnfurl
   116  		case chat1.UnfurlMode_WHITELISTED:
   117  			if e.isWhitelistHit(ctx, convID, msgID, h, settings.Whitelist, e.getExemptionList(uid)) {
   118  				ehit.Typ = ExtractorHitUnfurl
   119  			}
   120  		case chat1.UnfurlMode_NEVER:
   121  			if e.isAutoWhitelistFromHit(ctx, h) {
   122  				ehit.Typ = ExtractorHitUnfurl
   123  			} else {
   124  				continue
   125  			}
   126  		}
   127  		res = append(res, ehit)
   128  		if len(res) >= e.maxHits {
   129  			e.Debug(ctx, "Extract: max hits reached, aborting")
   130  			break
   131  		}
   132  	}
   133  	return res, nil
   134  }
   135  
   136  func (e *Extractor) AddWhitelistExemption(ctx context.Context, uid gregor1.UID,
   137  	exemption types.WhitelistExemption) {
   138  	defer e.Trace(ctx, nil, "AddWhitelistExemption")()
   139  	e.getExemptionList(uid).Add(exemption)
   140  }