github.com/grailbio/base@v0.0.11/cmd/grail-ticket/main.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // The following enables go generate to generate the doc.go file.
     6  //go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help
     7  
     8  package main
     9  
    10  import (
    11  	"fmt"
    12  	"io/ioutil"
    13  	"log"
    14  	"os"
    15  	"os/exec"
    16  	"syscall"
    17  	"time"
    18  
    19  	_ "github.com/grailbio/base/cmdutil/interactive"
    20  	"github.com/grailbio/base/security/ticket"
    21  	_ "github.com/grailbio/v23/factories/grail"
    22  	"v.io/v23/context"
    23  	"v.io/v23/vdl"
    24  	"v.io/x/lib/cmdline"
    25  	"v.io/x/ref/lib/v23cmd"
    26  	"v.io/x/ref/lib/vdl/codegen/json"
    27  )
    28  
    29  var (
    30  	timeoutFlag       time.Duration
    31  	authorityCertFlag string
    32  	certFlag          string
    33  	keyFlag           string
    34  	rationaleFlag     string
    35  	jsonOnlyFlag      bool
    36  	listFlag          bool
    37  )
    38  
    39  func newCmdRoot() *cmdline.Command {
    40  	root := &cmdline.Command{
    41  		Runner: v23cmd.RunnerFunc(run),
    42  		Name:   "grail-ticket",
    43  		Short:  "Retrieve a ticket from a ticket-server",
    44  		Long: `
    45  Command grail-ticket retrieves a ticket from a ticket-server. A ticket is
    46  identified using a Vanadium name.
    47  
    48  Examples:
    49  
    50    grail-ticket tickets/eng/dev/aws
    51    grail-ticket /127.0.0.1:8000/eng/dev/aws
    52  
    53  Note that tickets can be enumerated using the 'namespace' Vanadium tool:
    54  
    55    namespace glob tickets/...
    56    namespace glob tickets/eng/...
    57    namespace glob /127.0.0.1:8000/...
    58    namespace glob /127.0.0.1:8000/eng/...
    59  `,
    60  		ArgsName: "<ticket>",
    61  		LookPath: false,
    62  	}
    63  	root.Flags.DurationVar(&timeoutFlag, "timeout", 90*time.Second, "Timeout for the requests to the ticket-server")
    64  	root.Flags.BoolVar(&jsonOnlyFlag, "json-only", false, "Force a JSON output even for the tickets that have special handling")
    65  	root.Flags.BoolVar(&listFlag, "list", false, "List accessible tickets")
    66  	root.Flags.StringVar(&authorityCertFlag, "authority-cert", "", "PEM file to store the CA cert for a TLS-based ticket")
    67  	root.Flags.StringVar(&certFlag, "cert", "", "PEM file to store the cert for a TLS-based ticket")
    68  	root.Flags.StringVar(&keyFlag, "key", "", "PEM file to store the private key for a TLS-based ticket")
    69  	root.Flags.StringVar(&rationaleFlag, "rationale", "", "Rationale for accessing ticket")
    70  	return root
    71  }
    72  
    73  func saveCredentials(creds ticket.TlsCredentials) error {
    74  	if err := ioutil.WriteFile(authorityCertFlag, []byte(creds.AuthorityCert), 0644); err != nil {
    75  		return err
    76  	}
    77  	if err := ioutil.WriteFile(certFlag, []byte(creds.Cert), 0644); err != nil {
    78  		return err
    79  	}
    80  	return ioutil.WriteFile(keyFlag, []byte(creds.Key), 0600)
    81  }
    82  
    83  func run(ctx *context.T, env *cmdline.Env, args []string) error {
    84  	if len(args) == 0 {
    85  		return env.UsageErrorf("At least one arguments (<ticket>) is required.")
    86  	}
    87  
    88  	ticketPath := args[0]
    89  	if listFlag {
    90  		fmt.Println("Listing all accessible tickets (this may take up to 90 seconds)...")
    91  		client := ticket.ListServiceClient(ticketPath + "/list")
    92  		tickets, err := client.List(ctx)
    93  		if err != nil {
    94  			return err
    95  		}
    96  		for _, t := range tickets {
    97  			fmt.Println(t)
    98  		}
    99  		return nil
   100  	}
   101  
   102  	client := ticket.TicketServiceClient(ticketPath)
   103  	ctx, cancel := context.WithTimeout(ctx, timeoutFlag)
   104  	defer cancel()
   105  
   106  	var t ticket.Ticket
   107  	var err error
   108  	if rationaleFlag != "" {
   109  		t, err = client.GetWithArgs(ctx, map[string]string{
   110  			ticket.ControlRationale.String(): rationaleFlag,
   111  		})
   112  	} else {
   113  		t, err = client.Get(ctx)
   114  	}
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	if jsonOnlyFlag {
   120  		jsonOutput := json.Const(vdl.ValueOf(t.Interface()), "", nil)
   121  		fmt.Println(jsonOutput)
   122  		return nil
   123  	}
   124  
   125  	if t.Index() == (ticket.TicketGenericTicket{}).Index() {
   126  		fmt.Print(string((t.Interface().(ticket.GenericTicket)).Data))
   127  		return nil
   128  	}
   129  
   130  	if len(authorityCertFlag)+len(certFlag)+len(keyFlag) > 0 {
   131  		if len(authorityCertFlag)*len(certFlag)*len(keyFlag) == 0 {
   132  			return fmt.Errorf("-authority-cert=%q, -cert=%q, -key=%q flags need to be all empty or all non-empty", authorityCertFlag, certFlag, keyFlag)
   133  		}
   134  
   135  		switch t.Index() {
   136  		case (ticket.TicketDockerTicket{}).Index():
   137  			return saveCredentials(t.(ticket.TicketDockerTicket).Value.Credentials)
   138  		case (ticket.TicketDockerServerTicket{}).Index():
   139  			return saveCredentials(t.(ticket.TicketDockerServerTicket).Value.Credentials)
   140  		case (ticket.TicketDockerClientTicket{}).Index():
   141  			return saveCredentials(t.(ticket.TicketDockerClientTicket).Value.Credentials)
   142  		case (ticket.TicketTlsServerTicket{}).Index():
   143  			return saveCredentials(t.(ticket.TicketTlsServerTicket).Value.Credentials)
   144  		case (ticket.TicketTlsClientTicket{}).Index():
   145  			return saveCredentials(t.(ticket.TicketTlsClientTicket).Value.Credentials)
   146  		}
   147  	}
   148  
   149  	if t.Index() == (ticket.TicketAwsTicket{}).Index() && len(args) > 1 {
   150  		creds := t.(ticket.TicketAwsTicket).Value.AwsCredentials
   151  		awsEnv := map[string]string{
   152  			"AWS_ACCESS_KEY_ID":     creds.AccessKeyId,
   153  			"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
   154  			"AWS_SESSION_TOKEN":     creds.SessionToken,
   155  		}
   156  
   157  		args = args[1:]
   158  		path, err := exec.LookPath(args[0])
   159  		if err != nil {
   160  			log.Fatal(err)
   161  		}
   162  		for k := range awsEnv {
   163  			os.Unsetenv(k)
   164  		}
   165  		env := os.Environ()
   166  		for k, v := range awsEnv {
   167  			env = append(env, fmt.Sprintf("%s=%s", k, v))
   168  		}
   169  
   170  		// run runs a program with certain arguments and certain environment
   171  		// variables. This function never returns. The arguments list contains
   172  		// the name of the program.
   173  		return syscall.Exec(path, args, env)
   174  	}
   175  
   176  	jsonOutput := json.Const(vdl.ValueOf(t.Interface()), "", nil)
   177  	fmt.Println(jsonOutput)
   178  	return nil
   179  }
   180  
   181  func main() {
   182  	cmdline.HideGlobalFlagsExcept()
   183  	cmdline.Main(newCmdRoot())
   184  }