github.com/MontFerret/ferret@v0.18.0/pkg/drivers/cdp/network/manager_test.go (about) 1 package network_test 2 3 import ( 4 "context" 5 "os" 6 "testing" 7 "time" 8 9 "github.com/mafredri/cdp" 10 "github.com/mafredri/cdp/protocol/fetch" 11 network2 "github.com/mafredri/cdp/protocol/network" 12 "github.com/mafredri/cdp/protocol/page" 13 "github.com/rs/zerolog" 14 . "github.com/smartystreets/goconvey/convey" 15 "github.com/stretchr/testify/mock" 16 17 "github.com/MontFerret/ferret/pkg/drivers" 18 "github.com/MontFerret/ferret/pkg/drivers/cdp/network" 19 ) 20 21 type ( 22 PageAPI struct { 23 mock.Mock 24 cdp.Page 25 frameNavigated func(ctx context.Context) (page.FrameNavigatedClient, error) 26 } 27 28 NetworkAPI struct { 29 mock.Mock 30 cdp.Network 31 responseReceived func(ctx context.Context) (network2.ResponseReceivedClient, error) 32 setExtraHTTPHeaders func(ctx context.Context, args *network2.SetExtraHTTPHeadersArgs) error 33 } 34 35 FetchAPI struct { 36 mock.Mock 37 cdp.Fetch 38 enable func(context.Context, *fetch.EnableArgs) error 39 disable func(context.Context) error 40 requestPaused func(context.Context) (fetch.RequestPausedClient, error) 41 } 42 43 TestEventStream struct { 44 mock.Mock 45 ready chan struct{} 46 message chan interface{} 47 } 48 49 FrameNavigatedClient struct { 50 *TestEventStream 51 } 52 53 ResponseReceivedClient struct { 54 *TestEventStream 55 } 56 57 RequestPausedClient struct { 58 *TestEventStream 59 } 60 ) 61 62 func (api *PageAPI) FrameNavigated(ctx context.Context) (page.FrameNavigatedClient, error) { 63 return api.frameNavigated(ctx) 64 } 65 66 func (api *NetworkAPI) ResponseReceived(ctx context.Context) (network2.ResponseReceivedClient, error) { 67 return api.responseReceived(ctx) 68 } 69 70 func (api *NetworkAPI) SetExtraHTTPHeaders(ctx context.Context, args *network2.SetExtraHTTPHeadersArgs) error { 71 return api.setExtraHTTPHeaders(ctx, args) 72 } 73 74 func (api *FetchAPI) Enable(ctx context.Context, args *fetch.EnableArgs) error { 75 if api.enable == nil { 76 return nil 77 } 78 79 return api.enable(ctx, args) 80 } 81 82 func (api *FetchAPI) Disable(ctx context.Context) error { 83 if api.disable == nil { 84 return nil 85 } 86 87 return api.disable(ctx) 88 } 89 90 func (api *FetchAPI) RequestPaused(ctx context.Context) (fetch.RequestPausedClient, error) { 91 return api.requestPaused(ctx) 92 } 93 94 func NewTestEventStream() *TestEventStream { 95 return NewBufferedTestEventStream(0) 96 } 97 98 func NewBufferedTestEventStream(buffer int) *TestEventStream { 99 es := new(TestEventStream) 100 es.ready = make(chan struct{}, buffer) 101 es.message = make(chan interface{}, buffer) 102 return es 103 } 104 105 func (stream *TestEventStream) Ready() <-chan struct{} { 106 return stream.ready 107 } 108 109 func (stream *TestEventStream) RecvMsg(i interface{}) error { 110 return nil 111 } 112 113 func (stream *TestEventStream) Message() interface{} { 114 return <-stream.message 115 } 116 117 func (stream *TestEventStream) Close() error { 118 stream.Called() 119 close(stream.message) 120 close(stream.ready) 121 return nil 122 } 123 124 func (stream *TestEventStream) Emit(msg interface{}) { 125 stream.ready <- struct{}{} 126 stream.message <- msg 127 } 128 129 func NewFrameNavigatedClient() *FrameNavigatedClient { 130 return &FrameNavigatedClient{ 131 TestEventStream: NewTestEventStream(), 132 } 133 } 134 135 func (stream *FrameNavigatedClient) Recv() (*page.FrameNavigatedReply, error) { 136 <-stream.Ready() 137 msg := stream.Message() 138 139 repl, ok := msg.(*page.FrameNavigatedReply) 140 141 if !ok { 142 panic("Invalid message type") 143 } 144 145 return repl, nil 146 } 147 148 func NewResponseReceivedClient() *ResponseReceivedClient { 149 return &ResponseReceivedClient{ 150 TestEventStream: NewTestEventStream(), 151 } 152 } 153 154 func (stream *ResponseReceivedClient) Recv() (*network2.ResponseReceivedReply, error) { 155 <-stream.Ready() 156 msg := stream.Message() 157 158 repl, ok := msg.(*network2.ResponseReceivedReply) 159 160 if !ok { 161 panic("Invalid message type") 162 } 163 164 return repl, nil 165 } 166 167 func NewRequestPausedClient() *RequestPausedClient { 168 return &RequestPausedClient{ 169 TestEventStream: NewTestEventStream(), 170 } 171 } 172 173 func (stream *RequestPausedClient) Recv() (*fetch.RequestPausedReply, error) { 174 <-stream.Ready() 175 msg := stream.Message() 176 177 repl, ok := msg.(*fetch.RequestPausedReply) 178 179 if !ok { 180 panic("Invalid message type") 181 } 182 183 return repl, nil 184 } 185 186 func TestManager(t *testing.T) { 187 Convey("Network manager", t, func() { 188 189 Convey("New", func() { 190 Convey("Should close all resources on Close", func() { 191 responseReceivedClient := NewResponseReceivedClient() 192 responseReceivedClient.On("Close", mock.Anything).Once().Return(nil) 193 networkAPI := new(NetworkAPI) 194 networkAPI.responseReceived = func(ctx context.Context) (network2.ResponseReceivedClient, error) { 195 return responseReceivedClient, nil 196 } 197 networkAPI.setExtraHTTPHeaders = func(ctx context.Context, args *network2.SetExtraHTTPHeadersArgs) error { 198 return nil 199 } 200 201 requestPausedClient := NewRequestPausedClient() 202 requestPausedClient.On("Close", mock.Anything).Once().Return(nil) 203 fetchAPI := new(FetchAPI) 204 fetchAPI.enable = func(ctx context.Context, args *fetch.EnableArgs) error { 205 return nil 206 } 207 fetchAPI.requestPaused = func(ctx context.Context) (fetch.RequestPausedClient, error) { 208 return requestPausedClient, nil 209 } 210 211 client := &cdp.Client{ 212 Network: networkAPI, 213 Fetch: fetchAPI, 214 } 215 216 mgr, err := network.New( 217 zerolog.New(os.Stdout).Level(zerolog.Disabled), 218 client, 219 network.Options{ 220 Headers: drivers.NewHTTPHeadersWith(map[string][]string{"x-correlation-id": {"foo"}}), 221 Filter: &network.Filter{ 222 Patterns: []drivers.ResourceFilter{ 223 { 224 URL: "http://google.com", 225 Type: "img", 226 }, 227 }, 228 }, 229 }, 230 ) 231 232 So(err, ShouldBeNil) 233 So(mgr.Close(), ShouldBeNil) 234 235 time.Sleep(time.Duration(100) * time.Millisecond) 236 237 responseReceivedClient.AssertExpectations(t) 238 requestPausedClient.AssertExpectations(t) 239 }) 240 }) 241 }) 242 }