github.com/google/martian/v3@v3.3.3/priority/priority_group_test.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package priority 16 17 import ( 18 "errors" 19 "net/http" 20 "reflect" 21 "testing" 22 23 "github.com/google/martian/v3/martiantest" 24 "github.com/google/martian/v3/parse" 25 "github.com/google/martian/v3/proxyutil" 26 27 // Import to register header.Modifier with JSON parser. 28 _ "github.com/google/martian/v3/header" 29 ) 30 31 func TestPriorityGroupModifyRequest(t *testing.T) { 32 var order []string 33 34 pg := NewGroup() 35 36 tm50 := martiantest.NewModifier() 37 tm50.RequestFunc(func(*http.Request) { 38 order = append(order, "tm50") 39 }) 40 pg.AddRequestModifier(tm50, 50) 41 42 tm100a := martiantest.NewModifier() 43 tm100a.RequestFunc(func(*http.Request) { 44 order = append(order, "tm100a") 45 }) 46 pg.AddRequestModifier(tm100a, 100) 47 48 tm100b := martiantest.NewModifier() 49 tm100b.RequestFunc(func(*http.Request) { 50 order = append(order, "tm100b") 51 }) 52 pg.AddRequestModifier(tm100b, 100) 53 54 tm75 := martiantest.NewModifier() 55 tm75.RequestFunc(func(*http.Request) { 56 order = append(order, "tm75") 57 }) 58 59 if err := pg.RemoveRequestModifier(tm75); err != ErrModifierNotFound { 60 t.Fatalf("RemoveRequestModifier(): got %v, want ErrModifierNotFound", err) 61 } 62 63 pg.AddRequestModifier(tm75, 100) 64 65 if err := pg.RemoveRequestModifier(tm75); err != nil { 66 t.Fatalf("RemoveRequestModifier(): got %v, want no error", err) 67 } 68 69 req, err := http.NewRequest("GET", "http://example.com/", nil) 70 if err != nil { 71 t.Fatalf("http.NewRequest(): got %v, want no error", err) 72 } 73 if err := pg.ModifyRequest(req); err != nil { 74 t.Fatalf("ModifyRequest(): got %v, want no error", err) 75 } 76 if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) { 77 t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want) 78 } 79 } 80 81 func TestPriorityGroupModifyRequestHaltsOnError(t *testing.T) { 82 pg := NewGroup() 83 84 reqerr := errors.New("request error") 85 tm := martiantest.NewModifier() 86 tm.RequestError(reqerr) 87 88 pg.AddRequestModifier(tm, 100) 89 90 tm2 := martiantest.NewModifier() 91 pg.AddRequestModifier(tm2, 75) 92 93 req, err := http.NewRequest("GET", "http://example.com/", nil) 94 if err != nil { 95 t.Fatalf("http.NewRequest(): got %v, want no error", err) 96 } 97 if err := pg.ModifyRequest(req); err != reqerr { 98 t.Fatalf("ModifyRequest(): got %v, want %v", err, reqerr) 99 } 100 101 if tm2.RequestModified() { 102 t.Error("tm2.RequestModified(): got true, want false") 103 } 104 } 105 106 func TestPriorityGroupModifyResponse(t *testing.T) { 107 var order []string 108 109 pg := NewGroup() 110 111 tm50 := martiantest.NewModifier() 112 tm50.ResponseFunc(func(*http.Response) { 113 order = append(order, "tm50") 114 }) 115 pg.AddResponseModifier(tm50, 50) 116 117 tm100a := martiantest.NewModifier() 118 tm100a.ResponseFunc(func(*http.Response) { 119 order = append(order, "tm100a") 120 }) 121 pg.AddResponseModifier(tm100a, 100) 122 123 tm100b := martiantest.NewModifier() 124 tm100b.ResponseFunc(func(*http.Response) { 125 order = append(order, "tm100b") 126 }) 127 pg.AddResponseModifier(tm100b, 100) 128 129 tm75 := martiantest.NewModifier() 130 tm75.ResponseFunc(func(*http.Response) { 131 order = append(order, "tm75") 132 }) 133 134 if err := pg.RemoveResponseModifier(tm75); err != ErrModifierNotFound { 135 t.Fatalf("RemoveResponseModifier(): got %v, want ErrModifierNotFound", err) 136 } 137 138 pg.AddResponseModifier(tm75, 100) 139 140 if err := pg.RemoveResponseModifier(tm75); err != nil { 141 t.Fatalf("RemoveResponseModifier(): got %v, want no error", err) 142 } 143 144 res := proxyutil.NewResponse(200, nil, nil) 145 if err := pg.ModifyResponse(res); err != nil { 146 t.Fatalf("ModifyResponse(): got %v, want no error", err) 147 } 148 if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) { 149 t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want) 150 } 151 } 152 153 func TestPriorityGroupModifyResponseHaltsOnError(t *testing.T) { 154 pg := NewGroup() 155 156 reserr := errors.New("response error") 157 tm := martiantest.NewModifier() 158 tm.ResponseError(reserr) 159 160 pg.AddResponseModifier(tm, 100) 161 162 tm2 := martiantest.NewModifier() 163 pg.AddResponseModifier(tm2, 75) 164 165 res := proxyutil.NewResponse(200, nil, nil) 166 if err := pg.ModifyResponse(res); err != reserr { 167 t.Fatalf("ModifyRequest(): got %v, want %v", err, reserr) 168 } 169 170 if tm2.ResponseModified() { 171 t.Error("tm2.ResponseModified(): got true, want false") 172 } 173 } 174 175 func TestGroupFromJSON(t *testing.T) { 176 msg := []byte(`{ 177 "priority.Group": { 178 "scope": ["request", "response"], 179 "modifiers": [ 180 { 181 "priority": 100, 182 "modifier": { 183 "header.Modifier": { 184 "scope": ["request", "response"], 185 "name": "X-Testing", 186 "value": "true" 187 } 188 } 189 }, 190 { 191 "priority": 0, 192 "modifier": { 193 "header.Modifier": { 194 "scope": ["request", "response"], 195 "name": "Y-Testing", 196 "value": "true" 197 } 198 } 199 } 200 ] 201 } 202 }`) 203 204 r, err := parse.FromJSON(msg) 205 if err != nil { 206 t.Fatalf("parse.FromJSON(): got %v, want no error", err) 207 } 208 209 reqmod := r.RequestModifier() 210 if reqmod == nil { 211 t.Fatal("reqmod: got nil, want not nil") 212 } 213 214 req, err := http.NewRequest("GET", "http://example.com", nil) 215 if err != nil { 216 t.Fatalf("http.NewRequest(): got %v, want no error", err) 217 } 218 if err := reqmod.ModifyRequest(req); err != nil { 219 t.Fatalf("ModifyRequest(): got %v, want no error", err) 220 } 221 if got, want := req.Header.Get("X-Testing"), "true"; got != want { 222 t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Testing", got, want) 223 } 224 if got, want := req.Header.Get("Y-Testing"), "true"; got != want { 225 t.Errorf("req.Header.Get(%q): got %q, want %q", "Y-Testing", got, want) 226 } 227 228 resmod := r.ResponseModifier() 229 if resmod == nil { 230 t.Fatal("resmod: got nil, want not nil") 231 } 232 233 res := proxyutil.NewResponse(200, nil, req) 234 if err := resmod.ModifyResponse(res); err != nil { 235 t.Fatalf("ModifyResponse(): got %v, want no error", err) 236 } 237 if got, want := res.Header.Get("X-Testing"), "true"; got != want { 238 t.Errorf("res.Header.Get(%q): got %q, want %q", "X-Testing", got, want) 239 } 240 if got, want := res.Header.Get("Y-Testing"), "true"; got != want { 241 t.Errorf("res.Header.Get(%q): got %q, want %q", "Y-Testing", got, want) 242 } 243 }