github.com/opentofu/opentofu@v1.7.1/internal/communicator/ssh/ssh_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package ssh
     7  
     8  import (
     9  	"bytes"
    10  	"crypto/rand"
    11  	"crypto/rsa"
    12  	"crypto/x509"
    13  	"encoding/pem"
    14  	"os"
    15  	"path/filepath"
    16  	"testing"
    17  
    18  	"golang.org/x/crypto/ssh"
    19  )
    20  
    21  // verify that we can locate public key data
    22  func TestFindKeyData(t *testing.T) {
    23  	// set up a test directory
    24  	td := t.TempDir()
    25  	cwd, err := os.Getwd()
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  	if err := os.Chdir(td); err != nil {
    30  		t.Fatal(err)
    31  	}
    32  	defer os.Chdir(cwd)
    33  
    34  	id := "provisioner_id"
    35  
    36  	pub := generateSSHKey(t, id)
    37  	pubData := pub.Marshal()
    38  
    39  	// backup the pub file, and replace it with a broken file to ensure we
    40  	// extract the public key from the private key.
    41  	if err := os.Rename(id+".pub", "saved.pub"); err != nil {
    42  		t.Fatal(err)
    43  	}
    44  	if err := os.WriteFile(id+".pub", []byte("not a public key"), 0600); err != nil {
    45  		t.Fatal(err)
    46  	}
    47  
    48  	foundData := findIDPublicKey(id)
    49  	if !bytes.Equal(foundData, pubData) {
    50  		t.Fatalf("public key %q does not match", foundData)
    51  	}
    52  
    53  	// move the pub file back, and break the private key file to simulate an
    54  	// encrypted private key
    55  	if err := os.Rename("saved.pub", id+".pub"); err != nil {
    56  		t.Fatal(err)
    57  	}
    58  
    59  	if err := os.WriteFile(id, []byte("encrypted private key"), 0600); err != nil {
    60  		t.Fatal(err)
    61  	}
    62  
    63  	foundData = findIDPublicKey(id)
    64  	if !bytes.Equal(foundData, pubData) {
    65  		t.Fatalf("public key %q does not match", foundData)
    66  	}
    67  
    68  	// check the file by path too
    69  	foundData = findIDPublicKey(filepath.Join(".", id))
    70  	if !bytes.Equal(foundData, pubData) {
    71  		t.Fatalf("public key %q does not match", foundData)
    72  	}
    73  }
    74  
    75  func generateSSHKey(t *testing.T, idFile string) ssh.PublicKey {
    76  	t.Helper()
    77  
    78  	priv, err := rsa.GenerateKey(rand.Reader, 2048)
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  
    83  	privFile, err := os.OpenFile(idFile, os.O_RDWR|os.O_CREATE, 0600)
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	defer privFile.Close()
    88  	privPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}
    89  	if err := pem.Encode(privFile, privPEM); err != nil {
    90  		t.Fatal(err)
    91  	}
    92  
    93  	// generate and write public key
    94  	pub, err := ssh.NewPublicKey(&priv.PublicKey)
    95  	if err != nil {
    96  		t.Fatal(err)
    97  	}
    98  
    99  	err = os.WriteFile(idFile+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  
   104  	return pub
   105  }