github.com/google/martian/v3@v3.3.3/context.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package martian
    16  
    17  import (
    18  	"bufio"
    19  	"crypto/rand"
    20  	"encoding/hex"
    21  	"fmt"
    22  	"net"
    23  	"net/http"
    24  	"sync"
    25  )
    26  
    27  // Context provides information and storage for a single request/response pair.
    28  // Contexts are linked to shared session that is used for multiple requests on
    29  // a single connection.
    30  type Context struct {
    31  	session *Session
    32  	id      string
    33  
    34  	mu            sync.RWMutex
    35  	vals          map[string]interface{}
    36  	skipRoundTrip bool
    37  	skipLogging   bool
    38  	apiRequest    bool
    39  }
    40  
    41  // Session provides information and storage about a connection.
    42  type Session struct {
    43  	mu       sync.RWMutex
    44  	id       string
    45  	secure   bool
    46  	hijacked bool
    47  	conn     net.Conn
    48  	brw      *bufio.ReadWriter
    49  	vals     map[string]interface{}
    50  }
    51  
    52  var (
    53  	ctxmu sync.RWMutex
    54  	ctxs  = make(map[*http.Request]*Context)
    55  )
    56  
    57  // NewContext returns a context for the in-flight HTTP request.
    58  func NewContext(req *http.Request) *Context {
    59  	ctxmu.RLock()
    60  	defer ctxmu.RUnlock()
    61  
    62  	return ctxs[req]
    63  }
    64  
    65  // TestContext builds a new session and associated context and returns the
    66  // context and a function to remove the associated context. If it fails to
    67  // generate either a new session or a new context it will return an error.
    68  // Intended for tests only.
    69  func TestContext(req *http.Request, conn net.Conn, bw *bufio.ReadWriter) (ctx *Context, remove func(), err error) {
    70  	ctxmu.Lock()
    71  	defer ctxmu.Unlock()
    72  
    73  	ctx, ok := ctxs[req]
    74  	if ok {
    75  		return ctx, func() { unlink(req) }, nil
    76  	}
    77  
    78  	s, err := newSession(conn, bw)
    79  	if err != nil {
    80  		return nil, nil, err
    81  	}
    82  
    83  	ctx, err = withSession(s)
    84  	if err != nil {
    85  		return nil, nil, err
    86  	}
    87  
    88  	ctxs[req] = ctx
    89  
    90  	return ctx, func() { unlink(req) }, nil
    91  }
    92  
    93  // ID returns the session ID.
    94  func (s *Session) ID() string {
    95  	s.mu.RLock()
    96  	defer s.mu.RUnlock()
    97  
    98  	return s.id
    99  }
   100  
   101  // IsSecure returns whether the current session is from a secure connection,
   102  // such as when receiving requests from a TLS connection that has been MITM'd.
   103  func (s *Session) IsSecure() bool {
   104  	s.mu.RLock()
   105  	defer s.mu.RUnlock()
   106  
   107  	return s.secure
   108  }
   109  
   110  // MarkSecure marks the session as secure.
   111  func (s *Session) MarkSecure() {
   112  	s.mu.Lock()
   113  	defer s.mu.Unlock()
   114  
   115  	s.secure = true
   116  }
   117  
   118  // MarkInsecure marks the session as insecure.
   119  func (s *Session) MarkInsecure() {
   120  	s.mu.Lock()
   121  	defer s.mu.Unlock()
   122  
   123  	s.secure = false
   124  }
   125  
   126  // Hijack takes control of the connection from the proxy. No further action
   127  // will be taken by the proxy and the connection will be closed following the
   128  // return of the hijacker.
   129  func (s *Session) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   130  	s.mu.Lock()
   131  	defer s.mu.Unlock()
   132  
   133  	if s.hijacked {
   134  		return nil, nil, fmt.Errorf("martian: session has already been hijacked")
   135  	}
   136  	s.hijacked = true
   137  
   138  	return s.conn, s.brw, nil
   139  }
   140  
   141  // Hijacked returns whether the connection has been hijacked.
   142  func (s *Session) Hijacked() bool {
   143  	s.mu.RLock()
   144  	defer s.mu.RUnlock()
   145  
   146  	return s.hijacked
   147  }
   148  
   149  // setConn resets the underlying connection and bufio.ReadWriter of the
   150  // session. Used by the proxy when the connection is upgraded to TLS.
   151  func (s *Session) setConn(conn net.Conn, brw *bufio.ReadWriter) {
   152  	s.mu.Lock()
   153  	defer s.mu.Unlock()
   154  
   155  	s.conn = conn
   156  	s.brw = brw
   157  }
   158  
   159  // Get takes key and returns the associated value from the session.
   160  func (s *Session) Get(key string) (interface{}, bool) {
   161  	s.mu.RLock()
   162  	defer s.mu.RUnlock()
   163  
   164  	val, ok := s.vals[key]
   165  
   166  	return val, ok
   167  }
   168  
   169  // Set takes a key and associates it with val in the session. The value is
   170  // persisted for the entire session across multiple requests and responses.
   171  func (s *Session) Set(key string, val interface{}) {
   172  	s.mu.Lock()
   173  	defer s.mu.Unlock()
   174  
   175  	s.vals[key] = val
   176  }
   177  
   178  // Session returns the session for the context.
   179  func (ctx *Context) Session() *Session {
   180  	return ctx.session
   181  }
   182  
   183  // ID returns the context ID.
   184  func (ctx *Context) ID() string {
   185  	return ctx.id
   186  }
   187  
   188  // Get takes key and returns the associated value from the context.
   189  func (ctx *Context) Get(key string) (interface{}, bool) {
   190  	ctx.mu.RLock()
   191  	defer ctx.mu.RUnlock()
   192  
   193  	val, ok := ctx.vals[key]
   194  
   195  	return val, ok
   196  }
   197  
   198  // Set takes a key and associates it with val in the context. The value is
   199  // persisted for the duration of the request and is removed on the following
   200  // request.
   201  func (ctx *Context) Set(key string, val interface{}) {
   202  	ctx.mu.Lock()
   203  	defer ctx.mu.Unlock()
   204  
   205  	ctx.vals[key] = val
   206  }
   207  
   208  // SkipRoundTrip skips the round trip for the current request.
   209  func (ctx *Context) SkipRoundTrip() {
   210  	ctx.mu.Lock()
   211  	defer ctx.mu.Unlock()
   212  
   213  	ctx.skipRoundTrip = true
   214  }
   215  
   216  // SkippingRoundTrip returns whether the current round trip will be skipped.
   217  func (ctx *Context) SkippingRoundTrip() bool {
   218  	ctx.mu.RLock()
   219  	defer ctx.mu.RUnlock()
   220  
   221  	return ctx.skipRoundTrip
   222  }
   223  
   224  // SkipLogging skips logging by Martian loggers for the current request.
   225  func (ctx *Context) SkipLogging() {
   226  	ctx.mu.Lock()
   227  	defer ctx.mu.Unlock()
   228  
   229  	ctx.skipLogging = true
   230  }
   231  
   232  // SkippingLogging returns whether the current request / response pair will be logged.
   233  func (ctx *Context) SkippingLogging() bool {
   234  	ctx.mu.RLock()
   235  	defer ctx.mu.RUnlock()
   236  
   237  	return ctx.skipLogging
   238  }
   239  
   240  // APIRequest marks the requests as a request to the proxy API.
   241  func (ctx *Context) APIRequest() {
   242  	ctx.mu.Lock()
   243  	defer ctx.mu.Unlock()
   244  
   245  	ctx.apiRequest = true
   246  }
   247  
   248  // IsAPIRequest returns true when the request patterns matches a pattern in the proxy
   249  // mux. The mux is usually defined as a parameter to the api.Forwarder, which uses
   250  // http.DefaultServeMux by default.
   251  func (ctx *Context) IsAPIRequest() bool {
   252  	ctx.mu.RLock()
   253  	defer ctx.mu.RUnlock()
   254  
   255  	return ctx.apiRequest
   256  }
   257  
   258  // newID creates a new 16 character random hex ID; note these are not UUIDs.
   259  func newID() (string, error) {
   260  	src := make([]byte, 8)
   261  	if _, err := rand.Read(src); err != nil {
   262  		return "", err
   263  	}
   264  
   265  	return hex.EncodeToString(src), nil
   266  }
   267  
   268  // link associates the context with request.
   269  func link(req *http.Request, ctx *Context) {
   270  	ctxmu.Lock()
   271  	defer ctxmu.Unlock()
   272  
   273  	ctxs[req] = ctx
   274  }
   275  
   276  // unlink removes the context for request.
   277  func unlink(req *http.Request) {
   278  	ctxmu.Lock()
   279  	defer ctxmu.Unlock()
   280  
   281  	delete(ctxs, req)
   282  }
   283  
   284  // newSession builds a new session.
   285  func newSession(conn net.Conn, brw *bufio.ReadWriter) (*Session, error) {
   286  	sid, err := newID()
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  
   291  	return &Session{
   292  		id:   sid,
   293  		conn: conn,
   294  		brw:  brw,
   295  		vals: make(map[string]interface{}),
   296  	}, nil
   297  }
   298  
   299  // withSession builds a new context from an existing session. Session must be
   300  // non-nil.
   301  func withSession(s *Session) (*Context, error) {
   302  	cid, err := newID()
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  
   307  	return &Context{
   308  		session: s,
   309  		id:      cid,
   310  		vals:    make(map[string]interface{}),
   311  	}, nil
   312  }