github.com/network-quality/goresponsiveness@v0.0.0-20240129151524-343954285090/lgc/download.go (about)

     1  /*
     2   * This file is part of Go Responsiveness.
     3   *
     4   * Go Responsiveness is free software: you can redistribute it and/or modify it under
     5   * the terms of the GNU General Public License as published by the Free Software Foundation,
     6   * either version 2 of the License, or (at your option) any later version.
     7   * Go Responsiveness is distributed in the hope that it will be useful, but WITHOUT ANY
     8   * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
     9   * PARTICULAR PURPOSE. See the GNU General Public License for more details.
    10   *
    11   * You should have received a copy of the GNU General Public License along
    12   * with Go Responsiveness. If not, see <https://www.gnu.org/licenses/>.
    13   */
    14  
    15  package lgc
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptrace"
    24  	"os"
    25  	"sync"
    26  	"sync/atomic"
    27  	"time"
    28  
    29  	"github.com/network-quality/goresponsiveness/debug"
    30  	"github.com/network-quality/goresponsiveness/l4s"
    31  	"github.com/network-quality/goresponsiveness/stats"
    32  	"github.com/network-quality/goresponsiveness/traceable"
    33  	"github.com/network-quality/goresponsiveness/utilities"
    34  )
    35  
    36  // TODO: All 64-bit fields that are accessed atomically must
    37  // appear at the top of this struct.
    38  type LoadGeneratingConnectionDownload struct {
    39  	downloaded         uint64
    40  	lastIntervalEnd    int64
    41  	ConnectToAddr      string
    42  	URL                string
    43  	downloadStartTime  time.Time
    44  	client             *http.Client
    45  	debug              debug.DebugLevel
    46  	InsecureSkipVerify bool
    47  	KeyLogger          io.Writer
    48  	clientId           uint64
    49  	tracer             *httptrace.ClientTrace
    50  	stats              stats.TraceStats
    51  	status             LgcStatus
    52  	congestionControl  *string
    53  	statusLock         *sync.Mutex
    54  	statusWaiter       *sync.Cond
    55  }
    56  
    57  func NewLoadGeneratingConnectionDownload(url string, keyLogger io.Writer, connectToAddr string, insecureSkipVerify bool, congestionControl *string) LoadGeneratingConnectionDownload {
    58  	lgd := LoadGeneratingConnectionDownload{
    59  		URL:                url,
    60  		KeyLogger:          keyLogger,
    61  		ConnectToAddr:      connectToAddr,
    62  		InsecureSkipVerify: insecureSkipVerify,
    63  		congestionControl:  congestionControl,
    64  		statusLock:         &sync.Mutex{},
    65  	}
    66  	lgd.statusWaiter = sync.NewCond(lgd.statusLock)
    67  	return lgd
    68  }
    69  
    70  func (lgd *LoadGeneratingConnectionDownload) Direction() LgcDirection {
    71  	return LGC_DOWN
    72  }
    73  
    74  func (lgd *LoadGeneratingConnectionDownload) WaitUntilStarted(ctxt context.Context) bool {
    75  	conditional := func() bool { return lgd.status != LGC_STATUS_NOT_STARTED }
    76  	go utilities.ContextSignaler(ctxt, 500*time.Millisecond, &conditional, lgd.statusWaiter)
    77  	return utilities.WaitWithContext(ctxt, &conditional, lgd.statusLock, lgd.statusWaiter)
    78  }
    79  
    80  func (lgd *LoadGeneratingConnectionDownload) SetDnsStartTimeInfo(
    81  	now time.Time,
    82  	dnsStartInfo httptrace.DNSStartInfo,
    83  ) {
    84  	lgd.stats.DnsStartTime = now
    85  	lgd.stats.DnsStart = dnsStartInfo
    86  	if debug.IsDebug(lgd.debug) {
    87  		fmt.Printf(
    88  			"DNS Start for %v: %v\n",
    89  			lgd.ClientId(),
    90  			dnsStartInfo,
    91  		)
    92  	}
    93  }
    94  
    95  func (lgd *LoadGeneratingConnectionDownload) SetDnsDoneTimeInfo(
    96  	now time.Time,
    97  	dnsDoneInfo httptrace.DNSDoneInfo,
    98  ) {
    99  	lgd.stats.DnsDoneTime = now
   100  	lgd.stats.DnsDone = dnsDoneInfo
   101  	if debug.IsDebug(lgd.debug) {
   102  		fmt.Printf(
   103  			"DNS Done for %v: %v\n",
   104  			lgd.ClientId(),
   105  			lgd.stats.DnsDone,
   106  		)
   107  	}
   108  }
   109  
   110  func (lgd *LoadGeneratingConnectionDownload) SetConnectStartTime(
   111  	now time.Time,
   112  ) {
   113  	lgd.stats.ConnectStartTime = now
   114  	if debug.IsDebug(lgd.debug) {
   115  		fmt.Printf(
   116  			"TCP Start for %v at %v\n",
   117  			lgd.ClientId(),
   118  			lgd.stats.ConnectStartTime,
   119  		)
   120  	}
   121  }
   122  
   123  func (lgd *LoadGeneratingConnectionDownload) SetConnectDoneTimeError(
   124  	now time.Time,
   125  	err error,
   126  ) {
   127  	lgd.stats.ConnectDoneTime = now
   128  	lgd.stats.ConnectDoneError = err
   129  	if debug.IsDebug(lgd.debug) {
   130  		fmt.Printf(
   131  			"TCP Done for %v (with error %v) @ %v\n",
   132  			lgd.ClientId(),
   133  			lgd.stats.ConnectDoneError,
   134  			lgd.stats.ConnectDoneTime,
   135  		)
   136  	}
   137  }
   138  
   139  func (lgd *LoadGeneratingConnectionDownload) SetGetConnTime(now time.Time) {
   140  	lgd.stats.GetConnectionStartTime = now
   141  	if debug.IsDebug(lgd.debug) {
   142  		fmt.Printf(
   143  			"Started getting connection for %v @ %v\n",
   144  			lgd.ClientId(),
   145  			lgd.stats.GetConnectionStartTime,
   146  		)
   147  	}
   148  }
   149  
   150  func (lgd *LoadGeneratingConnectionDownload) SetGotConnTimeInfo(
   151  	now time.Time,
   152  	gotConnInfo httptrace.GotConnInfo,
   153  ) {
   154  	if gotConnInfo.Reused {
   155  		fmt.Printf("Unexpectedly reusing a connection!\n")
   156  		panic(!gotConnInfo.Reused)
   157  	}
   158  	lgd.stats.GetConnectionDoneTime = now
   159  	lgd.stats.ConnInfo = gotConnInfo
   160  	if debug.IsDebug(lgd.debug) {
   161  		fmt.Printf(
   162  			"Got connection for %v at %v with info %v\n",
   163  			lgd.ClientId(),
   164  			lgd.stats.GetConnectionDoneTime,
   165  			lgd.stats.ConnInfo,
   166  		)
   167  	}
   168  
   169  	if lgd.congestionControl != nil {
   170  		if debug.IsDebug(lgd.debug) {
   171  			fmt.Printf(
   172  				"Attempting to set congestion control algorithm to %v for connection %v at %v with info %v\n",
   173  				*lgd.congestionControl,
   174  				lgd.ClientId(),
   175  				lgd.stats.GetConnectionDoneTime,
   176  				lgd.stats.ConnInfo,
   177  			)
   178  		}
   179  		if err := l4s.SetL4S(lgd.stats.ConnInfo.Conn, lgd.congestionControl); err != nil {
   180  			fmt.Fprintf(
   181  				os.Stderr,
   182  				"Error setting L4S for %v at %v: %v\n",
   183  				lgd.ClientId(),
   184  				lgd.stats.GetConnectionDoneTime,
   185  				err.Error(),
   186  			)
   187  		}
   188  	}
   189  }
   190  
   191  func (lgd *LoadGeneratingConnectionDownload) SetTLSHandshakeStartTime(
   192  	now time.Time,
   193  ) {
   194  	lgd.stats.TLSStartTime = utilities.Some(now)
   195  	if debug.IsDebug(lgd.debug) {
   196  		fmt.Printf(
   197  			"Started TLS Handshake for %v @ %v\n",
   198  			lgd.ClientId(),
   199  			lgd.stats.TLSStartTime,
   200  		)
   201  	}
   202  }
   203  
   204  func (lgd *LoadGeneratingConnectionDownload) SetTLSHandshakeDoneTimeState(
   205  	now time.Time,
   206  	connectionState tls.ConnectionState,
   207  ) {
   208  	lgd.stats.TLSDoneTime = utilities.Some(now)
   209  	lgd.stats.TLSConnInfo = connectionState
   210  	if debug.IsDebug(lgd.debug) {
   211  		fmt.Printf(
   212  			"Completed TLS handshake for %v at %v with info %v\n",
   213  			lgd.ClientId(),
   214  			lgd.stats.TLSDoneTime,
   215  			lgd.stats.TLSConnInfo,
   216  		)
   217  	}
   218  }
   219  
   220  func (lgd *LoadGeneratingConnectionDownload) SetHttpWroteRequestTimeInfo(
   221  	now time.Time,
   222  	info httptrace.WroteRequestInfo,
   223  ) {
   224  	lgd.stats.HttpWroteRequestTime = now
   225  	lgd.stats.HttpInfo = info
   226  	if debug.IsDebug(lgd.debug) {
   227  		fmt.Printf(
   228  			"(lgd) Http finished writing request for %v at %v with info %v\n",
   229  			lgd.ClientId(),
   230  			lgd.stats.HttpWroteRequestTime,
   231  			lgd.stats.HttpInfo,
   232  		)
   233  	}
   234  }
   235  
   236  func (lgd *LoadGeneratingConnectionDownload) SetHttpResponseReadyTime(
   237  	now time.Time,
   238  ) {
   239  	lgd.stats.HttpResponseReadyTime = now
   240  	if debug.IsDebug(lgd.debug) {
   241  		fmt.Printf(
   242  			"Got the first byte of HTTP response headers for %v at %v\n",
   243  			lgd.ClientId(),
   244  			lgd.stats.HttpResponseReadyTime,
   245  		)
   246  	}
   247  }
   248  
   249  func (lgd *LoadGeneratingConnectionDownload) ClientId() uint64 {
   250  	return lgd.clientId
   251  }
   252  
   253  func (lgd *LoadGeneratingConnectionDownload) TransferredInInterval() (uint64, time.Duration) {
   254  	transferred := atomic.SwapUint64(&lgd.downloaded, 0)
   255  	newIntervalEnd := (time.Now().Sub(lgd.downloadStartTime)).Nanoseconds()
   256  	previousIntervalEnd := atomic.SwapInt64(&lgd.lastIntervalEnd, newIntervalEnd)
   257  	intervalLength := time.Duration(newIntervalEnd - previousIntervalEnd)
   258  	if debug.IsDebug(lgd.debug) {
   259  		fmt.Printf("download: Transferred: %v bytes in %v.\n", transferred, intervalLength)
   260  	}
   261  	return transferred, intervalLength
   262  }
   263  
   264  func (lgd *LoadGeneratingConnectionDownload) Client() *http.Client {
   265  	return lgd.client
   266  }
   267  
   268  type loadGeneratingConnectionDownloadReader struct {
   269  	n        *uint64
   270  	ctx      context.Context
   271  	readable io.Reader
   272  	lgd      *LoadGeneratingConnectionDownload
   273  }
   274  
   275  func (cr *loadGeneratingConnectionDownloadReader) Read(p []byte) (n int, err error) {
   276  	if cr.ctx.Err() != nil {
   277  		return 0, io.EOF
   278  	}
   279  
   280  	if *cr.n == 0 {
   281  		cr.lgd.statusLock.Lock()
   282  		cr.lgd.status = LGC_STATUS_RUNNING
   283  		cr.lgd.statusWaiter.Broadcast()
   284  		cr.lgd.statusLock.Unlock()
   285  	}
   286  
   287  	n, err = cr.readable.Read(p)
   288  	atomic.AddUint64(cr.n, uint64(n))
   289  	return
   290  }
   291  
   292  func (lgd *LoadGeneratingConnectionDownload) Start(
   293  	parentCtx context.Context,
   294  	debugLevel debug.DebugLevel,
   295  ) bool {
   296  	lgd.downloaded = 0
   297  	lgd.debug = debugLevel
   298  	lgd.clientId = utilities.GenerateUniqueId()
   299  
   300  	transport := &http.Transport{
   301  		Proxy: http.ProxyFromEnvironment,
   302  		TLSClientConfig: &tls.Config{
   303  			InsecureSkipVerify: lgd.InsecureSkipVerify,
   304  		},
   305  	}
   306  
   307  	if !utilities.IsInterfaceNil(lgd.KeyLogger) {
   308  		if debug.IsDebug(lgd.debug) {
   309  			fmt.Printf(
   310  				"Using an SSL Key Logger for this load-generating download.\n",
   311  			)
   312  		}
   313  
   314  		// The presence of a custom TLSClientConfig in a *generic* `transport`
   315  		// means that go will default to HTTP/1.1 and cowardly avoid HTTP/2:
   316  		// https://github.com/golang/go/blob/7ca6902c171b336d98adbb103d701a013229c806/src/net/http/transport.go#L278
   317  		// Also, it would appear that the API's choice of HTTP vs HTTP2 can
   318  		// depend on whether the url contains
   319  		// https:// or http://:
   320  		// https://github.com/golang/go/blob/7ca6902c171b336d98adbb103d701a013229c806/src/net/http/transport.go#L74
   321  		transport.TLSClientConfig.KeyLogWriter = lgd.KeyLogger
   322  	}
   323  	transport.TLSClientConfig.InsecureSkipVerify = lgd.InsecureSkipVerify
   324  
   325  	utilities.OverrideHostTransport(transport, lgd.ConnectToAddr)
   326  
   327  	lgd.client = &http.Client{Transport: transport}
   328  	lgd.tracer = traceable.GenerateHttpTimingTracer(lgd, lgd.debug)
   329  
   330  	if debug.IsDebug(lgd.debug) {
   331  		fmt.Printf(
   332  			"Started a load-generating download (id: %v).\n",
   333  			lgd.clientId,
   334  		)
   335  	}
   336  
   337  	go lgd.doDownload(parentCtx)
   338  	return true
   339  }
   340  
   341  func (lgd *LoadGeneratingConnectionDownload) Status() LgcStatus {
   342  	return lgd.status
   343  }
   344  
   345  func (lgd *LoadGeneratingConnectionDownload) Stats() *stats.TraceStats {
   346  	return &lgd.stats
   347  }
   348  
   349  func (lgd *LoadGeneratingConnectionDownload) doDownload(ctx context.Context) error {
   350  	var request *http.Request = nil
   351  	var get *http.Response = nil
   352  	var err error = nil
   353  
   354  	if request, err = http.NewRequestWithContext(
   355  		httptrace.WithClientTrace(ctx, lgd.tracer),
   356  		"GET",
   357  		lgd.URL,
   358  		nil,
   359  	); err != nil {
   360  		lgd.statusLock.Lock()
   361  		lgd.status = LGC_STATUS_ERROR
   362  		lgd.statusWaiter.Broadcast()
   363  		lgd.statusLock.Unlock()
   364  		return err
   365  	}
   366  
   367  	// Used to disable compression
   368  	request.Header.Set("Accept-Encoding", "identity")
   369  	request.Header.Set("User-Agent", utilities.UserAgent())
   370  
   371  	lgd.downloadStartTime = time.Now()
   372  	lgd.lastIntervalEnd = 0
   373  
   374  	if get, err = lgd.client.Do(request); err != nil {
   375  		lgd.statusLock.Lock()
   376  		lgd.status = LGC_STATUS_ERROR
   377  		lgd.statusWaiter.Broadcast()
   378  		lgd.statusLock.Unlock()
   379  		return err
   380  	}
   381  
   382  	// Header.Get returns "" when not set
   383  	if get.Header.Get("Content-Encoding") != "" {
   384  		lgd.statusLock.Lock()
   385  		lgd.status = LGC_STATUS_ERROR
   386  		lgd.statusWaiter.Broadcast()
   387  		lgd.statusLock.Unlock()
   388  		fmt.Printf("Content-Encoding header was set (compression not allowed)")
   389  		return fmt.Errorf("Content-Encoding header was set (compression not allowed)")
   390  	}
   391  	cr := &loadGeneratingConnectionDownloadReader{n: &lgd.downloaded, ctx: ctx, lgd: lgd, readable: get.Body}
   392  	_, _ = io.Copy(io.Discard, cr)
   393  
   394  	lgd.statusLock.Lock()
   395  	lgd.status = LGC_STATUS_DONE
   396  	lgd.statusWaiter.Broadcast()
   397  	lgd.statusLock.Unlock()
   398  
   399  	get.Body.Close()
   400  	if debug.IsDebug(lgd.debug) {
   401  		fmt.Printf("Ending a load-generating download.\n")
   402  	}
   403  
   404  	return nil
   405  }