github.com/stolowski/snapd@v0.0.0-20210407085831-115137ce5a22/osutil/context.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package osutil
    21  
    22  import (
    23  	"context"
    24  	"io"
    25  	"os/exec"
    26  	"sync"
    27  	"sync/atomic"
    28  	"syscall"
    29  )
    30  
    31  // ContextWriter returns a discarding io.Writer which Write method
    32  // returns an error once the context is done.
    33  func ContextWriter(ctx context.Context) io.Writer {
    34  	return ctxWriter{ctx}
    35  }
    36  
    37  type ctxWriter struct {
    38  	ctx context.Context
    39  }
    40  
    41  func (w ctxWriter) Write(p []byte) (n int, err error) {
    42  	select {
    43  	case <-w.ctx.Done():
    44  		return 0, w.ctx.Err()
    45  	default:
    46  	}
    47  	return len(p), nil
    48  }
    49  
    50  // RunWithContext runs the given command, but kills it if the context
    51  // becomes done before the command finishes.
    52  func RunWithContext(ctx context.Context, cmd *exec.Cmd) error {
    53  	if err := ctx.Err(); err != nil {
    54  		return err
    55  	}
    56  
    57  	if err := cmd.Start(); err != nil {
    58  		return err
    59  	}
    60  
    61  	var ctxDone uint32
    62  	var wg sync.WaitGroup
    63  	waitDone := make(chan struct{})
    64  
    65  	wg.Add(1)
    66  	go func() {
    67  		select {
    68  		case <-ctx.Done():
    69  			atomic.StoreUint32(&ctxDone, 1)
    70  			cmd.Process.Kill()
    71  		case <-waitDone:
    72  		}
    73  		wg.Done()
    74  	}()
    75  
    76  	err := cmd.Wait()
    77  	close(waitDone)
    78  	wg.Wait()
    79  
    80  	if atomic.LoadUint32(&ctxDone) != 0 {
    81  		// do one last check to make sure the error from Wait is what we expect from Kill
    82  		if err, ok := err.(*exec.ExitError); ok {
    83  			if ws, ok := err.ProcessState.Sys().(syscall.WaitStatus); ok && ws.Signal() == syscall.SIGKILL {
    84  				return ctx.Err()
    85  			}
    86  		}
    87  	}
    88  	return err
    89  }