github.com/google/martian/v3@v3.3.3/header/via_modifier.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package header 16 17 import ( 18 "crypto/rand" 19 "fmt" 20 "io" 21 "net/http" 22 "regexp" 23 "strings" 24 25 "github.com/google/martian/v3" 26 ) 27 28 const viaLoopKey = "via.LoopDetection" 29 30 var whitespace = regexp.MustCompile("[\t ]+") 31 32 // ViaModifier is a header modifier that checks for proxy redirect loops. 33 type ViaModifier struct { 34 requestedBy string 35 boundary string 36 } 37 38 // NewViaModifier returns a new Via modifier. 39 func NewViaModifier(requestedBy string) *ViaModifier { 40 return &ViaModifier{ 41 requestedBy: requestedBy, 42 boundary: randomBoundary(), 43 } 44 } 45 46 // ModifyRequest sets the Via header and provides loop-detection. If Via is 47 // already present, it will be appended to the existing value. If a loop is 48 // detected an error is added to the context and the request round trip is 49 // skipped. 50 // 51 // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9 52 func (m *ViaModifier) ModifyRequest(req *http.Request) error { 53 via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary) 54 55 if v := req.Header.Get("Via"); v != "" { 56 if m.hasLoop(v) { 57 err := fmt.Errorf("via: detected request loop, header contains %s", via) 58 59 ctx := martian.NewContext(req) 60 ctx.Set(viaLoopKey, err) 61 ctx.SkipRoundTrip() 62 63 return err 64 } 65 66 via = fmt.Sprintf("%s, %s", v, via) 67 } 68 69 req.Header.Set("Via", via) 70 71 return nil 72 } 73 74 // ModifyResponse sets the status code to 400 Bad Request if a loop was 75 // detected in the request. 76 func (m *ViaModifier) ModifyResponse(res *http.Response) error { 77 ctx := martian.NewContext(res.Request) 78 79 if err, _ := ctx.Get(viaLoopKey); err != nil { 80 res.StatusCode = 400 81 res.Status = http.StatusText(400) 82 83 return err.(error) 84 } 85 86 return nil 87 } 88 89 // hasLoop parses via and attempts to match requestedBy against the contained 90 // pseudonyms/host:port pairs. 91 func (m *ViaModifier) hasLoop(via string) bool { 92 for _, v := range strings.Split(via, ",") { 93 parts := whitespace.Split(strings.TrimSpace(v), 3) 94 95 // No pseudonym or host:port, assume there is no loop. 96 if len(parts) < 2 { 97 continue 98 } 99 100 if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] { 101 return true 102 } 103 } 104 105 return false 106 } 107 108 // SetBoundary sets the boundary string (random 10 character by default) used to 109 // disabiguate Martians that are chained together with identical requestedBy values. 110 // This should only be used for testing. 111 func (m *ViaModifier) SetBoundary(boundary string) { 112 m.boundary = boundary 113 } 114 115 // randomBoundary generates a 10 character string to ensure that Martians that 116 // are chained together with the same requestedBy value do not collide. This func 117 // panics if io.Readfull fails. 118 func randomBoundary() string { 119 var buf [10]byte 120 _, err := io.ReadFull(rand.Reader, buf[:]) 121 if err != nil { 122 panic(err) 123 } 124 return fmt.Sprintf("%x", buf[:]) 125 }