github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/e2e/vault/vault_test.go (about) 1 package vault 2 3 import ( 4 "archive/zip" 5 "bytes" 6 "encoding/json" 7 "flag" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "net/http" 12 "os" 13 "path/filepath" 14 "runtime" 15 "sort" 16 "testing" 17 "time" 18 19 "github.com/hashicorp/go-version" 20 "github.com/hashicorp/nomad/command/agent" 21 "github.com/hashicorp/nomad/helper" 22 "github.com/hashicorp/nomad/nomad/structs/config" 23 "github.com/hashicorp/nomad/testutil" 24 vapi "github.com/hashicorp/vault/api" 25 "github.com/stretchr/testify/require" 26 ) 27 28 var ( 29 integration = flag.Bool("integration", false, "run integration tests") 30 minVaultVer = version.Must(version.NewVersion("0.6.2")) 31 ) 32 33 // syncVault discovers available versions of Vault, downloads the binaries, 34 // returns a map of version to binary path as well as a sorted list of 35 // versions. 36 func syncVault(t *testing.T) ([]*version.Version, map[string]string) { 37 38 binDir := filepath.Join(os.TempDir(), "vault-bins/") 39 40 urls := vaultVersions(t) 41 42 sorted, versions, err := pruneVersions(urls) 43 require.NoError(t, err) 44 45 // Get the binaries we need to download 46 missing, err := missingVault(binDir, versions) 47 require.NoError(t, err) 48 49 // Create the directory for the binaries 50 require.NoError(t, createBinDir(binDir)) 51 52 // Download in parallel 53 start := time.Now() 54 errCh := make(chan error, len(missing)) 55 for ver, url := range missing { 56 go func(dst, url string) { 57 errCh <- getVault(dst, url) 58 }(filepath.Join(binDir, ver), url) 59 } 60 for i := 0; i < len(missing); i++ { 61 select { 62 case err := <-errCh: 63 require.NoError(t, err) 64 case <-time.After(5 * time.Minute): 65 t.Fatalf("timed out downloading Vault binaries") 66 } 67 } 68 if n := len(missing); n > 0 { 69 t.Logf("Downloaded %d versions of Vault in %s", n, time.Now().Sub(start)) 70 } 71 72 binaries := make(map[string]string, len(versions)) 73 for ver, _ := range versions { 74 binaries[ver] = filepath.Join(binDir, ver) 75 } 76 return sorted, binaries 77 } 78 79 // vaultVersions discovers available Vault versions from releases.hashicorp.com 80 // and returns a map of version to url. 81 func vaultVersions(t *testing.T) map[string]string { 82 resp, err := http.Get("https://releases.hashicorp.com/vault/index.json") 83 require.NoError(t, err) 84 85 respJson := struct { 86 Versions map[string]struct { 87 Builds []struct { 88 Version string `json:"version"` 89 Os string `json:"os"` 90 Arch string `json:"arch"` 91 URL string `json:"url"` 92 } `json:"builds"` 93 } 94 }{} 95 require.NoError(t, json.NewDecoder(resp.Body).Decode(&respJson)) 96 require.NoError(t, resp.Body.Close()) 97 98 versions := map[string]string{} 99 for vk, vv := range respJson.Versions { 100 gover, err := version.NewVersion(vk) 101 if err != nil { 102 t.Logf("error parsing Vault version %q -> %v", vk, err) 103 continue 104 } 105 106 // Skip ancient versions 107 if gover.LessThan(minVaultVer) { 108 continue 109 } 110 111 // Skip prerelease and enterprise versions 112 if gover.Prerelease() != "" || gover.Metadata() != "" { 113 continue 114 } 115 116 url := "" 117 for _, b := range vv.Builds { 118 buildver, err := version.NewVersion(b.Version) 119 if err != nil { 120 t.Logf("error parsing Vault build version %q -> %v", b.Version, err) 121 continue 122 } 123 124 if buildver.Prerelease() != "" { 125 continue 126 } 127 128 if buildver.Metadata() != "" { 129 continue 130 } 131 132 if b.Os != runtime.GOOS { 133 continue 134 } 135 136 if b.Arch != runtime.GOARCH { 137 continue 138 } 139 140 // Match! 141 url = b.URL 142 break 143 } 144 145 if url != "" { 146 versions[vk] = url 147 } 148 } 149 150 return versions 151 } 152 153 // pruneVersions only takes the latest Z for each X.Y.Z release. Returns a 154 // sorted list and map of kept versions. 155 func pruneVersions(all map[string]string) ([]*version.Version, map[string]string, error) { 156 if len(all) == 0 { 157 return nil, nil, fmt.Errorf("0 Vault versions") 158 } 159 160 sorted := make([]*version.Version, 0, len(all)) 161 162 for k := range all { 163 sorted = append(sorted, version.Must(version.NewVersion(k))) 164 } 165 166 sort.Sort(version.Collection(sorted)) 167 168 keep := make([]*version.Version, 0, len(all)) 169 170 for _, v := range sorted { 171 segments := v.Segments() 172 if len(segments) < 3 { 173 // Drop malformed versions 174 continue 175 } 176 177 if len(keep) == 0 { 178 keep = append(keep, v) 179 continue 180 } 181 182 last := keep[len(keep)-1].Segments() 183 184 if segments[0] == last[0] && segments[1] == last[1] { 185 // current X.Y == last X.Y, replace last with current 186 keep[len(keep)-1] = v 187 } else { 188 // current X.Y != last X.Y, append 189 keep = append(keep, v) 190 } 191 } 192 193 // Create a new map of canonicalized versions to urls 194 urls := make(map[string]string, len(keep)) 195 for _, v := range keep { 196 origURL := all[v.Original()] 197 if origURL == "" { 198 return nil, nil, fmt.Errorf("missing version %s", v.Original()) 199 } 200 urls[v.String()] = origURL 201 } 202 203 return keep, urls, nil 204 } 205 206 // createBinDir creates the binary directory 207 func createBinDir(binDir string) error { 208 // Check if the directory exists, otherwise create it 209 f, err := os.Stat(binDir) 210 if err != nil && !os.IsNotExist(err) { 211 return fmt.Errorf("failed to stat directory: %v", err) 212 } 213 214 if f != nil && f.IsDir() { 215 return nil 216 } else if f != nil { 217 if err := os.RemoveAll(binDir); err != nil { 218 return fmt.Errorf("failed to remove file at directory path: %v", err) 219 } 220 } 221 222 // Create the directory 223 if err := os.Mkdir(binDir, 075); err != nil { 224 return fmt.Errorf("failed to make directory: %v", err) 225 } 226 if err := os.Chmod(binDir, 0755); err != nil { 227 return fmt.Errorf("failed to chmod: %v", err) 228 } 229 230 return nil 231 } 232 233 // missingVault returns the binaries that must be downloaded. versions key must 234 // be the Vault version. 235 func missingVault(binDir string, versions map[string]string) (map[string]string, error) { 236 files, err := ioutil.ReadDir(binDir) 237 if err != nil { 238 if os.IsNotExist(err) { 239 return versions, nil 240 } 241 242 return nil, fmt.Errorf("failed to stat directory: %v", err) 243 } 244 245 // Copy versions so we don't mutate it 246 missingSet := make(map[string]string, len(versions)) 247 for k, v := range versions { 248 missingSet[k] = v 249 } 250 251 for _, f := range files { 252 delete(missingSet, f.Name()) 253 } 254 255 return missingSet, nil 256 } 257 258 // getVault downloads the given Vault binary 259 func getVault(dst, url string) error { 260 resp, err := http.Get(url) 261 if err != nil { 262 return err 263 } 264 defer resp.Body.Close() 265 266 // Wrap in an in-mem buffer 267 b := bytes.NewBuffer(nil) 268 if _, err := io.Copy(b, resp.Body); err != nil { 269 return fmt.Errorf("error reading response body: %v", err) 270 } 271 resp.Body.Close() 272 273 zreader, err := zip.NewReader(bytes.NewReader(b.Bytes()), resp.ContentLength) 274 if err != nil { 275 return err 276 } 277 278 if l := len(zreader.File); l != 1 { 279 return fmt.Errorf("unexpected number of files in zip: %v", l) 280 } 281 282 // Copy the file to its destination 283 out, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0777) 284 if err != nil { 285 return err 286 } 287 defer out.Close() 288 289 zfile, err := zreader.File[0].Open() 290 if err != nil { 291 return fmt.Errorf("failed to open zip file: %v", err) 292 } 293 294 if _, err := io.Copy(out, zfile); err != nil { 295 return fmt.Errorf("failed to decompress file to destination: %v", err) 296 } 297 298 return nil 299 } 300 301 // TestVaultCompatibility tests compatibility across Vault versions 302 func TestVaultCompatibility(t *testing.T) { 303 if !*integration { 304 t.Skip("skipping test in non-integration mode: add -integration flag to run") 305 } 306 307 sorted, vaultBinaries := syncVault(t) 308 309 for _, v := range sorted { 310 ver := v.String() 311 bin := vaultBinaries[ver] 312 require.NotZerof(t, bin, "missing version: %s", ver) 313 t.Run(ver, func(t *testing.T) { 314 testVaultCompatibility(t, bin, ver) 315 }) 316 } 317 } 318 319 // testVaultCompatibility tests compatibility with the given vault binary 320 func testVaultCompatibility(t *testing.T, vault string, version string) { 321 require := require.New(t) 322 323 // Create a Vault server 324 v := testutil.NewTestVaultFromPath(t, vault) 325 defer v.Stop() 326 327 token := setupVault(t, v.Client, version) 328 329 // Create a Nomad agent using the created vault 330 nomad := agent.NewTestAgent(t, t.Name(), func(c *agent.Config) { 331 if c.Vault == nil { 332 c.Vault = &config.VaultConfig{} 333 } 334 c.Vault.Enabled = helper.BoolToPtr(true) 335 c.Vault.Token = token 336 c.Vault.Role = "nomad-cluster" 337 c.Vault.AllowUnauthenticated = helper.BoolToPtr(true) 338 c.Vault.Addr = v.HTTPAddr 339 }) 340 defer nomad.Shutdown() 341 342 // Submit the Nomad job that requests a Vault token and cats that the Vault 343 // token is there 344 c := nomad.Client() 345 j := c.Jobs() 346 _, _, err := j.Register(job, nil) 347 require.NoError(err) 348 349 // Wait for there to be an allocation terminated successfully 350 //var allocID string 351 testutil.WaitForResult(func() (bool, error) { 352 // Get the allocations for the job 353 allocs, _, err := j.Allocations(*job.ID, false, nil) 354 if err != nil { 355 return false, err 356 } 357 l := len(allocs) 358 switch l { 359 case 0: 360 return false, fmt.Errorf("want one alloc; got zero") 361 case 1: 362 default: 363 // exit early 364 t.Fatalf("too many allocations; something failed") 365 } 366 alloc := allocs[0] 367 //allocID = alloc.ID 368 if alloc.ClientStatus == "complete" { 369 return true, nil 370 } 371 372 return false, fmt.Errorf("client status %q", alloc.ClientStatus) 373 }, func(err error) { 374 t.Fatalf("allocation did not finish: %v", err) 375 }) 376 377 } 378 379 // setupVault takes the Vault client and creates the required policies and 380 // roles. It returns the token that should be used by Nomad 381 func setupVault(t *testing.T, client *vapi.Client, vaultVersion string) string { 382 // Write the policy 383 sys := client.Sys() 384 385 // pre-0.9.0 vault servers do not work with our new vault client for the policy endpoint 386 // perform this using a raw HTTP request 387 newApi := version.Must(version.NewVersion("0.9.0")) 388 testVersion := version.Must(version.NewVersion(vaultVersion)) 389 if testVersion.LessThan(newApi) { 390 body := map[string]string{ 391 "rules": policy, 392 } 393 request := client.NewRequest("PUT", "/v1/sys/policy/nomad-server") 394 if err := request.SetJSONBody(body); err != nil { 395 t.Fatalf("failed to set JSON body on legacy policy creation: %v", err) 396 } 397 if _, err := client.RawRequest(request); err != nil { 398 t.Fatalf("failed to create legacy policy: %v", err) 399 } 400 } else { 401 if err := sys.PutPolicy("nomad-server", policy); err != nil { 402 t.Fatalf("failed to create policy: %v", err) 403 } 404 } 405 406 // Build the role 407 l := client.Logical() 408 l.Write("auth/token/roles/nomad-cluster", role) 409 410 // Create a new token with the role 411 a := client.Auth().Token() 412 req := vapi.TokenCreateRequest{ 413 Policies: []string{"nomad-server"}, 414 Period: "72h", 415 NoParent: true, 416 } 417 s, err := a.Create(&req) 418 if err != nil { 419 t.Fatalf("failed to create child token: %v", err) 420 } 421 422 // Get the client token 423 if s == nil || s.Auth == nil { 424 t.Fatalf("bad secret response: %+v", s) 425 } 426 427 return s.Auth.ClientToken 428 }