github.com/hernad/nomad@v1.6.112/e2e/vaultcompat/vault_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package vaultcompat 5 6 import ( 7 "archive/zip" 8 "bytes" 9 "encoding/json" 10 "flag" 11 "fmt" 12 "io" 13 "net/http" 14 "os" 15 "path/filepath" 16 "runtime" 17 "sort" 18 "testing" 19 "time" 20 21 "github.com/hashicorp/go-version" 22 "github.com/hernad/nomad/command/agent" 23 "github.com/hernad/nomad/helper/pointer" 24 "github.com/hernad/nomad/nomad/structs/config" 25 "github.com/hernad/nomad/testutil" 26 vapi "github.com/hashicorp/vault/api" 27 "github.com/stretchr/testify/require" 28 ) 29 30 var ( 31 integration = flag.Bool("integration", false, "run integration tests") 32 minVaultVer = version.Must(version.NewVersion("0.6.2")) 33 ) 34 35 // syncVault discovers available versions of Vault, downloads the binaries, 36 // returns a map of version to binary path as well as a sorted list of 37 // versions. 38 func syncVault(t *testing.T) ([]*version.Version, map[string]string) { 39 40 binDir := filepath.Join(os.TempDir(), "vault-bins/") 41 42 urls := vaultVersions(t) 43 44 sorted, versions, err := pruneVersions(urls) 45 require.NoError(t, err) 46 47 // Get the binaries we need to download 48 missing, err := missingVault(binDir, versions) 49 require.NoError(t, err) 50 51 // Create the directory for the binaries 52 require.NoError(t, createBinDir(binDir)) 53 54 // Download in parallel 55 start := time.Now() 56 errCh := make(chan error, len(missing)) 57 for ver, url := range missing { 58 go func(dst, url string) { 59 errCh <- getVault(dst, url) 60 }(filepath.Join(binDir, ver), url) 61 } 62 for i := 0; i < len(missing); i++ { 63 select { 64 case err := <-errCh: 65 require.NoError(t, err) 66 case <-time.After(5 * time.Minute): 67 require.Fail(t, "timed out downloading Vault binaries") 68 } 69 } 70 if n := len(missing); n > 0 { 71 t.Logf("Downloaded %d versions of Vault in %s", n, time.Now().Sub(start)) 72 } 73 74 binaries := make(map[string]string, len(versions)) 75 for ver := range versions { 76 binaries[ver] = filepath.Join(binDir, ver) 77 } 78 return sorted, binaries 79 } 80 81 // vaultVersions discovers available Vault versions from releases.hashicorp.com 82 // and returns a map of version to url. 83 func vaultVersions(t *testing.T) map[string]string { 84 resp, err := http.Get("https://releases.hashicorp.com/vault/index.json") 85 require.NoError(t, err) 86 87 respJson := struct { 88 Versions map[string]struct { 89 Builds []struct { 90 Version string `json:"version"` 91 Os string `json:"os"` 92 Arch string `json:"arch"` 93 URL string `json:"url"` 94 } `json:"builds"` 95 } 96 }{} 97 require.NoError(t, json.NewDecoder(resp.Body).Decode(&respJson)) 98 require.NoError(t, resp.Body.Close()) 99 100 versions := map[string]string{} 101 for vk, vv := range respJson.Versions { 102 gover, err := version.NewVersion(vk) 103 if err != nil { 104 t.Logf("error parsing Vault version %q -> %v", vk, err) 105 continue 106 } 107 108 // Skip ancient versions 109 if gover.LessThan(minVaultVer) { 110 continue 111 } 112 113 // Skip prerelease and enterprise versions 114 if gover.Prerelease() != "" || gover.Metadata() != "" { 115 continue 116 } 117 118 url := "" 119 for _, b := range vv.Builds { 120 buildver, err := version.NewVersion(b.Version) 121 if err != nil { 122 t.Logf("error parsing Vault build version %q -> %v", b.Version, err) 123 continue 124 } 125 126 if buildver.Prerelease() != "" { 127 continue 128 } 129 130 if buildver.Metadata() != "" { 131 continue 132 } 133 134 if b.Os != runtime.GOOS { 135 continue 136 } 137 138 if b.Arch != runtime.GOARCH { 139 continue 140 } 141 142 // Match! 143 url = b.URL 144 break 145 } 146 147 if url != "" { 148 versions[vk] = url 149 } 150 } 151 152 return versions 153 } 154 155 // pruneVersions only takes the latest Z for each X.Y.Z release. Returns a 156 // sorted list and map of kept versions. 157 func pruneVersions(all map[string]string) ([]*version.Version, map[string]string, error) { 158 if len(all) == 0 { 159 return nil, nil, fmt.Errorf("0 Vault versions") 160 } 161 162 sorted := make([]*version.Version, 0, len(all)) 163 164 for k := range all { 165 sorted = append(sorted, version.Must(version.NewVersion(k))) 166 } 167 168 sort.Sort(version.Collection(sorted)) 169 170 keep := make([]*version.Version, 0, len(all)) 171 172 for _, v := range sorted { 173 segments := v.Segments() 174 if len(segments) < 3 { 175 // Drop malformed versions 176 continue 177 } 178 179 if len(keep) == 0 { 180 keep = append(keep, v) 181 continue 182 } 183 184 last := keep[len(keep)-1].Segments() 185 186 if segments[0] == last[0] && segments[1] == last[1] { 187 // current X.Y == last X.Y, replace last with current 188 keep[len(keep)-1] = v 189 } else { 190 // current X.Y != last X.Y, append 191 keep = append(keep, v) 192 } 193 } 194 195 // Create a new map of canonicalized versions to urls 196 urls := make(map[string]string, len(keep)) 197 for _, v := range keep { 198 origURL := all[v.Original()] 199 if origURL == "" { 200 return nil, nil, fmt.Errorf("missing version %s", v.Original()) 201 } 202 urls[v.String()] = origURL 203 } 204 205 return keep, urls, nil 206 } 207 208 // createBinDir creates the binary directory 209 func createBinDir(binDir string) error { 210 // Check if the directory exists, otherwise create it 211 f, err := os.Stat(binDir) 212 if err != nil && !os.IsNotExist(err) { 213 return fmt.Errorf("failed to stat directory: %v", err) 214 } 215 216 if f != nil && f.IsDir() { 217 return nil 218 } else if f != nil { 219 if err := os.RemoveAll(binDir); err != nil { 220 return fmt.Errorf("failed to remove file at directory path: %v", err) 221 } 222 } 223 224 // Create the directory 225 if err := os.Mkdir(binDir, 075); err != nil { 226 return fmt.Errorf("failed to make directory: %v", err) 227 } 228 if err := os.Chmod(binDir, 0755); err != nil { 229 return fmt.Errorf("failed to chmod: %v", err) 230 } 231 232 return nil 233 } 234 235 // missingVault returns the binaries that must be downloaded. versions key must 236 // be the Vault version. 237 func missingVault(binDir string, versions map[string]string) (map[string]string, error) { 238 files, err := os.ReadDir(binDir) 239 if err != nil { 240 if os.IsNotExist(err) { 241 return versions, nil 242 } 243 244 return nil, fmt.Errorf("failed to stat directory: %v", err) 245 } 246 247 // Copy versions so we don't mutate it 248 missingSet := make(map[string]string, len(versions)) 249 for k, v := range versions { 250 missingSet[k] = v 251 } 252 253 for _, f := range files { 254 delete(missingSet, f.Name()) 255 } 256 257 return missingSet, nil 258 } 259 260 // getVault downloads the given Vault binary 261 func getVault(dst, url string) error { 262 resp, err := http.Get(url) 263 if err != nil { 264 return err 265 } 266 defer resp.Body.Close() 267 268 // Wrap in an in-mem buffer 269 b := bytes.NewBuffer(nil) 270 if _, err := io.Copy(b, resp.Body); err != nil { 271 return fmt.Errorf("error reading response body: %v", err) 272 } 273 resp.Body.Close() 274 275 zreader, err := zip.NewReader(bytes.NewReader(b.Bytes()), resp.ContentLength) 276 if err != nil { 277 return err 278 } 279 280 if l := len(zreader.File); l != 1 { 281 return fmt.Errorf("unexpected number of files in zip: %v", l) 282 } 283 284 // Copy the file to its destination 285 out, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0777) 286 if err != nil { 287 return err 288 } 289 defer out.Close() 290 291 zfile, err := zreader.File[0].Open() 292 if err != nil { 293 return fmt.Errorf("failed to open zip file: %v", err) 294 } 295 296 if _, err := io.Copy(out, zfile); err != nil { 297 return fmt.Errorf("failed to decompress file to destination: %v", err) 298 } 299 300 return nil 301 } 302 303 // TestVaultCompatibility tests compatibility across Vault versions 304 func TestVaultCompatibility(t *testing.T) { 305 if !*integration { 306 t.Skip("skipping test in non-integration mode: add -integration flag to run") 307 } 308 309 sorted, vaultBinaries := syncVault(t) 310 311 for _, v := range sorted { 312 ver := v.String() 313 bin := vaultBinaries[ver] 314 require.NotZerof(t, bin, "missing version: %s", ver) 315 t.Run(ver, func(t *testing.T) { 316 testVaultCompatibility(t, bin, ver) 317 }) 318 } 319 } 320 321 // testVaultCompatibility tests compatibility with the given vault binary 322 func testVaultCompatibility(t *testing.T, vault string, version string) { 323 require := require.New(t) 324 325 // Create a Vault server 326 v := testutil.NewTestVaultFromPath(t, vault) 327 defer v.Stop() 328 329 token := setupVault(t, v.Client, version) 330 331 // Create a Nomad agent using the created vault 332 nomad := agent.NewTestAgent(t, t.Name(), func(c *agent.Config) { 333 if c.Vault == nil { 334 c.Vault = &config.VaultConfig{} 335 } 336 c.Vault.Enabled = pointer.Of(true) 337 c.Vault.Token = token 338 c.Vault.Role = "nomad-cluster" 339 c.Vault.AllowUnauthenticated = pointer.Of(true) 340 c.Vault.Addr = v.HTTPAddr 341 }) 342 defer nomad.Shutdown() 343 344 // Submit the Nomad job that requests a Vault token and cats that the Vault 345 // token is there 346 c := nomad.Client() 347 j := c.Jobs() 348 _, _, err := j.Register(job, nil) 349 require.NoError(err) 350 351 // Wait for there to be an allocation terminated successfully 352 //var allocID string 353 testutil.WaitForResult(func() (bool, error) { 354 // Get the allocations for the job 355 allocs, _, err := j.Allocations(*job.ID, false, nil) 356 if err != nil { 357 return false, err 358 } 359 l := len(allocs) 360 switch l { 361 case 0: 362 return false, fmt.Errorf("want one alloc; got zero") 363 case 1: 364 default: 365 // exit early 366 require.Fail("too many allocations; something failed") 367 } 368 alloc := allocs[0] 369 //allocID = alloc.ID 370 if alloc.ClientStatus == "complete" { 371 return true, nil 372 } 373 374 return false, fmt.Errorf("client status %q", alloc.ClientStatus) 375 }, func(err error) { 376 require.NoError(err, "allocation did not finish") 377 }) 378 379 } 380 381 // setupVault takes the Vault client and creates the required policies and 382 // roles. It returns the token that should be used by Nomad 383 func setupVault(t *testing.T, client *vapi.Client, vaultVersion string) string { 384 // Write the policy 385 sys := client.Sys() 386 387 // pre-0.9.0 vault servers do not work with our new vault client for the policy endpoint 388 // perform this using a raw HTTP request 389 newApi := version.Must(version.NewVersion("0.9.0")) 390 testVersion := version.Must(version.NewVersion(vaultVersion)) 391 if testVersion.LessThan(newApi) { 392 body := map[string]string{ 393 "rules": policy, 394 } 395 request := client.NewRequest("PUT", "/v1/sys/policy/nomad-server") 396 if err := request.SetJSONBody(body); err != nil { 397 require.NoError(t, err, "failed to set JSON body on legacy policy creation") 398 } 399 if _, err := client.RawRequest(request); err != nil { 400 require.NoError(t, err, "failed to create legacy policy") 401 } 402 } else { 403 if err := sys.PutPolicy("nomad-server", policy); err != nil { 404 require.NoError(t, err, "failed to create policy") 405 } 406 } 407 408 // Build the role 409 l := client.Logical() 410 l.Write("auth/token/roles/nomad-cluster", role) 411 412 // Create a new token with the role 413 a := client.Auth().Token() 414 req := vapi.TokenCreateRequest{ 415 Policies: []string{"nomad-server"}, 416 Period: "72h", 417 NoParent: true, 418 } 419 s, err := a.Create(&req) 420 if err != nil { 421 require.NoError(t, err, "failed to create child token") 422 } 423 424 // Get the client token 425 if s == nil || s.Auth == nil { 426 require.NoError(t, err, "bad secret response") 427 } 428 429 return s.Auth.ClientToken 430 }