github.com/coreos/mantle@v0.13.0/update/updater.go (about)

     1  // Copyright 2016 CoreOS, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package update
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/sha256"
    20  	"fmt"
    21  	"io"
    22  	"os"
    23  
    24  	"github.com/coreos/pkg/capnslog"
    25  	"github.com/golang/protobuf/proto"
    26  
    27  	"github.com/coreos/mantle/update/metadata"
    28  )
    29  
    30  var (
    31  	plog = capnslog.NewPackageLogger("github.com/coreos/mantle", "update")
    32  )
    33  
    34  type Updater struct {
    35  	SrcPartition string
    36  	DstPartition string
    37  
    38  	payload *Payload
    39  }
    40  
    41  func (u *Updater) OpenPayload(file string) error {
    42  	plog.Infof("Loading payload from %s", file)
    43  
    44  	f, err := os.Open(file)
    45  	if err != nil {
    46  		return err
    47  	}
    48  
    49  	return u.UsePayload(f)
    50  }
    51  
    52  func (u *Updater) UsePayload(r io.Reader) (err error) {
    53  	u.payload, err = NewPayloadFrom(r)
    54  	return err
    55  }
    56  
    57  func (u *Updater) Update() error {
    58  	for _, proc := range u.payload.Procedures() {
    59  		var err error
    60  		switch proc.GetType() {
    61  		case installProcedure_partition:
    62  			err = u.UpdatePartition(proc)
    63  		case metadata.InstallProcedure_KERNEL:
    64  			err = u.UpdateKernel(proc)
    65  		default:
    66  			err = fmt.Errorf("unknown procedure type %s", proc.GetType())
    67  		}
    68  		if err != nil {
    69  			return err
    70  		}
    71  	}
    72  	return u.payload.VerifySignature()
    73  }
    74  
    75  func (u *Updater) UpdatePartition(proc *metadata.InstallProcedure) error {
    76  	return u.updateCommon(proc, "partition", u.SrcPartition, u.DstPartition)
    77  }
    78  
    79  func (u *Updater) UpdateKernel(proc *metadata.InstallProcedure) error {
    80  	return fmt.Errorf("KERNEL")
    81  }
    82  
    83  func (u *Updater) updateCommon(proc *metadata.InstallProcedure, procName, srcPath, dstPath string) (err error) {
    84  	var srcFile, dstFile *os.File
    85  	if proc.OldInfo.GetSize() != 0 && len(proc.OldInfo.Hash) != 0 {
    86  		if srcFile, err = os.Open(srcPath); err != nil {
    87  			return err
    88  		}
    89  		defer srcFile.Close()
    90  
    91  		if err = VerifyInfo(srcFile, proc.OldInfo); err != nil {
    92  			return err
    93  		}
    94  	}
    95  
    96  	dstFile, err = os.OpenFile(u.DstPartition, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
    97  	if err != nil {
    98  		return err
    99  	}
   100  	defer dstFile.Close()
   101  
   102  	progress := 0
   103  	for _, op := range u.payload.Operations(proc) {
   104  		progress++
   105  		plog.Infof("%s operation %d", procName, progress)
   106  		if err := op.Apply(dstFile, srcFile); err != nil {
   107  			return fmt.Errorf("%s operation %d: %v\n%s",
   108  				procName, progress, err,
   109  				proto.MarshalTextString(op.Operation))
   110  		}
   111  	}
   112  
   113  	return VerifyInfo(dstFile, proc.NewInfo)
   114  }
   115  
   116  func VerifyInfo(file *os.File, info *metadata.InstallInfo) error {
   117  	if _, err := file.Seek(0, os.SEEK_SET); err != nil {
   118  		return err
   119  	}
   120  
   121  	sha := sha256.New()
   122  	if n, err := io.CopyN(sha, file, int64(info.GetSize())); err == io.EOF {
   123  		return fmt.Errorf("%s: expected %d bytes but read %d bytes",
   124  			file.Name(), info.GetSize(), n)
   125  	} else if err != nil {
   126  		return err
   127  	}
   128  
   129  	sum := sha.Sum(nil)
   130  	if !bytes.Equal(info.Hash, sum) {
   131  		return fmt.Errorf("%s: expected hash %x but got %x",
   132  			file.Name(), info.Hash, sum)
   133  	}
   134  
   135  	return nil
   136  }