github.com/google/martian/v3@v3.3.3/header/via_modifier.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package header
    16  
    17  import (
    18  	"crypto/rand"
    19  	"fmt"
    20  	"io"
    21  	"net/http"
    22  	"regexp"
    23  	"strings"
    24  
    25  	"github.com/google/martian/v3"
    26  )
    27  
    28  const viaLoopKey = "via.LoopDetection"
    29  
    30  var whitespace = regexp.MustCompile("[\t ]+")
    31  
    32  // ViaModifier is a header modifier that checks for proxy redirect loops.
    33  type ViaModifier struct {
    34  	requestedBy string
    35  	boundary    string
    36  }
    37  
    38  // NewViaModifier returns a new Via modifier.
    39  func NewViaModifier(requestedBy string) *ViaModifier {
    40  	return &ViaModifier{
    41  		requestedBy: requestedBy,
    42  		boundary:    randomBoundary(),
    43  	}
    44  }
    45  
    46  // ModifyRequest sets the Via header and provides loop-detection. If Via is
    47  // already present, it will be appended to the existing value. If a loop is
    48  // detected an error is added to the context and the request round trip is
    49  // skipped.
    50  //
    51  // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9
    52  func (m *ViaModifier) ModifyRequest(req *http.Request) error {
    53  	via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary)
    54  
    55  	if v := req.Header.Get("Via"); v != "" {
    56  		if m.hasLoop(v) {
    57  			err := fmt.Errorf("via: detected request loop, header contains %s", via)
    58  
    59  			ctx := martian.NewContext(req)
    60  			ctx.Set(viaLoopKey, err)
    61  			ctx.SkipRoundTrip()
    62  
    63  			return err
    64  		}
    65  
    66  		via = fmt.Sprintf("%s, %s", v, via)
    67  	}
    68  
    69  	req.Header.Set("Via", via)
    70  
    71  	return nil
    72  }
    73  
    74  // ModifyResponse sets the status code to 400 Bad Request if a loop was
    75  // detected in the request.
    76  func (m *ViaModifier) ModifyResponse(res *http.Response) error {
    77  	ctx := martian.NewContext(res.Request)
    78  
    79  	if err, _ := ctx.Get(viaLoopKey); err != nil {
    80  		res.StatusCode = 400
    81  		res.Status = http.StatusText(400)
    82  
    83  		return err.(error)
    84  	}
    85  
    86  	return nil
    87  }
    88  
    89  // hasLoop parses via and attempts to match requestedBy against the contained
    90  // pseudonyms/host:port pairs.
    91  func (m *ViaModifier) hasLoop(via string) bool {
    92  	for _, v := range strings.Split(via, ",") {
    93  		parts := whitespace.Split(strings.TrimSpace(v), 3)
    94  
    95  		// No pseudonym or host:port, assume there is no loop.
    96  		if len(parts) < 2 {
    97  			continue
    98  		}
    99  
   100  		if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] {
   101  			return true
   102  		}
   103  	}
   104  
   105  	return false
   106  }
   107  
   108  // SetBoundary sets the boundary string (random 10 character by default) used to
   109  // disabiguate Martians that are chained together with identical requestedBy values.
   110  // This should only be used for testing.
   111  func (m *ViaModifier) SetBoundary(boundary string) {
   112  	m.boundary = boundary
   113  }
   114  
   115  // randomBoundary generates a 10 character string to ensure that Martians that
   116  // are chained together with the same requestedBy value do not collide.  This func
   117  // panics if io.Readfull fails.
   118  func randomBoundary() string {
   119  	var buf [10]byte
   120  	_, err := io.ReadFull(rand.Reader, buf[:])
   121  	if err != nil {
   122  		panic(err)
   123  	}
   124  	return fmt.Sprintf("%x", buf[:])
   125  }