github.com/google/martian/v3@v3.3.3/context.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package martian 16 17 import ( 18 "bufio" 19 "crypto/rand" 20 "encoding/hex" 21 "fmt" 22 "net" 23 "net/http" 24 "sync" 25 ) 26 27 // Context provides information and storage for a single request/response pair. 28 // Contexts are linked to shared session that is used for multiple requests on 29 // a single connection. 30 type Context struct { 31 session *Session 32 id string 33 34 mu sync.RWMutex 35 vals map[string]interface{} 36 skipRoundTrip bool 37 skipLogging bool 38 apiRequest bool 39 } 40 41 // Session provides information and storage about a connection. 42 type Session struct { 43 mu sync.RWMutex 44 id string 45 secure bool 46 hijacked bool 47 conn net.Conn 48 brw *bufio.ReadWriter 49 vals map[string]interface{} 50 } 51 52 var ( 53 ctxmu sync.RWMutex 54 ctxs = make(map[*http.Request]*Context) 55 ) 56 57 // NewContext returns a context for the in-flight HTTP request. 58 func NewContext(req *http.Request) *Context { 59 ctxmu.RLock() 60 defer ctxmu.RUnlock() 61 62 return ctxs[req] 63 } 64 65 // TestContext builds a new session and associated context and returns the 66 // context and a function to remove the associated context. If it fails to 67 // generate either a new session or a new context it will return an error. 68 // Intended for tests only. 69 func TestContext(req *http.Request, conn net.Conn, bw *bufio.ReadWriter) (ctx *Context, remove func(), err error) { 70 ctxmu.Lock() 71 defer ctxmu.Unlock() 72 73 ctx, ok := ctxs[req] 74 if ok { 75 return ctx, func() { unlink(req) }, nil 76 } 77 78 s, err := newSession(conn, bw) 79 if err != nil { 80 return nil, nil, err 81 } 82 83 ctx, err = withSession(s) 84 if err != nil { 85 return nil, nil, err 86 } 87 88 ctxs[req] = ctx 89 90 return ctx, func() { unlink(req) }, nil 91 } 92 93 // ID returns the session ID. 94 func (s *Session) ID() string { 95 s.mu.RLock() 96 defer s.mu.RUnlock() 97 98 return s.id 99 } 100 101 // IsSecure returns whether the current session is from a secure connection, 102 // such as when receiving requests from a TLS connection that has been MITM'd. 103 func (s *Session) IsSecure() bool { 104 s.mu.RLock() 105 defer s.mu.RUnlock() 106 107 return s.secure 108 } 109 110 // MarkSecure marks the session as secure. 111 func (s *Session) MarkSecure() { 112 s.mu.Lock() 113 defer s.mu.Unlock() 114 115 s.secure = true 116 } 117 118 // MarkInsecure marks the session as insecure. 119 func (s *Session) MarkInsecure() { 120 s.mu.Lock() 121 defer s.mu.Unlock() 122 123 s.secure = false 124 } 125 126 // Hijack takes control of the connection from the proxy. No further action 127 // will be taken by the proxy and the connection will be closed following the 128 // return of the hijacker. 129 func (s *Session) Hijack() (net.Conn, *bufio.ReadWriter, error) { 130 s.mu.Lock() 131 defer s.mu.Unlock() 132 133 if s.hijacked { 134 return nil, nil, fmt.Errorf("martian: session has already been hijacked") 135 } 136 s.hijacked = true 137 138 return s.conn, s.brw, nil 139 } 140 141 // Hijacked returns whether the connection has been hijacked. 142 func (s *Session) Hijacked() bool { 143 s.mu.RLock() 144 defer s.mu.RUnlock() 145 146 return s.hijacked 147 } 148 149 // setConn resets the underlying connection and bufio.ReadWriter of the 150 // session. Used by the proxy when the connection is upgraded to TLS. 151 func (s *Session) setConn(conn net.Conn, brw *bufio.ReadWriter) { 152 s.mu.Lock() 153 defer s.mu.Unlock() 154 155 s.conn = conn 156 s.brw = brw 157 } 158 159 // Get takes key and returns the associated value from the session. 160 func (s *Session) Get(key string) (interface{}, bool) { 161 s.mu.RLock() 162 defer s.mu.RUnlock() 163 164 val, ok := s.vals[key] 165 166 return val, ok 167 } 168 169 // Set takes a key and associates it with val in the session. The value is 170 // persisted for the entire session across multiple requests and responses. 171 func (s *Session) Set(key string, val interface{}) { 172 s.mu.Lock() 173 defer s.mu.Unlock() 174 175 s.vals[key] = val 176 } 177 178 // Session returns the session for the context. 179 func (ctx *Context) Session() *Session { 180 return ctx.session 181 } 182 183 // ID returns the context ID. 184 func (ctx *Context) ID() string { 185 return ctx.id 186 } 187 188 // Get takes key and returns the associated value from the context. 189 func (ctx *Context) Get(key string) (interface{}, bool) { 190 ctx.mu.RLock() 191 defer ctx.mu.RUnlock() 192 193 val, ok := ctx.vals[key] 194 195 return val, ok 196 } 197 198 // Set takes a key and associates it with val in the context. The value is 199 // persisted for the duration of the request and is removed on the following 200 // request. 201 func (ctx *Context) Set(key string, val interface{}) { 202 ctx.mu.Lock() 203 defer ctx.mu.Unlock() 204 205 ctx.vals[key] = val 206 } 207 208 // SkipRoundTrip skips the round trip for the current request. 209 func (ctx *Context) SkipRoundTrip() { 210 ctx.mu.Lock() 211 defer ctx.mu.Unlock() 212 213 ctx.skipRoundTrip = true 214 } 215 216 // SkippingRoundTrip returns whether the current round trip will be skipped. 217 func (ctx *Context) SkippingRoundTrip() bool { 218 ctx.mu.RLock() 219 defer ctx.mu.RUnlock() 220 221 return ctx.skipRoundTrip 222 } 223 224 // SkipLogging skips logging by Martian loggers for the current request. 225 func (ctx *Context) SkipLogging() { 226 ctx.mu.Lock() 227 defer ctx.mu.Unlock() 228 229 ctx.skipLogging = true 230 } 231 232 // SkippingLogging returns whether the current request / response pair will be logged. 233 func (ctx *Context) SkippingLogging() bool { 234 ctx.mu.RLock() 235 defer ctx.mu.RUnlock() 236 237 return ctx.skipLogging 238 } 239 240 // APIRequest marks the requests as a request to the proxy API. 241 func (ctx *Context) APIRequest() { 242 ctx.mu.Lock() 243 defer ctx.mu.Unlock() 244 245 ctx.apiRequest = true 246 } 247 248 // IsAPIRequest returns true when the request patterns matches a pattern in the proxy 249 // mux. The mux is usually defined as a parameter to the api.Forwarder, which uses 250 // http.DefaultServeMux by default. 251 func (ctx *Context) IsAPIRequest() bool { 252 ctx.mu.RLock() 253 defer ctx.mu.RUnlock() 254 255 return ctx.apiRequest 256 } 257 258 // newID creates a new 16 character random hex ID; note these are not UUIDs. 259 func newID() (string, error) { 260 src := make([]byte, 8) 261 if _, err := rand.Read(src); err != nil { 262 return "", err 263 } 264 265 return hex.EncodeToString(src), nil 266 } 267 268 // link associates the context with request. 269 func link(req *http.Request, ctx *Context) { 270 ctxmu.Lock() 271 defer ctxmu.Unlock() 272 273 ctxs[req] = ctx 274 } 275 276 // unlink removes the context for request. 277 func unlink(req *http.Request) { 278 ctxmu.Lock() 279 defer ctxmu.Unlock() 280 281 delete(ctxs, req) 282 } 283 284 // newSession builds a new session. 285 func newSession(conn net.Conn, brw *bufio.ReadWriter) (*Session, error) { 286 sid, err := newID() 287 if err != nil { 288 return nil, err 289 } 290 291 return &Session{ 292 id: sid, 293 conn: conn, 294 brw: brw, 295 vals: make(map[string]interface{}), 296 }, nil 297 } 298 299 // withSession builds a new context from an existing session. Session must be 300 // non-nil. 301 func withSession(s *Session) (*Context, error) { 302 cid, err := newID() 303 if err != nil { 304 return nil, err 305 } 306 307 return &Context{ 308 session: s, 309 id: cid, 310 vals: make(map[string]interface{}), 311 }, nil 312 }