gitee.com/mysnapcore/mysnapd@v0.1.0/cmd/snap-bootstrap/triggerwatch/triggerwatch.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package triggerwatch
    21  
    22  import (
    23  	"errors"
    24  	"fmt"
    25  	"os"
    26  	"os/signal"
    27  	"syscall"
    28  	"time"
    29  
    30  	"gitee.com/mysnapcore/mysnapd/logger"
    31  	"gitee.com/mysnapcore/mysnapd/osutil/udev/netlink"
    32  )
    33  
    34  type triggerProvider interface {
    35  	Open(filter triggerEventFilter, node string) (triggerDevice, error)
    36  	FindMatchingDevices(filter triggerEventFilter) ([]triggerDevice, error)
    37  }
    38  
    39  type triggerDevice interface {
    40  	WaitForTrigger(chan keyEvent)
    41  	String() string
    42  	Close()
    43  }
    44  
    45  type ueventConnection interface {
    46  	Connect(mode netlink.Mode) error
    47  	Close() error
    48  	Monitor(queue chan netlink.UEvent, errors chan error, matcher netlink.Matcher) func(time.Duration) bool
    49  }
    50  
    51  var (
    52  	// trigger mechanism
    53  	trigger       triggerProvider
    54  	getUEventConn = func() ueventConnection {
    55  		return &netlink.UEventConn{}
    56  	}
    57  
    58  	// wait for '1' to be pressed
    59  	triggerFilter = triggerEventFilter{Key: "KEY_1"}
    60  
    61  	ErrTriggerNotDetected     = errors.New("trigger not detected")
    62  	ErrNoMatchingInputDevices = errors.New("no matching input devices")
    63  )
    64  
    65  // Wait waits for a trigger on the available trigger devices for a given amount
    66  // of time. Returns nil if one was detected, ErrTriggerNotDetected if timeout
    67  // was hit, or other non-nil error.
    68  func Wait(timeout time.Duration, deviceTimeout time.Duration) error {
    69  	sigs := make(chan os.Signal, 1)
    70  	signal.Notify(sigs, syscall.SIGUSR1)
    71  	conn := getUEventConn()
    72  	if err := conn.Connect(netlink.UdevEvent); err != nil {
    73  		logger.Panicf("Unable to connect to Netlink Kobject UEvent socket")
    74  	}
    75  	defer conn.Close()
    76  
    77  	add := "add"
    78  	matcher := &netlink.RuleDefinitions{
    79  		Rules: []netlink.RuleDefinition{
    80  			{
    81  				Action: &add,
    82  				Env: map[string]string{
    83  					"SUBSYSTEM":         "input",
    84  					"ID_INPUT_KEYBOARD": "1",
    85  					"DEVNAME":           ".*",
    86  				},
    87  			},
    88  		},
    89  	}
    90  
    91  	ueventQueue := make(chan netlink.UEvent)
    92  	ueventErrors := make(chan error)
    93  	conn.Monitor(ueventQueue, ueventErrors, matcher)
    94  
    95  	if trigger == nil {
    96  		logger.Panicf("trigger is unset")
    97  	}
    98  
    99  	devices, err := trigger.FindMatchingDevices(triggerFilter)
   100  	if err != nil {
   101  		return fmt.Errorf("cannot list trigger devices: %v", err)
   102  	}
   103  
   104  	if devices == nil {
   105  		devices = make([]triggerDevice, 0)
   106  	}
   107  
   108  	logger.Noticef("waiting for trigger key: %v", triggerFilter.Key)
   109  
   110  	detectKeyCh := make(chan keyEvent, len(devices))
   111  	for _, dev := range devices {
   112  		go dev.WaitForTrigger(detectKeyCh)
   113  		defer dev.Close()
   114  	}
   115  	foundDevice := len(devices) != 0
   116  
   117  	start := time.Now()
   118  	for {
   119  		timePassed := time.Now().Sub(start)
   120  		relTimeout := timeout - timePassed
   121  		relDeviceTimeout := deviceTimeout - timePassed
   122  		select {
   123  		case kev := <-detectKeyCh:
   124  			if kev.Err != nil {
   125  				return kev.Err
   126  			}
   127  			// channel got closed without an error
   128  			logger.Noticef("%s: + got trigger key %v", kev.Dev, triggerFilter.Key)
   129  			return nil
   130  		case <-time.After(relTimeout):
   131  			return ErrTriggerNotDetected
   132  		case <-time.After(relDeviceTimeout):
   133  			if !foundDevice {
   134  				return ErrNoMatchingInputDevices
   135  			}
   136  		case uevent := <-ueventQueue:
   137  			dev, err := trigger.Open(triggerFilter, uevent.Env["DEVNAME"])
   138  			if err != nil {
   139  				logger.Noticef("ignoring device %s that cannot be opened: %v", uevent.Env["DEVNAME"], err)
   140  			} else if dev != nil {
   141  				foundDevice = true
   142  				defer dev.Close()
   143  				go dev.WaitForTrigger(detectKeyCh)
   144  			}
   145  		case <-sigs:
   146  			logger.Noticef("Switching root")
   147  			if err := syscall.Chdir("/sysroot"); err != nil {
   148  				return fmt.Errorf("Cannot change directory: %w", err)
   149  			}
   150  			if err := syscall.Chroot("/sysroot"); err != nil {
   151  				return fmt.Errorf("Cannot change root: %w", err)
   152  			}
   153  			if err := syscall.Chdir("/"); err != nil {
   154  				return fmt.Errorf("Cannot change directory: %w", err)
   155  			}
   156  		}
   157  	}
   158  }