github.com/grailbio/base@v0.0.11/cmdutil/ticket_flags.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package cmdutil
     6  
     7  import (
     8  	"bufio"
     9  	"flag"
    10  	"os"
    11  	"path/filepath"
    12  	"strings"
    13  	"time"
    14  )
    15  
    16  // TicketFlags represents an implementation of flag.Value that can be used
    17  // to specify tickets either in comma separated form or by repeating the
    18  // same flag. That is, either:
    19  //   --flag=t1,t2,t3
    20  // and/or
    21  //   --flag=t1 --flag=t2 --flag=t3
    22  // and/or
    23  //   --flag=t1,t2 --flag=t3
    24  type TicketFlags struct {
    25  	set                bool
    26  	dedup              map[string]bool
    27  	fs                 *flag.FlagSet
    28  	ticketFlag, rcFlag string
    29  	Tickets            []string
    30  	TicketRCFile       string
    31  	ticketRCFlag       stringFlag
    32  	Timeout            time.Duration
    33  }
    34  
    35  // wrapper to catch explicit setting of a flag.
    36  type stringFlag struct {
    37  	set  bool
    38  	name string
    39  	val  *string
    40  }
    41  
    42  // Set implements flag.Value.
    43  func (sf *stringFlag) Set(v string) error {
    44  	sf.set = true
    45  	*sf.val = v
    46  	return nil
    47  }
    48  
    49  // String implements flag.Value.
    50  func (sf *stringFlag) String() string {
    51  	if sf.val == nil {
    52  		// called via flag.isZeroValue.
    53  		return ""
    54  	}
    55  	return *sf.val
    56  }
    57  
    58  // Set implements flag.Value.
    59  func (tf *TicketFlags) Set(v string) error {
    60  	if !tf.set {
    61  		// Clear any defaults if setting for the first time.
    62  		tf.Tickets = nil
    63  		tf.dedup = map[string]bool{}
    64  	}
    65  	for _, ps := range strings.Split(v, ",") {
    66  		if ps == "" {
    67  			continue
    68  		}
    69  		if !tf.dedup[ps] {
    70  			tf.Tickets = append(tf.Tickets, ps)
    71  		}
    72  		tf.dedup[ps] = true
    73  	}
    74  	tf.set = true
    75  	return nil
    76  }
    77  
    78  // setDefaults sets default ticket paths for the flag. These values are cleared
    79  // the first time the flag is explicitly parsed in the flag set.
    80  func (tf *TicketFlags) setDefaults(tickets []string) {
    81  	tf.Tickets = tickets
    82  	tf.dedup = map[string]bool{}
    83  	for _, t := range tickets {
    84  		tf.dedup[t] = true
    85  	}
    86  	tf.fs.Lookup(tf.ticketFlag).DefValue = strings.Join(tickets, ",")
    87  }
    88  
    89  // String implements flag.Value.
    90  func (tf *TicketFlags) String() string {
    91  	return strings.Join(tf.Tickets, ",")
    92  }
    93  
    94  // ReadEnvOrFile will attempt to obtain values for the tickets to use from
    95  // the environment or from a file if none have been explicitly set on the
    96  // command line. If no flags were specified it will read the environment
    97  // variable GRAIL_TICKETS and if that's empty it will attempt to read the
    98  // file specified by the ticketrc flag (or it's default value).
    99  func (tf *TicketFlags) ReadEnvOrFile() error {
   100  	if tf.set {
   101  		return nil
   102  	}
   103  	if te := os.Getenv("GRAIL_TICKETS"); len(te) > 0 {
   104  		return tf.Set(te)
   105  	}
   106  	f, err := os.Open(tf.TicketRCFile)
   107  	if err != nil {
   108  		// It's ok for the rc file to not exist if it hasn't been set.
   109  		if tf.ticketRCFlag.set && !os.IsExist(err) {
   110  			return err
   111  		}
   112  		return nil
   113  	}
   114  	defer f.Close()
   115  	sc := bufio.NewScanner(f)
   116  	for sc.Scan() {
   117  		if l := strings.TrimSpace(sc.Text()); len(l) > 0 {
   118  			_ = tf.Set(l)
   119  		}
   120  	}
   121  	return sc.Err()
   122  }
   123  
   124  // RegisterTicketFlags registers the ticket related flags with the
   125  // supplied FlagSet. The flags are:
   126  // --<prefix>ticket
   127  // --<prefix>ticket-timeout
   128  // --<prefix>ticketrc
   129  func RegisterTicketFlags(fs *flag.FlagSet, prefix string, defaultTickets []string, flags *TicketFlags) {
   130  	flags.fs = fs
   131  	desc := "Comma separated list of GRAIL security tickets, and/or the flag may be repeated"
   132  	fs.Var(flags, prefix+"ticket", desc)
   133  	fs.DurationVar(&flags.Timeout, prefix+"ticket-timeout", time.Minute, "specifies the timeout duration for obtaining any single GRAIL security ticket")
   134  	flags.ticketRCFlag.name = prefix + "ticketrc"
   135  	flags.ticketRCFlag.val = &flags.TicketRCFile
   136  	flags.TicketRCFile = filepath.Join(os.Getenv("HOME"), ".ticketrc")
   137  	fs.Var(&flags.ticketRCFlag, flags.ticketRCFlag.name, "a file containing the tickets to use")
   138  	fs.Lookup(prefix + "ticketrc").DefValue = "$HOME/.ticketrc"
   139  	flags.ticketFlag = prefix + "ticket"
   140  	flags.rcFlag = prefix + "ticketrc"
   141  	flags.setDefaults(defaultTickets)
   142  }