github.com/google/martian/v3@v3.3.3/auth/auth_filter.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package auth provides filtering support for a martian.Proxy based on auth 16 // ID. 17 package auth 18 19 import ( 20 "fmt" 21 "net/http" 22 "sync" 23 24 "github.com/google/martian/v3" 25 ) 26 27 // Filter filters RequestModifiers and ResponseModifiers by auth ID. 28 type Filter struct { 29 authRequired bool 30 31 mu sync.RWMutex 32 reqmods map[string]martian.RequestModifier 33 resmods map[string]martian.ResponseModifier 34 } 35 36 // NewFilter returns a new auth.Filter. 37 func NewFilter() *Filter { 38 return &Filter{ 39 reqmods: make(map[string]martian.RequestModifier), 40 resmods: make(map[string]martian.ResponseModifier), 41 } 42 } 43 44 // SetAuthRequired determines whether the auth ID must have an associated 45 // RequestModifier or ResponseModifier. If true, it will set auth error. 46 func (f *Filter) SetAuthRequired(required bool) { 47 f.authRequired = required 48 } 49 50 // SetRequestModifier sets the RequestModifier for the given ID. It will 51 // overwrite any existing modifier with the same ID. 52 func (f *Filter) SetRequestModifier(id string, reqmod martian.RequestModifier) error { 53 f.mu.Lock() 54 defer f.mu.Unlock() 55 56 if reqmod != nil { 57 f.reqmods[id] = reqmod 58 } else { 59 delete(f.reqmods, id) 60 } 61 62 return nil 63 } 64 65 // SetResponseModifier sets the ResponseModifier for the given ID. It will 66 // overwrite any existing modifier with the same ID. 67 func (f *Filter) SetResponseModifier(id string, resmod martian.ResponseModifier) error { 68 f.mu.Lock() 69 defer f.mu.Unlock() 70 71 if resmod != nil { 72 f.resmods[id] = resmod 73 } else { 74 delete(f.resmods, id) 75 } 76 77 return nil 78 } 79 80 // RequestModifier retrieves the RequestModifier for the given ID. Returns nil 81 // if no modifier exists for the given ID. 82 func (f *Filter) RequestModifier(id string) martian.RequestModifier { 83 f.mu.RLock() 84 defer f.mu.RUnlock() 85 86 return f.reqmods[id] 87 } 88 89 // ResponseModifier retrieves the ResponseModifier for the given ID. Returns nil 90 // if no modifier exists for the given ID. 91 func (f *Filter) ResponseModifier(id string) martian.ResponseModifier { 92 f.mu.RLock() 93 defer f.mu.RUnlock() 94 95 return f.resmods[id] 96 } 97 98 // ModifyRequest runs the RequestModifier for the associated auth ID. If no 99 // modifier is found for auth ID then auth error is set. 100 func (f *Filter) ModifyRequest(req *http.Request) error { 101 ctx := martian.NewContext(req) 102 actx := FromContext(ctx) 103 104 if reqmod, ok := f.reqmods[actx.ID()]; ok { 105 return reqmod.ModifyRequest(req) 106 } 107 108 if err := f.requireKnownAuth(actx.ID()); err != nil { 109 actx.SetError(err) 110 } 111 112 return nil 113 } 114 115 // ModifyResponse runs the ResponseModifier for the associated auth ID. If no 116 // modifier is found for the auth ID then the auth error is set. 117 func (f *Filter) ModifyResponse(res *http.Response) error { 118 ctx := martian.NewContext(res.Request) 119 actx := FromContext(ctx) 120 121 if resmod, ok := f.resmods[actx.ID()]; ok { 122 return resmod.ModifyResponse(res) 123 } 124 125 if err := f.requireKnownAuth(actx.ID()); err != nil { 126 actx.SetError(err) 127 } 128 129 return nil 130 } 131 132 func (f *Filter) requireKnownAuth(id string) error { 133 _, reqok := f.reqmods[id] 134 _, resok := f.resmods[id] 135 136 if !reqok && !resok && f.authRequired { 137 return fmt.Errorf("auth: unrecognized credentials: %s", id) 138 } 139 140 return nil 141 }