github.com/MontFerret/ferret@v0.18.0/pkg/drivers/cdp/network/interceptor.go (about)

     1  package network
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/gobwas/glob"
     8  	"github.com/mafredri/cdp"
     9  	"github.com/mafredri/cdp/protocol/fetch"
    10  	"github.com/mafredri/cdp/protocol/network"
    11  	"github.com/rs/zerolog"
    12  
    13  	"github.com/MontFerret/ferret/pkg/drivers"
    14  	"github.com/MontFerret/ferret/pkg/drivers/cdp/events"
    15  	"github.com/MontFerret/ferret/pkg/runtime/logging"
    16  )
    17  
    18  type (
    19  	ResourceFilter struct {
    20  		URL          glob.Glob
    21  		ResourceType string
    22  	}
    23  
    24  	Interceptor struct {
    25  		mu      sync.RWMutex
    26  		running bool
    27  		logger  zerolog.Logger
    28  		client  *cdp.Client
    29  		filters map[string]*InterceptorFilter
    30  		loop    *events.Loop
    31  	}
    32  
    33  	InterceptorFilter struct {
    34  		resources []ResourceFilter
    35  	}
    36  
    37  	InterceptorListener func(ctx context.Context, msg *fetch.RequestPausedReply) bool
    38  )
    39  
    40  func NewInterceptorFilter(filter *Filter) (*InterceptorFilter, error) {
    41  	interFilter := new(InterceptorFilter)
    42  	interFilter.resources = make([]ResourceFilter, 0, len(filter.Patterns))
    43  
    44  	for _, pattern := range filter.Patterns {
    45  		rf := ResourceFilter{
    46  			ResourceType: pattern.Type,
    47  		}
    48  
    49  		if pattern.URL != "" {
    50  			p, err := glob.Compile(pattern.URL)
    51  
    52  			if err != nil {
    53  				return nil, err
    54  			}
    55  
    56  			rf.URL = p
    57  		}
    58  
    59  		if rf.ResourceType != "" && rf.URL != nil {
    60  			interFilter.resources = append(interFilter.resources, rf)
    61  		}
    62  	}
    63  
    64  	return interFilter, nil
    65  }
    66  
    67  func (f *InterceptorFilter) Filter(rt network.ResourceType, req network.Request) bool {
    68  	var result bool
    69  
    70  	for _, pattern := range f.resources {
    71  		if pattern.ResourceType != "" && pattern.URL != nil {
    72  			result = string(rt) == pattern.ResourceType && pattern.URL.Match(req.URL)
    73  		} else if pattern.ResourceType != "" {
    74  			result = string(rt) == pattern.ResourceType
    75  		} else if pattern.URL != nil {
    76  			result = pattern.URL.Match(req.URL)
    77  		}
    78  
    79  		if result {
    80  			break
    81  		}
    82  	}
    83  
    84  	return result
    85  }
    86  
    87  func NewInterceptor(logger zerolog.Logger, client *cdp.Client) *Interceptor {
    88  	i := new(Interceptor)
    89  	i.logger = logging.WithName(logger.With(), "network_interceptor").Logger()
    90  	i.client = client
    91  	i.filters = make(map[string]*InterceptorFilter)
    92  	i.loop = events.NewLoop(createRequestPausedStreamFactory(client))
    93  	i.loop.AddListener(requestPausedEvent, events.Always(i.filter))
    94  
    95  	return i
    96  }
    97  
    98  func (i *Interceptor) IsRunning() bool {
    99  	i.mu.Lock()
   100  	defer i.mu.Unlock()
   101  
   102  	return i.running
   103  }
   104  
   105  func (i *Interceptor) AddFilter(name string, filter *Filter) error {
   106  	i.mu.Lock()
   107  	defer i.mu.Unlock()
   108  
   109  	f, err := NewInterceptorFilter(filter)
   110  
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	i.filters[name] = f
   116  
   117  	return nil
   118  }
   119  
   120  func (i *Interceptor) RemoveFilter(name string) {
   121  	i.mu.Lock()
   122  	defer i.mu.Unlock()
   123  
   124  	delete(i.filters, name)
   125  }
   126  
   127  func (i *Interceptor) AddListener(listener InterceptorListener) events.ListenerID {
   128  	i.mu.Lock()
   129  	defer i.mu.Unlock()
   130  
   131  	return i.loop.AddListener(requestPausedEvent, func(ctx context.Context, message interface{}) bool {
   132  		msg, ok := message.(*fetch.RequestPausedReply)
   133  
   134  		if !ok {
   135  			return true
   136  		}
   137  
   138  		return listener(ctx, msg)
   139  	})
   140  }
   141  
   142  func (i *Interceptor) RemoveListener(id events.ListenerID) {
   143  	i.mu.Lock()
   144  	defer i.mu.Unlock()
   145  
   146  	i.loop.RemoveListener(requestPausedEvent, id)
   147  }
   148  
   149  func (i *Interceptor) Run(ctx context.Context) error {
   150  	i.mu.Lock()
   151  	defer i.mu.Unlock()
   152  
   153  	if i.running {
   154  		return nil
   155  	}
   156  
   157  	err := i.client.Fetch.Enable(ctx, fetch.NewEnableArgs())
   158  	i.running = err == nil
   159  
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	if err := i.loop.Run(ctx); err != nil {
   165  		if e := i.client.Fetch.Disable(ctx); e != nil {
   166  			i.logger.Err(err).Msg("failed to disable fetch")
   167  		}
   168  
   169  		i.running = false
   170  
   171  		return err
   172  	}
   173  
   174  	go func() {
   175  		<-ctx.Done()
   176  
   177  		nested, cancel := context.WithTimeout(context.Background(), drivers.DefaultWaitTimeout)
   178  		defer cancel()
   179  
   180  		i.stop(nested)
   181  	}()
   182  
   183  	return nil
   184  }
   185  
   186  func (i *Interceptor) stop(ctx context.Context) {
   187  	err := i.client.Fetch.Disable(ctx)
   188  	i.running = false
   189  
   190  	if err != nil {
   191  		i.logger.Err(err).Msg("failed to stop interceptor")
   192  	}
   193  }
   194  
   195  func (i *Interceptor) filter(ctx context.Context, message interface{}) {
   196  	i.mu.Lock()
   197  	defer i.mu.Unlock()
   198  
   199  	msg, ok := message.(*fetch.RequestPausedReply)
   200  
   201  	if !ok {
   202  		return
   203  	}
   204  
   205  	log := i.logger.With().
   206  		Str("request_id", string(msg.RequestID)).
   207  		Str("frame_id", string(msg.FrameID)).
   208  		Str("resource_type", string(msg.ResourceType)).
   209  		Str("url", msg.Request.URL).
   210  		Logger()
   211  
   212  	log.Trace().Msg("trying to block resource loading")
   213  
   214  	var reject bool
   215  
   216  	for _, filter := range i.filters {
   217  		reject = filter.Filter(msg.ResourceType, msg.Request)
   218  
   219  		if reject {
   220  			break
   221  		}
   222  	}
   223  
   224  	if !reject {
   225  		err := i.client.Fetch.ContinueRequest(ctx, fetch.NewContinueRequestArgs(msg.RequestID))
   226  
   227  		if err != nil {
   228  			i.logger.Err(err).Msg("failed to allow resource loading")
   229  		}
   230  
   231  		log.Trace().Msg("succeeded to allow resource loading")
   232  
   233  		return
   234  	}
   235  
   236  	err := i.client.Fetch.FailRequest(ctx, fetch.NewFailRequestArgs(msg.RequestID, network.ErrorReasonBlockedByClient))
   237  
   238  	if err != nil {
   239  		log.Trace().Err(err).Msg("failed to block resource loading")
   240  	}
   241  
   242  	log.Trace().Msg("succeeded to block resource loading")
   243  }