volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/ssh/ssh.go (about)

     1  /*
     2  Copyright 2019 The Volcano Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package ssh
    18  
    19  import (
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"crypto/x509"
    23  	"encoding/pem"
    24  	"flag"
    25  	"fmt"
    26  
    27  	"golang.org/x/crypto/ssh"
    28  	v1 "k8s.io/api/core/v1"
    29  	"k8s.io/klog/v2"
    30  
    31  	batch "volcano.sh/apis/pkg/apis/batch/v1alpha1"
    32  	"volcano.sh/apis/pkg/apis/helpers"
    33  	jobhelpers "volcano.sh/volcano/pkg/controllers/job/helpers"
    34  	pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface"
    35  )
    36  
    37  type sshPlugin struct {
    38  	// Arguments given for the plugin
    39  	pluginArguments []string
    40  
    41  	client pluginsinterface.PluginClientset
    42  
    43  	// flag parse args
    44  	sshKeyFilePath string
    45  
    46  	// private key string
    47  	sshPrivateKey string
    48  
    49  	// public key string
    50  	sshPublicKey string
    51  }
    52  
    53  // New creates ssh plugin
    54  func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface {
    55  	p := sshPlugin{
    56  		pluginArguments: arguments,
    57  		client:          client,
    58  		sshKeyFilePath:  SSHAbsolutePath,
    59  	}
    60  
    61  	p.addFlags()
    62  
    63  	return &p
    64  }
    65  
    66  func (sp *sshPlugin) Name() string {
    67  	return "ssh"
    68  }
    69  
    70  func (sp *sshPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
    71  	sp.mountRsaKey(pod, job)
    72  
    73  	return nil
    74  }
    75  
    76  func (sp *sshPlugin) OnJobAdd(job *batch.Job) error {
    77  	if job.Status.ControlledResources["plugin-"+sp.Name()] == sp.Name() {
    78  		return nil
    79  	}
    80  
    81  	var data map[string][]byte
    82  	var err error
    83  	if len(sp.sshPrivateKey) > 0 {
    84  		data, err = withUserProvidedRsaKey(job, sp.sshPrivateKey, sp.sshPublicKey)
    85  	} else {
    86  		data, err = generateRsaKey(job)
    87  	}
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	if err := helpers.CreateOrUpdateSecret(job, sp.client.KubeClients, data, sp.secretName(job)); err != nil {
    93  		return fmt.Errorf("create secret for job <%s/%s> with ssh plugin failed for %v",
    94  			job.Namespace, job.Name, err)
    95  	}
    96  
    97  	job.Status.ControlledResources["plugin-"+sp.Name()] = sp.Name()
    98  
    99  	return nil
   100  }
   101  
   102  func (sp *sshPlugin) OnJobDelete(job *batch.Job) error {
   103  	if job.Status.ControlledResources["plugin-"+sp.Name()] != sp.Name() {
   104  		return nil
   105  	}
   106  	if err := helpers.DeleteSecret(job, sp.client.KubeClients, sp.secretName(job)); err != nil {
   107  		return err
   108  	}
   109  	delete(job.Status.ControlledResources, "plugin-"+sp.Name())
   110  
   111  	return nil
   112  }
   113  
   114  // TODO: currently a container using a Secret as a subPath volume mount will not receive Secret updates.
   115  // we may not update the job secret due to the above reason now.
   116  // related issue: https://github.com/volcano-sh/volcano/issues/1420
   117  func (sp *sshPlugin) OnJobUpdate(job *batch.Job) error {
   118  	//data, err := generateRsaKey(job)
   119  	//if err != nil {
   120  	//	return err
   121  	//}
   122  	//
   123  	//if err := helpers.CreateOrUpdateSecret(job, sp.client.KubeClients, data, sp.secretName(job)); err != nil {
   124  	//	return fmt.Errorf("update secret for job <%s/%s> with ssh plugin failed for %v",
   125  	//		job.Namespace, job.Name, err)
   126  	//}
   127  
   128  	return nil
   129  }
   130  
   131  func (sp *sshPlugin) mountRsaKey(pod *v1.Pod, job *batch.Job) {
   132  	secretName := sp.secretName(job)
   133  
   134  	sshVolume := v1.Volume{
   135  		Name: secretName,
   136  	}
   137  
   138  	var mode int32 = 0600
   139  	sshVolume.Secret = &v1.SecretVolumeSource{
   140  		SecretName: secretName,
   141  		Items: []v1.KeyToPath{
   142  			{
   143  				Key:  SSHPrivateKey,
   144  				Path: SSHRelativePath + "/" + SSHPrivateKey,
   145  			},
   146  			{
   147  				Key:  SSHPublicKey,
   148  				Path: SSHRelativePath + "/" + SSHPublicKey,
   149  			},
   150  			{
   151  				Key:  SSHAuthorizedKeys,
   152  				Path: SSHRelativePath + "/" + SSHAuthorizedKeys,
   153  			},
   154  			{
   155  				Key:  SSHConfig,
   156  				Path: SSHRelativePath + "/" + SSHConfig,
   157  			},
   158  		},
   159  		DefaultMode: &mode,
   160  	}
   161  
   162  	if sp.sshKeyFilePath != SSHAbsolutePath {
   163  		var noRootMode int32 = 0644
   164  		sshVolume.Secret.DefaultMode = &noRootMode
   165  	}
   166  
   167  	pod.Spec.Volumes = append(pod.Spec.Volumes, sshVolume)
   168  
   169  	for i, c := range pod.Spec.Containers {
   170  		vm := v1.VolumeMount{
   171  			MountPath: sp.sshKeyFilePath,
   172  			SubPath:   SSHRelativePath,
   173  			Name:      secretName,
   174  		}
   175  
   176  		pod.Spec.Containers[i].VolumeMounts = append(c.VolumeMounts, vm)
   177  	}
   178  	for i, c := range pod.Spec.InitContainers {
   179  		vm := v1.VolumeMount{
   180  			MountPath: sp.sshKeyFilePath,
   181  			SubPath:   SSHRelativePath,
   182  			Name:      secretName,
   183  		}
   184  
   185  		pod.Spec.InitContainers[i].VolumeMounts = append(c.VolumeMounts, vm)
   186  	}
   187  }
   188  
   189  func generateRsaKey(job *batch.Job) (map[string][]byte, error) {
   190  	bitSize := 2048
   191  
   192  	privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
   193  	if err != nil {
   194  		klog.Errorf("rsa generateKey err: %v", err)
   195  		return nil, err
   196  	}
   197  
   198  	// id_rsa
   199  	privBlock := pem.Block{
   200  		Type:  "RSA PRIVATE KEY",
   201  		Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
   202  	}
   203  	privateKeyBytes := pem.EncodeToMemory(&privBlock)
   204  
   205  	// id_rsa.pub
   206  	publicRsaKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
   207  	if err != nil {
   208  		klog.Errorf("ssh newPublicKey err: %v", err)
   209  		return nil, err
   210  	}
   211  	publicKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
   212  
   213  	data := make(map[string][]byte)
   214  	data[SSHPrivateKey] = privateKeyBytes
   215  	data[SSHPublicKey] = publicKeyBytes
   216  	data[SSHAuthorizedKeys] = publicKeyBytes
   217  	data[SSHConfig] = []byte(generateSSHConfig(job))
   218  
   219  	return data, nil
   220  }
   221  
   222  func withUserProvidedRsaKey(job *batch.Job, sshPrivateKey string, sshPublicKey string) (map[string][]byte, error) {
   223  	data := make(map[string][]byte)
   224  	data[SSHPrivateKey] = []byte(sshPrivateKey)
   225  	data[SSHPublicKey] = []byte(sshPublicKey)
   226  	data[SSHAuthorizedKeys] = []byte(sshPublicKey)
   227  	data[SSHConfig] = []byte(generateSSHConfig(job))
   228  
   229  	return data, nil
   230  }
   231  
   232  func (sp *sshPlugin) secretName(job *batch.Job) string {
   233  	return fmt.Sprintf("%s-%s", job.Name, sp.Name())
   234  }
   235  
   236  func (sp *sshPlugin) addFlags() {
   237  	flagSet := flag.NewFlagSet(sp.Name(), flag.ContinueOnError)
   238  	flagSet.StringVar(&sp.sshKeyFilePath, "ssh-key-file-path", sp.sshKeyFilePath, "The path used to store "+
   239  		"ssh private and public keys, it is `/root/.ssh` by default.")
   240  	flagSet.StringVar(&sp.sshPrivateKey, "ssh-private-key", sp.sshPrivateKey, "The input string of the private key")
   241  	flagSet.StringVar(&sp.sshPublicKey, "ssh-public-key", sp.sshPublicKey, "The input string of the public key")
   242  
   243  	if err := flagSet.Parse(sp.pluginArguments); err != nil {
   244  		klog.Errorf("plugin %s flagset parse failed, err: %v", sp.Name(), err)
   245  	}
   246  }
   247  
   248  func generateSSHConfig(job *batch.Job) string {
   249  	config := "StrictHostKeyChecking no\nUserKnownHostsFile /dev/null\n"
   250  
   251  	for _, ts := range job.Spec.Tasks {
   252  		for i := 0; i < int(ts.Replicas); i++ {
   253  			hostName := ts.Template.Spec.Hostname
   254  			subdomain := ts.Template.Spec.Subdomain
   255  			if len(hostName) == 0 {
   256  				hostName = jobhelpers.MakePodName(job.Name, ts.Name, i)
   257  			}
   258  			if len(subdomain) == 0 {
   259  				subdomain = job.Name
   260  			}
   261  
   262  			config += "Host " + hostName + "\n"
   263  			config += "  HostName " + hostName + "." + subdomain + "\n"
   264  			if len(ts.Template.Spec.Hostname) != 0 {
   265  				break
   266  			}
   267  		}
   268  	}
   269  
   270  	return config
   271  }