github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/layer4/connection.go (about)

     1  // Copyright 2020 Matthew Holt
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package layer4
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"net"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/caddyserver/caddy/v2"
    25  	"go.uber.org/zap"
    26  )
    27  
    28  // WrapConnection wraps an underlying connection into a layer4 connection that
    29  // supports recording and rewinding, as well as adding context with a replacer
    30  // and variable table. This function is intended for use at the start of a
    31  // connection handler chain where the underlying connection is not yet a layer4
    32  // Connection value.
    33  func WrapConnection(underlying net.Conn, buf []byte, logger *zap.Logger) *Connection {
    34  	repl := caddy.NewReplacer()
    35  	repl.Set("l4.conn.remote_addr", underlying.RemoteAddr())
    36  	repl.Set("l4.conn.local_addr", underlying.LocalAddr())
    37  	repl.Set("l4.conn.wrap_time", time.Now().UTC())
    38  
    39  	ctx := context.Background()
    40  	ctx = context.WithValue(ctx, VarsCtxKey, make(map[string]interface{}))
    41  	ctx = context.WithValue(ctx, ReplacerCtxKey, repl)
    42  
    43  	return &Connection{
    44  		Conn:    underlying,
    45  		Context: ctx,
    46  		Logger:  logger,
    47  		buf:     buf,
    48  	}
    49  }
    50  
    51  // Connection contains information about the connection as it
    52  // passes through various handlers. It also has the capability
    53  // of recording and rewinding when necessary.
    54  //
    55  // A Connection can be used as a net.Conn because it embeds a
    56  // net.Conn; but when wrapping underlying connections, usually
    57  // you want to be careful to replace the embedded Conn, not
    58  // this entire Connection value.
    59  //
    60  // Connection structs are NOT safe for concurrent use.
    61  type Connection struct {
    62  	// The underlying connection.
    63  	net.Conn
    64  
    65  	// The context for the connection.
    66  	Context context.Context
    67  
    68  	Logger *zap.Logger
    69  
    70  	buf          []byte // stores matching data
    71  	offset       int
    72  	frozenOffset int
    73  	matching     bool
    74  
    75  	bytesRead, bytesWritten uint64
    76  }
    77  
    78  var ErrConsumedAllPrefetchedBytes = errors.New("consumed all prefetched bytes")
    79  var ErrMatchingBufferFull = errors.New("matching buffer is full")
    80  
    81  // Read implements io.Reader in such a way that reads first
    82  // deplete any associated buffer from the prior recording,
    83  // and once depleted (or if there isn't one), it continues
    84  // reading from the underlying connection.
    85  func (cx *Connection) Read(p []byte) (n int, err error) {
    86  	// if we are matching and consumed the buffer exit with error
    87  	if cx.matching && (len(cx.buf) == 0 || len(cx.buf) == cx.offset) {
    88  		return 0, ErrConsumedAllPrefetchedBytes
    89  	}
    90  
    91  	// if there is a buffer we should read from, start
    92  	// with that; we only read from the underlying conn
    93  	// after the buffer has been "depleted"
    94  	if len(cx.buf) > 0 && cx.offset < len(cx.buf) {
    95  		n := copy(p, cx.buf[cx.offset:])
    96  		cx.offset += n
    97  		if !cx.matching && cx.offset == len(cx.buf) {
    98  			// if we are not in matching mode reset buf automatically after it was consumed
    99  			cx.offset = 0
   100  			cx.buf = cx.buf[:0]
   101  		}
   102  		return n, nil
   103  	}
   104  
   105  	// buffer has been "depleted" so read from
   106  	// underlying connection
   107  	n, err = cx.Conn.Read(p)
   108  	cx.bytesRead += uint64(n)
   109  
   110  	return
   111  }
   112  
   113  func (cx *Connection) Write(p []byte) (n int, err error) {
   114  	n, err = cx.Conn.Write(p)
   115  	cx.bytesWritten += uint64(n)
   116  	return
   117  }
   118  
   119  // Wrap wraps conn in a new Connection based on cx (reusing
   120  // cx's existing buffer and context). This is useful after
   121  // a connection is wrapped by a package that does not support
   122  // our Connection type (for example, `tls.Server()`).
   123  func (cx *Connection) Wrap(conn net.Conn) *Connection {
   124  	return &Connection{
   125  		Conn:         conn,
   126  		Context:      cx.Context,
   127  		Logger:       cx.Logger,
   128  		buf:          cx.buf,
   129  		offset:       cx.offset,
   130  		matching:     cx.matching,
   131  		bytesRead:    cx.bytesRead,
   132  		bytesWritten: cx.bytesWritten,
   133  	}
   134  }
   135  
   136  // prefetch tries to read all bytes that a client initially sent us without blocking.
   137  func (cx *Connection) prefetch() (err error) {
   138  	var n int
   139  
   140  	// read once
   141  	if len(cx.buf) < MaxMatchingBytes {
   142  		free := cap(cx.buf) - len(cx.buf)
   143  		if free >= prefetchChunkSize {
   144  			n, err = cx.Conn.Read(cx.buf[len(cx.buf) : len(cx.buf)+prefetchChunkSize])
   145  			cx.buf = cx.buf[:len(cx.buf)+n]
   146  		} else {
   147  			var tmp []byte
   148  			tmp = bufPool.Get().([]byte)
   149  			tmp = tmp[:prefetchChunkSize]
   150  			defer bufPool.Put(tmp)
   151  
   152  			n, err = cx.Conn.Read(tmp)
   153  			cx.buf = append(cx.buf, tmp[:n]...)
   154  		}
   155  
   156  		cx.bytesRead += uint64(n)
   157  
   158  		if err != nil {
   159  			return err
   160  		}
   161  
   162  		if cx.Logger.Core().Enabled(zap.DebugLevel) {
   163  			cx.Logger.Debug("prefetched",
   164  				zap.String("remote", cx.RemoteAddr().String()),
   165  				zap.Int("bytes", len(cx.buf)),
   166  			)
   167  		}
   168  
   169  		return nil
   170  	}
   171  
   172  	return ErrMatchingBufferFull
   173  }
   174  
   175  // freeze activates the matching mode that only reads from cx.buf.
   176  func (cx *Connection) freeze() {
   177  	cx.matching = true
   178  	cx.frozenOffset = cx.offset
   179  }
   180  
   181  // unfreeze stops the matching mode and resets the buffer offset
   182  // so that the next reads come from the buffer first.
   183  func (cx *Connection) unfreeze() {
   184  	cx.matching = false
   185  	cx.offset = cx.frozenOffset
   186  }
   187  
   188  // SetVar sets a value in the context's variable table with
   189  // the given key. It overwrites any previous value with the
   190  // same key.
   191  func (cx *Connection) SetVar(key string, value interface{}) {
   192  	varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{})
   193  	if !ok {
   194  		return
   195  	}
   196  	varMap[key] = value
   197  }
   198  
   199  // GetVar gets a value from the context's variable table with
   200  // the given key. It returns the value if found, and true if
   201  // it found a value with that key; false otherwise.
   202  func (cx *Connection) GetVar(key string) interface{} {
   203  	varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{})
   204  	if !ok {
   205  		return nil
   206  	}
   207  	return varMap[key]
   208  }
   209  
   210  // MatchingBytes returns all bytes currently available for matching. This is only intended for reading.
   211  // Do not write into the slice. It's a view of the internal buffer, and you will likely mess up the connection.
   212  // Use of this for matching purpose should be accompanied by corresponding error value,
   213  // ErrConsumedAllPrefetchedBytes and ErrMatchingBufferFull, if not matched.
   214  func (cx *Connection) MatchingBytes() []byte {
   215  	return cx.buf[cx.offset:]
   216  }
   217  
   218  var (
   219  	// VarsCtxKey is the key used to store the variables table
   220  	// in a Connection's context.
   221  	VarsCtxKey caddy.CtxKey = "vars"
   222  
   223  	// ReplacerCtxKey is the key used to store the replacer.
   224  	ReplacerCtxKey caddy.CtxKey = "replacer"
   225  
   226  	// listenerCtxKey is the key used to get the listener from a handler
   227  	listenerCtxKey caddy.CtxKey = "listener"
   228  )
   229  
   230  // the prefetch chunk size is a very large 2kb, in order to completely fetch the ~1.7kb X25519Kyber768Draft00 based TLS ClientHello. https://pq.cloudflareresearch.com/
   231  const prefetchChunkSize = 2048
   232  
   233  // MaxMatchingBytes is the amount of bytes that are at most prefetched during matching.
   234  // This is probably most relevant for the http matcher since http requests do not have a size limit.
   235  // 8 KiB should cover most use-cases and is similar to popular webservers.
   236  const MaxMatchingBytes = 8 * 1024
   237  
   238  var bufPool = sync.Pool{
   239  	New: func() interface{} {
   240  		return make([]byte, 0, prefetchChunkSize)
   241  	},
   242  }