github.com/google/martian/v3@v3.3.3/trafficshape/conn.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package trafficshape
    16  
    17  import (
    18  	"io"
    19  	"net"
    20  	"sort"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/google/martian/v3/log"
    25  )
    26  
    27  // Conn wraps a net.Conn and simulates connection latency and bandwidth
    28  // charateristics.
    29  type Conn struct {
    30  	Context *Context
    31  
    32  	// Shapes represents the traffic shape map inherited from the listener.
    33  	Shapes        *urlShapes
    34  	GlobalBuckets map[string]*Bucket
    35  	// LocalBuckets represents a map from the url_regexes to their dedicated buckets.
    36  	LocalBuckets map[string]*Buckets
    37  	Established  time.Time
    38  	// Established is the time that the connection is established.
    39  	DefaultBandwidth Bandwidth
    40  	Listener         *Listener
    41  	ReadBucket       *Bucket // Shared by listener.
    42  	WriteBucket      *Bucket // Shared by listener.
    43  
    44  	conn    net.Conn
    45  	latency time.Duration
    46  	ronce   sync.Once
    47  	wonce   sync.Once
    48  }
    49  
    50  // Read reads bytes from connection into b, optionally simulating connection
    51  // latency and throttling read throughput based on desired bandwidth
    52  // constraints.
    53  func (c *Conn) Read(b []byte) (int, error) {
    54  	c.ronce.Do(c.sleepLatency)
    55  
    56  	n, err := c.ReadBucket.FillThrottle(func(remaining int64) (int64, error) {
    57  		max := remaining
    58  		if l := int64(len(b)); max > l {
    59  			max = l
    60  		}
    61  
    62  		n, err := c.conn.Read(b[:max])
    63  		return int64(n), err
    64  	})
    65  	if err != nil && err != io.EOF {
    66  		log.Errorf("trafficshape: error on throttled read: %v", err)
    67  	}
    68  
    69  	return int(n), err
    70  }
    71  
    72  // ReadFrom reads data from r until EOF or error, optionally simulating
    73  // connection latency and throttling read throughput based on desired bandwidth
    74  // constraints.
    75  func (c *Conn) ReadFrom(r io.Reader) (int64, error) {
    76  	c.ronce.Do(c.sleepLatency)
    77  
    78  	var total int64
    79  	for {
    80  		n, err := c.ReadBucket.FillThrottle(func(remaining int64) (int64, error) {
    81  			return io.CopyN(c.conn, r, remaining)
    82  		})
    83  
    84  		total += n
    85  
    86  		if err == io.EOF {
    87  			log.Debugf("trafficshape: exhausted reader successfully")
    88  			return total, nil
    89  		} else if err != nil {
    90  			log.Errorf("trafficshape: failed copying from reader: %v", err)
    91  			return total, err
    92  		}
    93  	}
    94  }
    95  
    96  // Close closes the connection.
    97  // Any blocked Read or Write operations will be unblocked and return errors.
    98  func (c *Conn) Close() error {
    99  	return c.conn.Close()
   100  }
   101  
   102  // LocalAddr returns the local network address.
   103  func (c *Conn) LocalAddr() net.Addr {
   104  	return c.conn.LocalAddr()
   105  }
   106  
   107  // RemoteAddr returns the remote network address.
   108  func (c *Conn) RemoteAddr() net.Addr {
   109  	return c.conn.RemoteAddr()
   110  }
   111  
   112  // SetDeadline sets the read and write deadlines associated
   113  // with the connection. It is equivalent to calling both
   114  // SetReadDeadline and SetWriteDeadline.
   115  //
   116  // A deadline is an absolute time after which I/O operations
   117  // fail with a timeout (see type Error) instead of
   118  // blocking. The deadline applies to all future and pending
   119  // I/O, not just the immediately following call to Read or
   120  // Write. After a deadline has been exceeded, the connection
   121  // can be refreshed by setting a deadline in the future.
   122  //
   123  // An idle timeout can be implemented by repeatedly extending
   124  // the deadline after successful Read or Write calls.
   125  //
   126  // A zero value for t means I/O operations will not time out.
   127  //
   128  // Note that if a TCP connection has keep-alive turned on,
   129  // which is the default unless overridden by Dialer.KeepAlive
   130  // or ListenConfig.KeepAlive, then a keep-alive failure may
   131  // also return a timeout error. On Unix systems a keep-alive
   132  // failure on I/O can be detected using
   133  // errors.Is(err, syscall.ETIMEDOUT).
   134  func (c *Conn) SetDeadline(t time.Time) error {
   135  	return c.conn.SetDeadline(t)
   136  }
   137  
   138  // SetReadDeadline sets the deadline for future Read calls
   139  // and any currently-blocked Read call.
   140  // A zero value for t means Read will not time out.
   141  func (c *Conn) SetReadDeadline(t time.Time) error {
   142  	return c.conn.SetReadDeadline(t)
   143  }
   144  
   145  // SetWriteDeadline sets the deadline for future Write calls
   146  // and any currently-blocked Write call.
   147  // Even if write times out, it may return n > 0, indicating that
   148  // some of the data was successfully written.
   149  // A zero value for t means Write will not time out.
   150  func (c *Conn) SetWriteDeadline(t time.Time) error {
   151  	return c.conn.SetWriteDeadline(t)
   152  }
   153  
   154  // GetWrappedConn returns the undrelying trafficshaped net.Conn.
   155  func (c *Conn) GetWrappedConn() net.Conn {
   156  	return c.conn
   157  }
   158  
   159  // WriteTo writes data to w from the connection, optionally simulating
   160  // connection latency and throttling write throughput based on desired
   161  // bandwidth constraints.
   162  func (c *Conn) WriteTo(w io.Writer) (int64, error) {
   163  	c.wonce.Do(c.sleepLatency)
   164  
   165  	var total int64
   166  	for {
   167  		n, err := c.WriteBucket.FillThrottle(func(remaining int64) (int64, error) {
   168  			return io.CopyN(w, c.conn, remaining)
   169  		})
   170  
   171  		total += n
   172  
   173  		if err != nil {
   174  			if err != io.EOF {
   175  				log.Errorf("trafficshape: failed copying to writer: %v", err)
   176  			}
   177  			return total, err
   178  		}
   179  	}
   180  }
   181  
   182  func min(x, y int64) int64 {
   183  	if x < y {
   184  		return x
   185  	}
   186  	return y
   187  }
   188  
   189  // CheckExistenceAndValidity checks that the current url regex is present in the map, and that
   190  // the connection was established before the url shape map was last updated. We do not allow the
   191  // updated url shape map to traffic shape older connections.
   192  // Important: Assumes you have acquired the required locks and will release them youself.
   193  func (c *Conn) CheckExistenceAndValidity(URLRegex string) bool {
   194  	shapeStillValid := c.Shapes.LastModifiedTime.Before(c.Established)
   195  	_, p := c.Shapes.M[URLRegex]
   196  	return p && shapeStillValid
   197  }
   198  
   199  // GetCurrentThrottle uses binary search to determine if the current byte offset ('start')
   200  // lies within a throttle interval. If so, also returns the bandwidth specified for that interval.
   201  func (c *Conn) GetCurrentThrottle(start int64) *ThrottleContext {
   202  	c.Shapes.RLock()
   203  	defer c.Shapes.RUnlock()
   204  
   205  	if !c.CheckExistenceAndValidity(c.Context.URLRegex) {
   206  		log.Debugf("existence check failed")
   207  		return &ThrottleContext{
   208  			ThrottleNow: false,
   209  		}
   210  	}
   211  
   212  	c.Shapes.M[c.Context.URLRegex].RLock()
   213  	defer c.Shapes.M[c.Context.URLRegex].RUnlock()
   214  
   215  	throttles := c.Shapes.M[c.Context.URLRegex].Shape.Throttles
   216  
   217  	if l := len(throttles); l != 0 {
   218  		// ind is the first index in throttles with ByteStart > start.
   219  		// Once we get ind, we can check the previous throttle, if any,
   220  		// to see if its ByteEnd is after 'start'.
   221  		ind := sort.Search(len(throttles),
   222  			func(i int) bool { return throttles[i].ByteStart > start })
   223  
   224  		// All throttles have Bytestart > start, hence not in throttle.
   225  		if ind == 0 {
   226  			return &ThrottleContext{
   227  				ThrottleNow: false,
   228  			}
   229  		}
   230  
   231  		// No throttle has Bytestart > start, so check the last throttle to
   232  		// see if it ends after 'start'. Note: the last throttle is special
   233  		// since it can have -1 (meaning infinity) as the ByteEnd.
   234  		if ind == l {
   235  			if throttles[l-1].ByteEnd > start || throttles[l-1].ByteEnd == -1 {
   236  				return &ThrottleContext{
   237  					ThrottleNow: true,
   238  					Bandwidth:   throttles[l-1].Bandwidth,
   239  				}
   240  			}
   241  			return &ThrottleContext{
   242  				ThrottleNow: false,
   243  			}
   244  		}
   245  
   246  		// Check the previous throttle to see if it ends after 'start'.
   247  		if throttles[ind-1].ByteEnd > start {
   248  			return &ThrottleContext{
   249  				ThrottleNow: true,
   250  				Bandwidth:   throttles[ind-1].Bandwidth,
   251  			}
   252  		}
   253  
   254  		return &ThrottleContext{
   255  			ThrottleNow: false,
   256  		}
   257  	}
   258  
   259  	return &ThrottleContext{
   260  		ThrottleNow: false,
   261  	}
   262  }
   263  
   264  // GetNextActionFromByte takes in a byte offset and uses binary search to determine the upcoming
   265  // action, i.e the first action after the byte that still has a non zero count.
   266  func (c *Conn) GetNextActionFromByte(start int64) *NextActionInfo {
   267  	c.Shapes.RLock()
   268  	defer c.Shapes.RUnlock()
   269  
   270  	if !c.CheckExistenceAndValidity(c.Context.URLRegex) {
   271  		log.Debugf("existence check failed")
   272  		return &NextActionInfo{
   273  			ActionNext: false,
   274  		}
   275  	}
   276  
   277  	c.Shapes.M[c.Context.URLRegex].RLock()
   278  	defer c.Shapes.M[c.Context.URLRegex].RUnlock()
   279  
   280  	actions := c.Shapes.M[c.Context.URLRegex].Shape.Actions
   281  
   282  	if l := len(actions); l != 0 {
   283  		ind := sort.Search(len(actions),
   284  			func(i int) bool { return actions[i].getByte() >= start })
   285  
   286  		return c.GetNextActionFromIndex(int64(ind))
   287  	}
   288  
   289  	return &NextActionInfo{
   290  		ActionNext: false,
   291  	}
   292  }
   293  
   294  // GetNextActionFromIndex takes in an index and returns the first action after the index that
   295  // has a non zero count, if there is one.
   296  func (c *Conn) GetNextActionFromIndex(ind int64) *NextActionInfo {
   297  	c.Shapes.RLock()
   298  	defer c.Shapes.RUnlock()
   299  
   300  	if !c.CheckExistenceAndValidity(c.Context.URLRegex) {
   301  		return &NextActionInfo{
   302  			ActionNext: false,
   303  		}
   304  	}
   305  
   306  	c.Shapes.M[c.Context.URLRegex].RLock()
   307  	defer c.Shapes.M[c.Context.URLRegex].RUnlock()
   308  
   309  	actions := c.Shapes.M[c.Context.URLRegex].Shape.Actions
   310  
   311  	if l := int64(len(actions)); l != 0 {
   312  
   313  		for ind < l && (actions[ind].getCount() == 0) {
   314  			ind++
   315  		}
   316  
   317  		if ind >= l {
   318  			return &NextActionInfo{
   319  				ActionNext: false,
   320  			}
   321  		}
   322  		return &NextActionInfo{
   323  			ActionNext: true,
   324  			Index:      ind,
   325  			ByteOffset: actions[ind].getByte(),
   326  		}
   327  	}
   328  	return &NextActionInfo{
   329  		ActionNext: false,
   330  	}
   331  }
   332  
   333  // WriteDefaultBuckets writes bytes from b to the connection, optionally simulating
   334  // connection latency and throttling write throughput based on desired
   335  // bandwidth constraints. It uses the WriteBucket inherited from the listener.
   336  func (c *Conn) WriteDefaultBuckets(b []byte) (int, error) {
   337  	c.wonce.Do(c.sleepLatency)
   338  
   339  	var total int64
   340  	for len(b) > 0 {
   341  		var max int64
   342  
   343  		n, err := c.WriteBucket.FillThrottle(func(remaining int64) (int64, error) {
   344  			max = remaining
   345  			if l := int64(len(b)); remaining >= l {
   346  				max = l
   347  			}
   348  
   349  			n, err := c.conn.Write(b[:max])
   350  			return int64(n), err
   351  		})
   352  
   353  		total += n
   354  
   355  		if err != nil {
   356  			if err != io.EOF {
   357  				log.Errorf("trafficshape: failed write: %v", err)
   358  			}
   359  			return int(total), err
   360  		}
   361  
   362  		b = b[max:]
   363  	}
   364  
   365  	return int(total), nil
   366  }
   367  
   368  // Write writes bytes from b to the connection, while enforcing throttles and performing actions.
   369  // It uses and updates the Context in the connection.
   370  func (c *Conn) Write(b []byte) (int, error) {
   371  	if !c.Context.Shaping {
   372  		return c.WriteDefaultBuckets(b)
   373  	}
   374  	c.wonce.Do(c.sleepLatency)
   375  	var total int64
   376  
   377  	// Write the header if needed, without enforcing any traffic shaping, and without updating
   378  	// ByteOffset.
   379  	if headerToWrite := c.Context.HeaderLen - c.Context.HeaderBytesWritten; headerToWrite > 0 {
   380  		writeAmount := min(int64(len(b)), headerToWrite)
   381  
   382  		n, err := c.conn.Write(b[:writeAmount])
   383  
   384  		if err != nil {
   385  			if err != io.EOF {
   386  				log.Errorf("trafficshape: failed write: %v", err)
   387  			}
   388  			return int(n), err
   389  		}
   390  		c.Context.HeaderBytesWritten += writeAmount
   391  		total += writeAmount
   392  		b = b[writeAmount:]
   393  	}
   394  
   395  	var amountToWrite int64
   396  
   397  	for len(b) > 0 {
   398  		var max int64
   399  
   400  		// Determine the amount to be written up till the next action.
   401  		amountToWrite = int64(len(b))
   402  		if c.Context.NextActionInfo.ActionNext {
   403  			amountTillNextAction := c.Context.NextActionInfo.ByteOffset - c.Context.ByteOffset
   404  			if amountTillNextAction <= amountToWrite {
   405  				amountToWrite = amountTillNextAction
   406  			}
   407  		}
   408  
   409  		// Write into both the local and global buckets, as well as the underlying connection.
   410  		n, err := c.Context.Buckets.WriteBucket.FillThrottleLocked(func(remaining int64) (int64, error) {
   411  			max = min(remaining, amountToWrite)
   412  
   413  			if max == 0 {
   414  				return 0, nil
   415  			}
   416  
   417  			return c.Context.GlobalBucket.FillThrottleLocked(func(rem int64) (int64, error) {
   418  				max = min(rem, max)
   419  				n, err := c.conn.Write(b[:max])
   420  
   421  				return int64(n), err
   422  			})
   423  		})
   424  
   425  		if err != nil {
   426  			if err != io.EOF {
   427  				log.Errorf("trafficshape: failed write: %v", err)
   428  			}
   429  			return int(total), err
   430  		}
   431  
   432  		// Update the current byte offset.
   433  		c.Context.ByteOffset += n
   434  		total += n
   435  
   436  		b = b[max:]
   437  
   438  		// Check if there was an upcoming action, and that the byte offset matches the action's byte.
   439  		if c.Context.NextActionInfo.ActionNext &&
   440  			c.Context.ByteOffset >= c.Context.NextActionInfo.ByteOffset {
   441  			// Note here, we check again that the url shape map is still valid and that the action still has
   442  			// a non zero count, since that could have been modified since the last time we checked.
   443  			ind := c.Context.NextActionInfo.Index
   444  			c.Shapes.RLock()
   445  			if !c.CheckExistenceAndValidity(c.Context.URLRegex) {
   446  				c.Shapes.RUnlock()
   447  				// Write the remaining b using default buckets, and set Shaping as false
   448  				// so that subsequent calls to Write() also use default buckets
   449  				// without performing any actions.
   450  				c.Context.Shaping = false
   451  				writeTotal, e := c.WriteDefaultBuckets(b)
   452  				return int(total) + writeTotal, e
   453  			}
   454  			c.Shapes.M[c.Context.URLRegex].Lock()
   455  			actions := c.Shapes.M[c.Context.URLRegex].Shape.Actions
   456  			if actions[ind].getCount() != 0 {
   457  				// Update the action count, determine the type of action and perform it.
   458  				actions[ind].decrementCount()
   459  				switch action := actions[ind].(type) {
   460  				case *Halt:
   461  					d := action.Duration
   462  					log.Debugf("trafficshape: Sleeping for time %d ms for urlregex %s at byte offset %d",
   463  						d, c.Context.URLRegex, c.Context.ByteOffset)
   464  					c.Shapes.M[c.Context.URLRegex].Unlock()
   465  					c.Shapes.RUnlock()
   466  					time.Sleep(time.Duration(d) * time.Millisecond)
   467  				case *CloseConnection:
   468  					log.Infof("trafficshape: Closing connection for urlregex %s at byte offset %d",
   469  						c.Context.URLRegex, c.Context.ByteOffset)
   470  					c.Shapes.M[c.Context.URLRegex].Unlock()
   471  					c.Shapes.RUnlock()
   472  					return int(total), &ErrForceClose{message: "Forcing close connection"}
   473  				case *ChangeBandwidth:
   474  					bw := action.Bandwidth
   475  					log.Infof("trafficshape: Changing connection bandwidth to %d for urlregex %s at byte offset %d",
   476  						bw, c.Context.URLRegex, c.Context.ByteOffset)
   477  					c.Shapes.M[c.Context.URLRegex].Unlock()
   478  					c.Shapes.RUnlock()
   479  					c.Context.Buckets.WriteBucket.SetCapacity(bw)
   480  				default:
   481  					c.Shapes.M[c.Context.URLRegex].Unlock()
   482  					c.Shapes.RUnlock()
   483  				}
   484  			} else {
   485  				c.Shapes.M[c.Context.URLRegex].Unlock()
   486  				c.Shapes.RUnlock()
   487  			}
   488  			// Get the next action to be performed, if any.
   489  			c.Context.NextActionInfo = c.GetNextActionFromIndex(ind + 1)
   490  		}
   491  	}
   492  	return int(total), nil
   493  }
   494  
   495  func (c *Conn) sleepLatency() {
   496  	log.Debugf("trafficshape: simulating latency: %s", c.latency)
   497  	time.Sleep(c.latency)
   498  }