github.com/mfpierre/corectl@v0.5.6/helpers.go (about)

     1  // Copyright 2015 - António Meireles  <antonio.meireles@reformi.st>
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  package main
    17  
    18  import (
    19  	"bytes"
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"crypto/sha512"
    23  	"crypto/x509"
    24  	"encoding/binary"
    25  	"encoding/hex"
    26  	"encoding/pem"
    27  	"fmt"
    28  	"io"
    29  	"io/ioutil"
    30  	"log"
    31  	"net"
    32  	"net/http"
    33  	"os"
    34  	"os/exec"
    35  	"os/user"
    36  	"path/filepath"
    37  	"regexp"
    38  	"strconv"
    39  	"strings"
    40  	"sync"
    41  
    42  	"github.com/blang/semver"
    43  	"github.com/rakyll/pb"
    44  	"github.com/spf13/viper"
    45  	// until github.com/mitchellh/go-ps consumes it
    46  	"github.com/yeonsh/go-ps"
    47  	"golang.org/x/crypto/openpgp"
    48  	"golang.org/x/crypto/openpgp/clearsign"
    49  	"golang.org/x/crypto/ssh"
    50  )
    51  
    52  // (recursively) fix permissions on path
    53  func normalizeOnDiskPermissions(path string) (err error) {
    54  	if !engine.hasPowers {
    55  		return
    56  	}
    57  	u, _ := strconv.Atoi(engine.uid)
    58  	g, _ := strconv.Atoi(engine.gid)
    59  
    60  	action := func(p string, _ os.FileInfo, _ error) error {
    61  		return os.Chown(p, u, g)
    62  	}
    63  	return filepath.Walk(path, action)
    64  }
    65  
    66  func pSlice(plain []string) []string {
    67  	var sliced []string
    68  	for _, x := range plain {
    69  		strip := strings.Replace(strings.Replace(x, "]", "", -1), "[", "", -1)
    70  		for _, y := range strings.Split(strip, ",") {
    71  			sliced = append(sliced, y)
    72  		}
    73  	}
    74  	return sliced
    75  }
    76  
    77  func downloadAndVerify(channel,
    78  	version string) (l map[string]string, err error) {
    79  	var (
    80  		prefix = "coreos_production_pxe"
    81  		root   = fmt.Sprintf("http://%s.release.core-os.net/amd64-usr/%s/",
    82  			channel, version)
    83  		files = []string{fmt.Sprintf("%s.vmlinuz", prefix),
    84  			fmt.Sprintf("%s_image.cpio.gz", prefix)}
    85  		signature = fmt.Sprintf("%s%s%s",
    86  			root, prefix, "_image.cpio.gz.DIGESTS.asc")
    87  		token                                     []string
    88  		tmpDir, digestTxt, fileName, bzHashSHA512 string
    89  		output                                    *os.File
    90  		digestRaw, longIDdecoded                  []byte
    91  		r, digest                                 *http.Response
    92  		longIDdecodedInt                          uint64
    93  		keyring                                   openpgp.EntityList
    94  		check                                     *openpgp.Entity
    95  		messageClear                              *clearsign.Block
    96  		messageClearRdr                           *bytes.Reader
    97  		re                                        = regexp.MustCompile(
    98  			`(?m)(?P<method>(SHA1|SHA512)) HASH(?:\r?)\n(?P<hash>` +
    99  				`.[^\s]*)\s*(?P<file>[\w\d_\.]*)`)
   100  		keymap   = make(map[string]int)
   101  		location = make(map[string]string)
   102  	)
   103  
   104  	log.Printf("downloading and verifying %s/%v\n", channel, version)
   105  	for _, target := range files {
   106  		url := fmt.Sprintf("%s%s", root, target)
   107  
   108  		if tmpDir, err = ioutil.TempDir(engine.tmpDir, "coreos"); err != nil {
   109  			return
   110  		}
   111  		defer func() {
   112  			if err != nil {
   113  				if e := os.RemoveAll(tmpDir); e != nil {
   114  					log.Println(e)
   115  				}
   116  			}
   117  		}()
   118  		token = strings.Split(url, "/")
   119  		fileName = token[len(token)-1]
   120  		pack := filepath.Join(tmpDir, "/", fileName)
   121  		if _, err = http.Head(url); err != nil {
   122  			return
   123  		}
   124  		if digest, err = http.Get(signature); err != nil {
   125  			return
   126  		}
   127  		defer digest.Body.Close()
   128  		switch digest.StatusCode {
   129  		case http.StatusOK, http.StatusNoContent:
   130  		default:
   131  			return l, fmt.Errorf("failed fetching %s: HTTP status: %s",
   132  				signature, digest.Status)
   133  		}
   134  		if digestRaw, err = ioutil.ReadAll(digest.Body); err != nil {
   135  			return
   136  		}
   137  		if longIDdecoded, err = hex.DecodeString(GPGLongID); err != nil {
   138  			return
   139  		}
   140  		longIDdecodedInt = binary.BigEndian.Uint64(longIDdecoded)
   141  		if engine.debug {
   142  			fmt.Printf("Trusted hex key id %s is decimal %d\n",
   143  				GPGLongID, longIDdecoded)
   144  		}
   145  		if keyring, err = openpgp.ReadArmoredKeyRing(
   146  			bytes.NewBufferString(GPGKey)); err != nil {
   147  			return
   148  		}
   149  		messageClear, _ = clearsign.Decode(digestRaw)
   150  		digestTxt = string(messageClear.Bytes)
   151  		messageClearRdr = bytes.NewReader(messageClear.Bytes)
   152  		if check, err =
   153  			openpgp.CheckDetachedSignature(keyring, messageClearRdr,
   154  				messageClear.ArmoredSignature.Body); err != nil {
   155  			return l, fmt.Errorf("Signature check for DIGESTS failed.")
   156  		}
   157  		if check.PrimaryKey.KeyId == longIDdecodedInt {
   158  			if engine.debug {
   159  				fmt.Printf("Trusted key id %d matches keyid %d\n",
   160  					longIDdecodedInt, longIDdecodedInt)
   161  			}
   162  		}
   163  		if engine.debug {
   164  			fmt.Printf("DIGESTS signature OK. ")
   165  		}
   166  
   167  		for index, name := range re.SubexpNames() {
   168  			keymap[name] = index
   169  		}
   170  
   171  		matches := re.FindAllStringSubmatch(digestTxt, -1)
   172  
   173  		for _, match := range matches {
   174  			if match[keymap["file"]] == fileName {
   175  				if match[keymap["method"]] == "SHA512" {
   176  					bzHashSHA512 = match[keymap["hash"]]
   177  				}
   178  			}
   179  		}
   180  
   181  		sha512h := sha512.New()
   182  
   183  		if r, err = http.Get(url); err != nil {
   184  			return
   185  		}
   186  		defer r.Body.Close()
   187  		switch r.StatusCode {
   188  		case http.StatusOK, http.StatusNoContent:
   189  		default:
   190  			return l, fmt.Errorf("failed fetching %s: HTTP status: %s",
   191  				signature, r.Status)
   192  		}
   193  		bar := pb.New(int(r.ContentLength)).SetUnits(pb.U_BYTES)
   194  		bar.Start()
   195  
   196  		if output, err = os.Create(pack); err != nil {
   197  			return
   198  		}
   199  		defer output.Close()
   200  
   201  		writer := io.MultiWriter(sha512h, bar, output)
   202  		io.Copy(writer, r.Body)
   203  		bar.Finish()
   204  		if hex.EncodeToString(sha512h.Sum([]byte{})) != bzHashSHA512 {
   205  			return l, fmt.Errorf("SHA512 hash verification failed for %s",
   206  				fileName)
   207  		}
   208  		log.Printf("SHA512 hash for %s OK\n", fileName)
   209  
   210  		location[fileName] = pack
   211  	}
   212  	return location, err
   213  }
   214  
   215  // sshKeyGen creates a one-time ssh public and private key pair
   216  func sshKeyGen() (a string, b string, err error) {
   217  	var (
   218  		public ssh.PublicKey
   219  		secret *rsa.PrivateKey
   220  	)
   221  
   222  	if secret, err = rsa.GenerateKey(rand.Reader, 2014); err != nil {
   223  		return
   224  	}
   225  
   226  	secretDer := x509.MarshalPKCS1PrivateKey(secret)
   227  	secretBlk := pem.Block{
   228  		Type: "RSA PRIVATE KEY", Headers: nil, Bytes: secretDer,
   229  	}
   230  	if public, err = ssh.NewPublicKey(&secret.PublicKey); err != nil {
   231  		return
   232  	}
   233  
   234  	return string(pem.EncodeToMemory(&secretBlk)),
   235  		string(ssh.MarshalAuthorizedKey(public)), err
   236  }
   237  
   238  func (session *sessionContext) init() (err error) {
   239  	var (
   240  		caller              *user.User
   241  		usr                 string
   242  		netMask, netAddress []byte
   243  		cmdL                = []string{
   244  			"defaults", "read",
   245  			"/Library/Preferences/SystemConfiguration/com.apple.vmnet.plist",
   246  		}
   247  	)
   248  	// viper & cobra
   249  	session.rawArgs = viper.New()
   250  	session.rawArgs.SetEnvPrefix("COREOS")
   251  	session.rawArgs.AutomaticEnv()
   252  	session.rawArgs.BindPFlags(RootCmd.PersistentFlags())
   253  	session.debug = session.rawArgs.GetBool("debug")
   254  
   255  	if uid := os.Geteuid(); uid == 0 {
   256  		if usr = os.Getenv("SUDO_USER"); usr == "" {
   257  			return fmt.Errorf("Do not run this as 'root' user," +
   258  				"but as a regular user via 'sudo'")
   259  		}
   260  		if caller, err = user.Lookup(usr); err != nil {
   261  			return
   262  		}
   263  		session.hasPowers = true
   264  	} else {
   265  		session.hasPowers = false
   266  		if caller, err = user.Current(); err != nil {
   267  			return
   268  		}
   269  	}
   270  
   271  	if netAddress, err = exec.Command(cmdL[0],
   272  		append(cmdL[1:], "Shared_Net_Address")...).Output(); err != nil {
   273  		return
   274  	}
   275  
   276  	if netMask, err = exec.Command(cmdL[0],
   277  		append(cmdL[1:], "Shared_Net_Mask")...).Output(); err != nil {
   278  		return
   279  	}
   280  
   281  	session.address = strings.TrimSpace(string(netAddress))
   282  	session.netmask = strings.TrimSpace(string(netMask))
   283  	session.network = net.ParseIP(session.address).Mask(net.IPMask(net.ParseIP(
   284  		session.netmask).To4())).String()
   285  
   286  	session.configDir = filepath.Join(caller.HomeDir, "/.coreos/")
   287  	session.imageDir = filepath.Join(session.configDir, "/images/")
   288  	session.runDir = filepath.Join(session.configDir, "/running/")
   289  	session.tmpDir = filepath.Join(session.configDir, "/tmp/")
   290  
   291  	session.uid, session.gid = caller.Uid, caller.Gid
   292  	session.homedir = caller.HomeDir
   293  
   294  	if session.pwd, err = os.Getwd(); err != nil {
   295  		return
   296  	}
   297  
   298  	for _, i := range DefaultChannels {
   299  		if err =
   300  			os.MkdirAll(filepath.Join(session.imageDir, i), 0755); err != nil {
   301  			return
   302  		}
   303  	}
   304  
   305  	if err = os.MkdirAll(session.runDir, 0755); err != nil {
   306  		return
   307  	}
   308  	if err = os.MkdirAll(session.tmpDir, 0755); err != nil {
   309  		return
   310  	}
   311  	return normalizeOnDiskPermissions(session.configDir)
   312  }
   313  
   314  func (session *sessionContext) allowedToRun() (err error) {
   315  	if !session.hasPowers {
   316  		return fmt.Errorf("not enough previleges to start or forcefully " +
   317  			"halt VMs. use 'sudo'")
   318  	}
   319  	return
   320  }
   321  
   322  func normalizeChannelName(channel string) string {
   323  	for _, b := range DefaultChannels {
   324  		if b == channel {
   325  			return channel
   326  		}
   327  	}
   328  	log.Printf("'%s' is not a recognized CoreOS image channel. %s",
   329  		channel, "Using default ('alpha').")
   330  	return "alpha"
   331  }
   332  
   333  func normalizeVersion(version string) string {
   334  	if version == "latest" {
   335  		return version
   336  	}
   337  	if _, err := semver.Parse(version); err != nil {
   338  		log.Printf("'%s' is not in a recognizable CoreOS version format. %s",
   339  			version, "Using default ('latest') instead")
   340  		return "latest"
   341  	}
   342  	return version
   343  }
   344  
   345  func (vm *VMInfo) isActive() bool {
   346  	if p, _ := ps.FindProcess(vm.Pid); p == nil ||
   347  		!strings.HasSuffix(p.Executable(), "corectl") {
   348  		return false
   349  	}
   350  	return true
   351  }
   352  
   353  func (vm *VMInfo) metadataService() (endpoint string, err error) {
   354  	var (
   355  		free         net.Listener
   356  		foundGuestIP sync.Once
   357  		mux, root    = http.NewServeMux(), "/" + vm.Name
   358  		rIP          = func(s string) string { return strings.Split(s, ":")[0] }
   359  		netcfg       = net.IPNet{
   360  			IP:   net.ParseIP(engine.address),
   361  			Mask: net.IPMask(net.ParseIP(engine.netmask).To4()),
   362  		}
   363  		isAllowed = func(origin string, w http.ResponseWriter) bool {
   364  			if netcfg.Contains(net.ParseIP(origin)) {
   365  				w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   366  				w.WriteHeader(http.StatusOK)
   367  				return true
   368  			}
   369  			w.WriteHeader(http.StatusPreconditionFailed)
   370  			w.Write(nil)
   371  			return false
   372  		}
   373  	)
   374  
   375  	if free, err = net.Listen("tcp", "127.0.0.1:0"); err != nil {
   376  		return
   377  	}
   378  
   379  	if vm.CloudConfig != "" && vm.CClocation == Local {
   380  		var txt []byte
   381  		if txt, err = ioutil.ReadFile(vm.CloudConfig); err != nil {
   382  			return
   383  		}
   384  
   385  		mux.HandleFunc(root+"/cloud-config",
   386  			func(w http.ResponseWriter, r *http.Request) {
   387  				if isAllowed(rIP(r.RemoteAddr), w) {
   388  					w.Write(txt)
   389  					foundGuestIP.Do(func() {
   390  						vm.publicIP <- rIP(r.RemoteAddr)
   391  					})
   392  				}
   393  			})
   394  	}
   395  
   396  	mux.HandleFunc(root+"/sshKey",
   397  		func(w http.ResponseWriter, r *http.Request) {
   398  			if isAllowed(rIP(r.RemoteAddr), w) {
   399  				w.Write([]byte(vm.InternalSSHauthKey))
   400  				if !(vm.CloudConfig != "" && vm.CClocation == Local) {
   401  					foundGuestIP.Do(func() {
   402  						vm.publicIP <- rIP(r.RemoteAddr)
   403  					})
   404  				}
   405  			}
   406  		})
   407  	mux.HandleFunc(root+"/hostname",
   408  		func(w http.ResponseWriter, r *http.Request) {
   409  			if isAllowed(rIP(r.RemoteAddr), w) {
   410  				w.Write([]byte(vm.Name))
   411  			}
   412  		})
   413  	mux.HandleFunc(root+"/homedir",
   414  		func(w http.ResponseWriter, r *http.Request) {
   415  			if isAllowed(rIP(r.RemoteAddr), w) {
   416  				w.Write([]byte(engine.homedir))
   417  			}
   418  		})
   419  	mux.HandleFunc(root+"/nfs",
   420  		func(w http.ResponseWriter, r *http.Request) {
   421  			if isAllowed(rIP(r.RemoteAddr), w) {
   422  				w.Write([]byte(engine.address))
   423  			}
   424  		})
   425  
   426  	srv := &http.Server{
   427  		Addr:    fmt.Sprintf(":%v", free.Addr().(*net.TCPAddr).Port),
   428  		Handler: mux,
   429  	}
   430  	go func() {
   431  		defer free.Close()
   432  		srv.ListenAndServe()
   433  	}()
   434  
   435  	return fmt.Sprintf("http://%v%v%v", engine.address, srv.Addr, root), err
   436  }