github.com/zppinho/prow@v0.0.0-20240510014325-1738badeb017/cmd/webhook-server/clients.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"os"
    25  	"path/filepath"
    26  	"strings"
    27  
    28  	"sigs.k8s.io/prow/cmd/webhook-server/secretmanager"
    29  )
    30  
    31  type GCPClient struct {
    32  	client   *secretmanager.Client
    33  	secretID string
    34  }
    35  
    36  func newGCPClient(client *secretmanager.Client, secretID string) *GCPClient {
    37  	return &GCPClient{
    38  		client:   client,
    39  		secretID: secretID,
    40  	}
    41  }
    42  
    43  func (g *GCPClient) CreateSecret(ctx context.Context, secretID string) error {
    44  	_, err := g.client.CreateSecret(ctx, secretID)
    45  	if err != nil {
    46  		return fmt.Errorf("could not create secret %v", err)
    47  	}
    48  	return nil
    49  }
    50  
    51  func (g *GCPClient) AddSecretVersion(ctx context.Context, secretName string, payload []byte) error {
    52  	if err := g.client.AddSecretVersion(ctx, secretName, payload); err != nil {
    53  		return fmt.Errorf("could not add secret data %v", err)
    54  	}
    55  	return nil
    56  }
    57  
    58  func (g *GCPClient) GetSecretValue(ctx context.Context, secretName string, versionName string) ([]byte, bool, error) {
    59  	err := g.checkSecret(ctx, secretName)
    60  	if err != nil && err == os.ErrNotExist {
    61  		return nil, false, nil
    62  	} else if err != nil {
    63  		return nil, false, err
    64  	}
    65  	payload, err := g.client.GetSecretValue(ctx, secretName, versionName)
    66  	if err != nil {
    67  		return nil, false, fmt.Errorf("error getting secret value %v", err)
    68  	}
    69  
    70  	return payload, true, nil
    71  }
    72  
    73  func (g *GCPClient) checkSecret(ctx context.Context, secretName string) error {
    74  	res, err := g.client.ListSecrets(ctx)
    75  	if err != nil {
    76  		return fmt.Errorf("could not make call to list secrets successfully %v", err)
    77  	}
    78  	for _, secret := range res {
    79  		if strings.Contains(secret.Name, g.secretID) {
    80  			return nil
    81  		}
    82  	}
    83  	return os.ErrNotExist
    84  }
    85  
    86  // for integration testing purposes. Not to be used in prod
    87  type localFSClient struct {
    88  	path   string
    89  	expiry int
    90  	dns    []string
    91  }
    92  
    93  func NewLocalFSClient(path string, expiry int, dns []string) *localFSClient {
    94  	return &localFSClient{
    95  		path:   path,
    96  		expiry: expiry,
    97  		dns:    dns,
    98  	}
    99  }
   100  
   101  func (l *localFSClient) CreateSecret(ctx context.Context, secretID string) error {
   102  	if _, err := os.Stat(l.path); errors.Is(err, os.ErrNotExist) {
   103  		err := os.Mkdir(l.path, 0755)
   104  		if err != nil {
   105  			return fmt.Errorf("unable to create secret dir %v", err)
   106  		}
   107  	}
   108  	return nil
   109  }
   110  
   111  func (l *localFSClient) AddSecretVersion(ctx context.Context, secretName string, payload []byte) error {
   112  	certFile := filepath.Join(l.path, certFile)
   113  	privKeyFile := filepath.Join(l.path, privKeyFile)
   114  	caBundleFile := filepath.Join(l.path, caBundleFile)
   115  
   116  	serverCertPerm, serverPrivKey, caPem, _, err := genSecretData(l.expiry, l.dns)
   117  	if err != nil {
   118  		return err
   119  	}
   120  	if err := os.WriteFile(certFile, []byte(serverCertPerm), 0666); err != nil {
   121  		return fmt.Errorf("could not write contents of cert file")
   122  	}
   123  	if err := os.WriteFile(privKeyFile, []byte(serverPrivKey), 0666); err != nil {
   124  		return fmt.Errorf("could not write contents of privkey file")
   125  	}
   126  	if err := os.WriteFile(caBundleFile, []byte(caPem), 0666); err != nil {
   127  		return fmt.Errorf("could not write contents of caBundle file")
   128  	}
   129  	return nil
   130  }
   131  
   132  func (l *localFSClient) GetSecretValue(ctx context.Context, secretName string, versionName string) ([]byte, bool, error) {
   133  	err := l.checkSecret(ctx, secretName)
   134  	if err != nil && err == os.ErrNotExist {
   135  		return nil, false, nil
   136  	} else if err != nil {
   137  		return nil, false, err
   138  	}
   139  	secretsMap := make(map[string]string)
   140  	files, err := os.ReadDir(l.path)
   141  	if err != nil {
   142  		return nil, false, fmt.Errorf("could not read file path")
   143  	}
   144  	for _, f := range files {
   145  		content, err := os.ReadFile(filepath.Join(l.path, f.Name()))
   146  		if err != nil {
   147  			return nil, false, fmt.Errorf("error reading file %v", err)
   148  		}
   149  		switch f.Name() {
   150  		case certFile:
   151  			secretsMap[certFile] = string(content)
   152  		case privKeyFile:
   153  			secretsMap[privKeyFile] = string(content)
   154  		case caBundleFile:
   155  			secretsMap[caBundleFile] = string(content)
   156  		}
   157  	}
   158  	res, err := json.Marshal(secretsMap)
   159  	if err != nil {
   160  		return nil, false, fmt.Errorf("could not marshal secrets data %v", err)
   161  	}
   162  	return res, true, nil
   163  }
   164  
   165  func (l *localFSClient) checkSecret(ctx context.Context, secretName string) error {
   166  	_, err := os.Stat(l.path)
   167  	if err != nil && os.IsNotExist(err) {
   168  		return os.ErrNotExist
   169  	} else if err != nil {
   170  		return err
   171  	}
   172  
   173  	files, err := os.ReadDir(l.path)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if len(files) < 2 {
   178  		return nil
   179  	}
   180  	for _, f := range files {
   181  		_, err := os.ReadFile(filepath.Join(l.path, f.Name()))
   182  		if err != nil {
   183  			return fmt.Errorf("error reading file %v", err)
   184  		}
   185  	}
   186  	return nil
   187  }