github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/security/ticket/ticket.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package ticket
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"os"
    11  	"reflect"
    12  	"strings"
    13  
    14  	"github.com/aws/aws-sdk-go/aws/session"
    15  	"github.com/Schaudge/grailbase/common/log"
    16  	"github.com/Schaudge/grailbase/security/keycrypt"
    17  	"v.io/v23/context"
    18  	"v.io/v23/security"
    19  )
    20  
    21  // TicketContext wraps the informations that needs to carry around between
    22  // varius ticket functions.
    23  type TicketContext struct {
    24  	ctx             *context.T
    25  	session         *session.Session
    26  	remoteBlessings security.Blessings
    27  }
    28  
    29  // NewTicketContext allows creating a TicketContext without unncessary exporting
    30  // its fields.
    31  func NewTicketContext(ctx *context.T, session *session.Session, remoteBlessings security.Blessings) *TicketContext {
    32  	return &TicketContext{
    33  		ctx:             ctx,
    34  		session:         session,
    35  		remoteBlessings: remoteBlessings,
    36  	}
    37  }
    38  
    39  // Builder is the interface for building a Ticket.
    40  type Builder interface {
    41  	Build(ctx *TicketContext, parameters []Parameter) (Ticket, error)
    42  }
    43  
    44  var (
    45  	_ Builder = (*TicketAwsTicket)(nil)
    46  	_ Builder = (*TicketS3Ticket)(nil)
    47  	_ Builder = (*TicketSshCertificateTicket)(nil)
    48  	_ Builder = (*TicketEcrTicket)(nil)
    49  	_ Builder = (*TicketTlsServerTicket)(nil)
    50  	_ Builder = (*TicketTlsClientTicket)(nil)
    51  	_ Builder = (*TicketDockerTicket)(nil)
    52  	_ Builder = (*TicketDockerServerTicket)(nil)
    53  	_ Builder = (*TicketDockerClientTicket)(nil)
    54  	_ Builder = (*TicketB2Ticket)(nil)
    55  	_ Builder = (*TicketVanadiumTicket)(nil)
    56  	_ Builder = (*TicketGenericTicket)(nil)
    57  )
    58  
    59  // Build builds a Ticket by running all the builders.
    60  func (t TicketAwsTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
    61  	r := TicketAwsTicket{}
    62  	var err error
    63  	if t.Value.AwsAssumeRoleBuilder != nil {
    64  		r, err = t.Value.AwsAssumeRoleBuilder.newAwsTicket(ctx)
    65  		if err != nil {
    66  			return r, err
    67  		}
    68  		t.Value.AwsAssumeRoleBuilder = nil
    69  	} else if t.Value.AwsSessionBuilder != nil {
    70  		err = t.Value.AwsSessionBuilder.AwsCredentials.kmsInterpolate()
    71  		if err != nil {
    72  			return t, err
    73  		}
    74  
    75  		r, err = t.Value.AwsSessionBuilder.newAwsTicket(ctx)
    76  		if err != nil {
    77  			return r, err
    78  		}
    79  		t.Value.AwsSessionBuilder = nil
    80  	}
    81  	r = *mergeOrDie(ctx, &r, &t).(*TicketAwsTicket)
    82  	err = r.Value.AwsCredentials.kmsInterpolate()
    83  
    84  	return r, err
    85  }
    86  
    87  func (t *AwsCredentials) kmsInterpolate() (err error) {
    88  	t.SecretAccessKey, err = kmsInterpolationString(t.SecretAccessKey)
    89  	return err
    90  }
    91  
    92  // Build builds a Ticket by running all the builders.
    93  func (t TicketS3Ticket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
    94  	r := TicketS3Ticket{}
    95  	var err error
    96  	if t.Value.AwsAssumeRoleBuilder != nil {
    97  		r, err = t.Value.AwsAssumeRoleBuilder.newS3Ticket(ctx)
    98  		if err != nil {
    99  			return r, err
   100  		}
   101  		t.Value.AwsAssumeRoleBuilder = nil
   102  	} else if t.Value.AwsSessionBuilder != nil {
   103  		err = t.Value.AwsSessionBuilder.AwsCredentials.kmsInterpolate()
   104  		if err != nil {
   105  			return t, err
   106  		}
   107  
   108  		r, err = t.Value.AwsSessionBuilder.newS3Ticket(ctx)
   109  		if err != nil {
   110  			return r, err
   111  		}
   112  		t.Value.AwsSessionBuilder = nil
   113  	}
   114  	r = *mergeOrDie(ctx, &r, &t).(*TicketS3Ticket)
   115  	err = r.Value.AwsCredentials.kmsInterpolate()
   116  	return r, err
   117  }
   118  
   119  // Build builds a Ticket by running all the builders.
   120  func (t TicketSshCertificateTicket) Build(ctx *TicketContext, parameters []Parameter) (Ticket, error) {
   121  	rCompute := TicketSshCertificateTicket{}
   122  
   123  	// Populate the ComputeInstances first as input to the SSH CertBuilder
   124  	if t.Value.AwsComputeInstancesBuilder != nil {
   125  		var instanceBuilder = t.Value.AwsComputeInstancesBuilder
   126  		if instanceBuilder.AwsAccountLookupRole != "" {
   127  			instances, err := AwsEc2InstanceLookup(ctx, instanceBuilder)
   128  			if err != nil {
   129  				return nil, err
   130  			}
   131  			rCompute.Value.ComputeInstances = instances
   132  		} else {
   133  			return rCompute, fmt.Errorf("AwsAccountLookupRole required for AwsComputeInstancesBuilder.")
   134  		}
   135  	}
   136  
   137  	rSsh := TicketSshCertificateTicket{}
   138  	if t.Value.SshCertAuthorityBuilder != nil {
   139  
   140  		// Set the PublicKey parameter on the builder from the input parameters
   141  		// NOTE: If multiple publicKeys are provided as input, use the last one
   142  		for _, param := range parameters {
   143  			if param.Key == "PublicKey" {
   144  				t.Value.SshCertAuthorityBuilder.PublicKey = param.Value
   145  			}
   146  		}
   147  
   148  		var err error
   149  		rSsh, err = t.Value.SshCertAuthorityBuilder.newSshCertificateTicket(ctx)
   150  		if err != nil {
   151  			return rSsh, err
   152  		}
   153  		t.Value.SshCertAuthorityBuilder = nil
   154  	}
   155  
   156  	r := *mergeOrDie(ctx, &rCompute, &rSsh).(*TicketSshCertificateTicket)
   157  	return *mergeOrDie(ctx, &r, &t).(*TicketSshCertificateTicket), nil
   158  }
   159  
   160  // Build builds a Ticket by running all the builders.
   161  func (t TicketEcrTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   162  	r := TicketEcrTicket{}
   163  	if t.Value.AwsAssumeRoleBuilder != nil {
   164  		var err error
   165  		r, err = t.Value.AwsAssumeRoleBuilder.newEcrTicket(ctx)
   166  		if err != nil {
   167  			return r, err
   168  		}
   169  		t.Value.AwsAssumeRoleBuilder = nil
   170  	}
   171  	return *mergeOrDie(ctx, &r, &t).(*TicketEcrTicket), nil
   172  }
   173  
   174  // Build builds a Ticket by running all the builders.
   175  func (t TicketTlsServerTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   176  	r := TicketTlsServerTicket{}
   177  	if t.Value.TlsCertAuthorityBuilder != nil {
   178  		var err error
   179  		r, err = t.Value.TlsCertAuthorityBuilder.newTlsServerTicket(ctx)
   180  		if err != nil {
   181  			return r, err
   182  		}
   183  		t.Value.TlsCertAuthorityBuilder = nil
   184  	}
   185  	return *mergeOrDie(ctx, &r, &t).(*TicketTlsServerTicket), nil
   186  }
   187  
   188  // Build builds a Ticket by running all the builders.
   189  func (t TicketTlsClientTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   190  	r := TicketTlsClientTicket{}
   191  	if t.Value.TlsCertAuthorityBuilder != nil {
   192  		var err error
   193  		r, err = t.Value.TlsCertAuthorityBuilder.newTlsClientTicket(ctx)
   194  		if err != nil {
   195  			return r, err
   196  		}
   197  		t.Value.TlsCertAuthorityBuilder = nil
   198  	}
   199  	return *mergeOrDie(ctx, &r, &t).(*TicketTlsClientTicket), nil
   200  }
   201  
   202  // Build builds a Ticket by running all the builders.
   203  func (t TicketDockerTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   204  	r := TicketDockerTicket{}
   205  	if t.Value.TlsCertAuthorityBuilder != nil {
   206  		var err error
   207  		r, err = t.Value.TlsCertAuthorityBuilder.newDockerTicket(ctx)
   208  		if err != nil {
   209  			return r, err
   210  		}
   211  		t.Value.TlsCertAuthorityBuilder = nil
   212  	}
   213  	return *mergeOrDie(ctx, &r, &t).(*TicketDockerTicket), nil
   214  }
   215  
   216  // Build builds a Ticket by running all the builders.
   217  func (t TicketDockerServerTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   218  	r := TicketDockerServerTicket{}
   219  	if t.Value.TlsCertAuthorityBuilder != nil {
   220  		var err error
   221  		r, err = t.Value.TlsCertAuthorityBuilder.newDockerServerTicket(ctx)
   222  		if err != nil {
   223  			return r, err
   224  		}
   225  		t.Value.TlsCertAuthorityBuilder = nil
   226  	}
   227  	return *mergeOrDie(ctx, &r, &t).(*TicketDockerServerTicket), nil
   228  }
   229  
   230  // Build builds a Ticket by running all the builders.
   231  func (t TicketDockerClientTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   232  	r := TicketDockerClientTicket{}
   233  	if t.Value.TlsCertAuthorityBuilder != nil {
   234  		var err error
   235  		r, err = t.Value.TlsCertAuthorityBuilder.newDockerClientTicket(ctx)
   236  		if err != nil {
   237  			return r, err
   238  		}
   239  		t.Value.TlsCertAuthorityBuilder = nil
   240  	}
   241  	return *mergeOrDie(ctx, &r, &t).(*TicketDockerClientTicket), nil
   242  }
   243  
   244  // Build builds a Ticket by running all the builders.
   245  func (t TicketB2Ticket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   246  	r := TicketB2Ticket{}
   247  	if t.Value.B2AccountAuthorizationBuilder != nil {
   248  		var err error
   249  		r, err = t.Value.B2AccountAuthorizationBuilder.newB2Ticket(ctx)
   250  		if err != nil {
   251  			return r, err
   252  		}
   253  		t.Value.B2AccountAuthorizationBuilder = nil
   254  	}
   255  	return *mergeOrDie(ctx, &r, &t).(*TicketB2Ticket), nil
   256  }
   257  
   258  // Build builds a Ticket by running all the builders.
   259  func (t TicketVanadiumTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   260  	r := TicketVanadiumTicket{}
   261  	if t.Value.VanadiumBuilder != nil {
   262  		var err error
   263  		r, err = t.Value.VanadiumBuilder.newVanadiumTicket(ctx)
   264  		if err != nil {
   265  			return r, err
   266  		}
   267  		t.Value.VanadiumBuilder = nil
   268  	}
   269  	return *mergeOrDie(ctx, &r, &t).(*TicketVanadiumTicket), nil
   270  }
   271  
   272  // Build builds a Ticket.
   273  func (t TicketGenericTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) {
   274  	r := TicketGenericTicket{}
   275  	r = *mergeOrDie(ctx, &r, &t).(*TicketGenericTicket)
   276  	var err error
   277  	r.Value.Data, err = kmsInterpolationBytes(r.Value.Data)
   278  	return r, err
   279  }
   280  
   281  // merge i2 in i1 by overwriting in i1 all the non-zero fields in i2. The i1
   282  // and i2 needs to be references to the same type. Only simple types (bool,
   283  // numeric, string) and string are supported.
   284  func mergeOrDie(ctx *TicketContext, i1, i2 interface{}) interface{} {
   285  	if reflect.DeepEqual(i1, i2) {
   286  		return i1
   287  	}
   288  	v1, v2 := reflect.ValueOf(i1).Elem(), reflect.ValueOf(i2).Elem()
   289  	k1, k2 := v1.Kind(), v2.Kind()
   290  	if k1 != k2 {
   291  		log.Error(ctx.ctx, "different types in merge: %+v (%s) vs %v (%s)", v1, v1.Kind(), v2, v2.Kind())
   292  		os.Exit(255)
   293  	}
   294  	switch k1 {
   295  	case reflect.Struct:
   296  		for i := 0; i < v1.NumField(); i++ {
   297  			f1, f2 := v1.Field(i), v2.Field(i)
   298  			if !f1.CanSet() {
   299  				continue
   300  			}
   301  			v := mergeOrDie(ctx, f1.Addr().Interface(), f2.Addr().Interface())
   302  			f1.Set(reflect.Indirect(reflect.ValueOf(v)))
   303  		}
   304  	case reflect.Map:
   305  		// TODO(razvanm): figure out why the default doesn't work.
   306  		if v2.Len() > 0 {
   307  			v1.Set(v2)
   308  		}
   309  	default:
   310  		zero := reflect.Zero(v2.Type()).Interface()
   311  		if !reflect.DeepEqual(v2.Interface(), zero) {
   312  			v1.Set(v2)
   313  		}
   314  	}
   315  	return i1
   316  }
   317  
   318  // kmsInterpolation takes a string and, if the values is a 'kms://' URL it
   319  // returns the corresponding keycrypt values.
   320  func kmsInterpolationString(s string) (string, error) {
   321  	if !strings.HasPrefix(s, "kms://") {
   322  		return s, nil
   323  	}
   324  
   325  	secret, err := keycrypt.Lookup(s)
   326  	if err != nil {
   327  		return "", fmt.Errorf("keycrypt.Lookup(%q): %v", s, err)
   328  	}
   329  
   330  	secretBytes, err := secret.Get()
   331  	if err != nil {
   332  		return "", fmt.Errorf("Secret.Get(%q): %v", s, err)
   333  	}
   334  
   335  	return string(secretBytes), nil
   336  }
   337  
   338  func kmsInterpolationBytes(b []byte) ([]byte, error) {
   339  	if !bytes.HasPrefix(b, []byte("kms://")) {
   340  		return b, nil
   341  	}
   342  
   343  	secret, err := keycrypt.Lookup(string(b))
   344  	if err != nil {
   345  		return nil, fmt.Errorf("keycrypt.Lookup(%q): %v", b, err)
   346  	}
   347  
   348  	secretBytes, err := secret.Get()
   349  	if err != nil {
   350  		return nil, fmt.Errorf("Secret.Get(%q): %v", b, err)
   351  	}
   352  
   353  	return secretBytes, nil
   354  }