github.com/mponton/terratest@v0.44.0/modules/ssh/agent.go (about)

     1  package ssh
     2  
     3  import (
     4  	"crypto/x509"
     5  	"encoding/pem"
     6  	"io"
     7  	"io/ioutil"
     8  	"net"
     9  	"os"
    10  	"path/filepath"
    11  
    12  	"github.com/mponton/terratest/modules/logger"
    13  	"github.com/mponton/terratest/modules/testing"
    14  	"golang.org/x/crypto/ssh/agent"
    15  )
    16  
    17  type SshAgent struct {
    18  	stop       chan bool
    19  	stopped    chan bool
    20  	socketDir  string
    21  	socketFile string
    22  	agent      agent.Agent
    23  	ln         net.Listener
    24  }
    25  
    26  // Create SSH agent, start it in background and returns control back to the main thread
    27  // You should stop the agent to cleanup files afterwards by calling `defer s.Stop()`
    28  func NewSshAgent(t testing.TestingT, socketDir string, socketFile string) (*SshAgent, error) {
    29  	var err error
    30  	s := &SshAgent{make(chan bool), make(chan bool), socketDir, socketFile, agent.NewKeyring(), nil}
    31  	s.ln, err = net.Listen("unix", s.socketFile)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  	go s.run(t)
    36  	return s, nil
    37  }
    38  
    39  // expose socketFile variable
    40  func (s *SshAgent) SocketFile() string {
    41  	return s.socketFile
    42  }
    43  
    44  // SSH Agent listener and handler
    45  func (s *SshAgent) run(t testing.TestingT) {
    46  	defer close(s.stopped)
    47  	for {
    48  		select {
    49  		case <-s.stop:
    50  			return
    51  		default:
    52  			c, err := s.ln.Accept()
    53  			if err != nil {
    54  				select {
    55  				// When s.Stop() closes the listener, s.ln.Accept() returns an error that can be ignored
    56  				// since the agent is in stopping process
    57  				case <-s.stop:
    58  					return
    59  					// When s.ln.Accept() returns a legit error, we print it and continue accepting further requests
    60  				default:
    61  					logger.Logf(t, "could not accept connection to agent %v", err)
    62  					continue
    63  				}
    64  			} else {
    65  				defer c.Close()
    66  				go func(c io.ReadWriter) {
    67  					err := agent.ServeAgent(s.agent, c)
    68  					if err != nil {
    69  						logger.Logf(t, "could not serve ssh agent %v", err)
    70  					}
    71  				}(c)
    72  			}
    73  		}
    74  	}
    75  }
    76  
    77  // Stop and clean up SSH agent
    78  func (s *SshAgent) Stop() {
    79  	close(s.stop)
    80  	s.ln.Close()
    81  	<-s.stopped
    82  	os.RemoveAll(s.socketDir)
    83  }
    84  
    85  // Instantiates and returns an in-memory ssh agent with the given KeyPair already added
    86  // You should stop the agent to cleanup files afterwards by calling `defer sshAgent.Stop()`
    87  func SshAgentWithKeyPair(t testing.TestingT, keyPair *KeyPair) *SshAgent {
    88  	sshAgent, err := SshAgentWithKeyPairE(t, keyPair)
    89  
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  
    94  	return sshAgent
    95  }
    96  
    97  func SshAgentWithKeyPairE(t testing.TestingT, keyPair *KeyPair) (*SshAgent, error) {
    98  	sshAgent, err := SshAgentWithKeyPairsE(t, []*KeyPair{keyPair})
    99  	return sshAgent, err
   100  }
   101  
   102  func SshAgentWithKeyPairs(t testing.TestingT, keyPairs []*KeyPair) *SshAgent {
   103  	sshAgent, err := SshAgentWithKeyPairsE(t, keyPairs)
   104  
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  
   109  	return sshAgent
   110  }
   111  
   112  // Instantiates and returns an in-memory ssh agent with the given KeyPair(s) already added
   113  // You should stop the agent to cleanup files afterwards by calling `defer sshAgent.Stop()`
   114  func SshAgentWithKeyPairsE(t testing.TestingT, keyPairs []*KeyPair) (*SshAgent, error) {
   115  	logger.Logf(t, "Generating SSH Agent with given KeyPair(s)")
   116  
   117  	// Instantiate a temporary SSH agent
   118  	socketDir, err := ioutil.TempDir("", "ssh-agent-")
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	socketFile := filepath.Join(socketDir, "ssh_auth.sock")
   123  	sshAgent, err := NewSshAgent(t, socketDir, socketFile)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	// add given ssh keys to the newly created agent
   129  	for _, keyPair := range keyPairs {
   130  		// Create SSH key for the agent using the given SSH key pair(s)
   131  		block, _ := pem.Decode([]byte(keyPair.PrivateKey))
   132  		privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  		key := agent.AddedKey{PrivateKey: privateKey}
   137  		sshAgent.agent.Add(key)
   138  	}
   139  
   140  	return sshAgent, err
   141  }