gitlab.com/pidrakin/dotfiles-cli@v1.7.5/ssh_ultimate/authentication.go (about)

     1  package ssh_ultimate
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"github.com/kevinburke/ssh_config"
     7  	"github.com/mitchellh/go-homedir"
     8  	log "github.com/sirupsen/logrus"
     9  	"gitlab.com/pidrakin/dotfiles-cli/common/logging"
    10  	"gitlab.com/pidrakin/go/interactive"
    11  	"golang.org/x/crypto/ssh"
    12  	"golang.org/x/crypto/ssh/agent"
    13  	"io"
    14  	"net"
    15  	"os"
    16  	"strings"
    17  )
    18  
    19  func parseIdentityFile(sshConfig *ssh_config.Config, host string, identityFile string) (string, error) {
    20  	if identityFile != "" {
    21  		return identityFile, nil
    22  	}
    23  
    24  	if sshConfig != nil {
    25  		identityFile, err := sshConfig.Get(host, "IdentityFile")
    26  		if err != nil {
    27  			return "", err
    28  		}
    29  		if identityFile != "" {
    30  			return identityFile, nil
    31  		}
    32  	}
    33  
    34  	return "", nil
    35  }
    36  
    37  func parseIdentityAgent(sshConfig *ssh_config.Config, host string, identityAgent string) (string, error) {
    38  	var err error
    39  	if sshConfig != nil {
    40  		if identityAgent, err = sshConfig.Get(host, "IdentityAgent"); err != nil {
    41  			return "", err
    42  		}
    43  	}
    44  
    45  	return identityAgent, nil
    46  }
    47  
    48  func passwordAuthMethod(reader io.Reader, writer io.Writer, id string) ssh.AuthMethod {
    49  	return ssh.PasswordCallback(func() (string, error) {
    50  		response := os.Getenv("SSHPASS")
    51  		if response == "" {
    52  			_, err := fmt.Fprintf(writer, "[%s] Password required: ", id)
    53  			if err != nil {
    54  				return "", err
    55  			}
    56  			response, err = interactive.PromptLine(reader, true)
    57  			if err != nil {
    58  				return "", err
    59  			}
    60  			response = strings.TrimSuffix(response, "\n")
    61  			_, err = fmt.Fprintf(writer, "\n")
    62  			if err != nil {
    63  				return "", err
    64  			}
    65  		}
    66  		return response, nil
    67  	})
    68  }
    69  
    70  var agentClient agent.Agent
    71  
    72  func sshAgentAuthMethod(socket string) (ssh.AuthMethod, error) {
    73  	if agentClient == nil {
    74  		socket, err := homedir.Expand(socket)
    75  		if err != nil {
    76  			return nil, fmt.Errorf("[dotfiles rollout] failed to open SSH_AUTH_SOCK: %v", err)
    77  		}
    78  		conn, err := net.Dial("unix", socket)
    79  		if err != nil {
    80  			return nil, fmt.Errorf("[dotfiles rollout] failed to open SSH_AUTH_SOCK: %v", err)
    81  		}
    82  
    83  		agentClient = agent.NewClient(conn)
    84  	}
    85  	return ssh.PublicKeysCallback(agentClient.Signers), nil
    86  }
    87  
    88  func privateKeyAuthMethod(errWriter io.Writer, file string, keyPass string) (ssh.AuthMethod, error) {
    89  	defer logging.SetLogOutput(errWriter)()
    90  	var key ssh.Signer
    91  	var b []byte
    92  	normalizedPath, err := homedir.Expand(file)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	b, err = os.ReadFile(normalizedPath)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	if keyPass == "" {
   101  		key, err = ssh.ParsePrivateKey(b)
   102  	} else {
   103  		key, err = ssh.ParsePrivateKeyWithPassphrase(b, []byte(keyPass))
   104  	}
   105  
   106  	if err != nil {
   107  		var passphraseMissingError *ssh.PassphraseMissingError
   108  		if errors.As(err, &passphraseMissingError) {
   109  			log.Errorf("[dotfiles rollout] ssh identity file needs passphrase")
   110  		}
   111  		return nil, err
   112  	}
   113  	return ssh.PublicKeys(key), nil
   114  }
   115  
   116  func parseAuthMethods(reader io.Reader, writer io.Writer, errWriter io.Writer, sshConfig *ssh_config.Config, host string, identityFilePath string, identityKeyPass string, id string) ([]ssh.AuthMethod, error) {
   117  	var authMethods []ssh.AuthMethod
   118  
   119  	var err error
   120  
   121  	identityFilePath, err = parseIdentityFile(sshConfig, host, identityFilePath)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	if identityFilePath != "" {
   127  		authMethod, privErr := privateKeyAuthMethod(errWriter, identityFilePath, identityKeyPass)
   128  		if privErr != nil {
   129  			return nil, privErr
   130  		}
   131  		authMethods = append(authMethods, authMethod)
   132  	}
   133  
   134  	// TODO: not yet used; don't know how
   135  	//var identitiesOnly string
   136  	//if sshConfig != nil {
   137  	//	identitiesOnly, err = sshConfig.Get(host, "IdentitiesOnly")
   138  	//	if err != nil {
   139  	//		return nil, err
   140  	//	}
   141  	//}
   142  
   143  	identityAgent, err := parseIdentityAgent(sshConfig, host, os.Getenv("SSH_AUTH_SOCK"))
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	if agentClient != nil || identityAgent != "" {
   149  		authMethod, agentErr := sshAgentAuthMethod(identityAgent)
   150  		if agentErr == nil {
   151  			authMethods = append(authMethods, authMethod)
   152  		}
   153  	}
   154  
   155  	var passwordAuthentication string
   156  	if sshConfig != nil {
   157  		passwordAuthentication, err = sshConfig.Get(host, "PasswordAuthentication")
   158  		if err != nil {
   159  			return nil, err
   160  		}
   161  	}
   162  
   163  	if passwordAuthentication != "no" {
   164  		authMethods = append(authMethods, passwordAuthMethod(reader, writer, id))
   165  	}
   166  
   167  	return authMethods, nil
   168  }