github.com/freiheit-com/kuberpult@v1.24.2-0.20240328135542-315d5630abe6/services/cd-service/pkg/repository/testssh/server.go (about)

     1  /*This file is part of kuberpult.
     2  
     3  Kuberpult is free software: you can redistribute it and/or modify
     4  it under the terms of the Expat(MIT) License as published by
     5  the Free Software Foundation.
     6  
     7  Kuberpult is distributed in the hope that it will be useful,
     8  but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    10  MIT License for more details.
    11  
    12  You should have received a copy of the MIT License
    13  along with kuberpult. If not, see <https://directory.fsf.org/wiki/License:Expat>.
    14  
    15  Copyright 2023 freiheit.com*/
    16  
    17  // nolint: errcheck
    18  package testssh
    19  
    20  import (
    21  	"crypto/ed25519"
    22  	"encoding/pem"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"os"
    27  	"os/exec"
    28  	"path/filepath"
    29  	"time"
    30  
    31  	"github.com/mattn/go-shellwords"
    32  	"github.com/mikesmitty/edkey"
    33  	"golang.org/x/crypto/ssh"
    34  	"golang.org/x/crypto/ssh/knownhosts"
    35  )
    36  
    37  type TestServer struct {
    38  	Port       int
    39  	KnownHosts string
    40  	ClientKey  string
    41  	Url        string
    42  	Pushes     uint
    43  	l          net.Listener
    44  	execDelay  time.Duration
    45  }
    46  
    47  type envReq struct {
    48  	Env   string
    49  	Value string
    50  }
    51  
    52  type execReq struct {
    53  	Command string
    54  }
    55  
    56  type exitReq struct {
    57  	Status uint32
    58  }
    59  
    60  func New(workdir string) *TestServer {
    61  	//exhaustruct:ignore
    62  	ts := TestServer{}
    63  	// Allocate a new listening port
    64  	//exhaustruct:ignore
    65  	ts.l, _ = net.ListenTCP("tcp", &net.TCPAddr{})
    66  	ts.Port = ts.l.Addr().(*net.TCPAddr).Port
    67  
    68  	// Setup a private key for the server and write a known hosts file
    69  	_, servPriv, _ := ed25519.GenerateKey(nil)
    70  	ps, _ := ssh.NewSignerFromSigner(servPriv)
    71  	kh := knownhosts.Line([]string{"127.0.0.1"}, ps.PublicKey())
    72  	ts.KnownHosts = filepath.Join(workdir, "known_hosts")
    73  	os.WriteFile(ts.KnownHosts, []byte(kh), 0644)
    74  	//exhaustruct:ignore
    75  	sc := &ssh.ServerConfig{
    76  		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
    77  			//exhaustruct:ignore
    78  			return &ssh.Permissions{}, nil
    79  		},
    80  	}
    81  	sc.AddHostKey(ps)
    82  
    83  	// Setup a private key for the client and write it to an openssh compatible file
    84  	_, clientPriv, _ := ed25519.GenerateKey(nil)
    85  	ts.ClientKey = filepath.Join(workdir, "id_ed25519")
    86  	key := pem.EncodeToMemory(&pem.Block{
    87  		Headers: nil,
    88  		Type:    "OPENSSH PRIVATE KEY",
    89  		Bytes:   edkey.MarshalED25519PrivateKey(clientPriv),
    90  	})
    91  	os.WriteFile(ts.ClientKey, key, 0600)
    92  
    93  	ts.Url = fmt.Sprintf("ssh://git@127.0.0.1:%d/.", ts.Port)
    94  
    95  	go func() {
    96  		for {
    97  			con, err := ts.l.Accept()
    98  			if err != nil {
    99  				fmt.Printf("testssh: err %q\n", err)
   100  				return
   101  			}
   102  			go ts.handleConn(con, workdir, sc)
   103  		}
   104  	}()
   105  	return &ts
   106  }
   107  
   108  func (ts *TestServer) handleConn(con net.Conn, workdir string, sc *ssh.ServerConfig) {
   109  	defer con.Close()
   110  	sCon, chans, reqs, err := ssh.NewServerConn(con, sc)
   111  	if err != nil {
   112  		fmt.Printf("testssh: err %q\n", err)
   113  		return
   114  	}
   115  	defer sCon.Close()
   116  	go ssh.DiscardRequests(reqs)
   117  	for newch := range chans {
   118  		if newch.ChannelType() != "session" {
   119  			newch.Reject(ssh.UnknownChannelType, "only channel type session is allowed")
   120  		}
   121  		ch, reqs, err := newch.Accept()
   122  		if err != nil {
   123  			fmt.Printf("testssh: accept err %q\n", err)
   124  			return
   125  		}
   126  		env := []string{}
   127  		for req := range reqs {
   128  			switch req.Type {
   129  			case "env":
   130  				var payload envReq
   131  				ssh.Unmarshal(req.Payload, &payload)
   132  				env = append(env, fmt.Sprintf("%s=%s", payload.Env, payload.Value))
   133  			case "exec":
   134  				var payload execReq
   135  				ssh.Unmarshal(req.Payload, &payload)
   136  				args, _ := shellwords.Parse(payload.Command)
   137  				if args[0] != "git-upload-pack" && args[0] != "git-receive-pack" {
   138  					fmt.Printf("testssh: illegal command: %q\n", args[0])
   139  					req.Reply(false, nil)
   140  					ch.Close()
   141  					return
   142  				}
   143  				if args[0] == "git-receive-pack" {
   144  					ts.Pushes = ts.Pushes + 1
   145  				}
   146  				args[1] = filepath.Join(workdir, args[1])
   147  				cmd := exec.Command(args[0], args[1:]...)
   148  				cmd.Env = env
   149  				stdin, _ := cmd.StdinPipe()
   150  				stdout, _ := cmd.StdoutPipe()
   151  				stderr, _ := cmd.StderrPipe()
   152  				go io.Copy(stdin, ch)
   153  				time.Sleep(ts.execDelay)
   154  				cmd.Start()
   155  				req.Reply(true, nil)
   156  				_, _ = io.Copy(ch, stdout)
   157  				_, _ = io.Copy(ch.Stderr(), stderr)
   158  				err = cmd.Wait()
   159  				if err != nil {
   160  					fmt.Printf("testssh: run err %q\n", err)
   161  				}
   162  				ch.SendRequest("exit-status", false, ssh.Marshal(&exitReq{Status: uint32(cmd.ProcessState.ExitCode())}))
   163  				ch.Close()
   164  			default:
   165  				fmt.Printf("testssh: illegal req: %q\n", req.Type)
   166  				ch.Close()
   167  			}
   168  
   169  		}
   170  	}
   171  }
   172  
   173  func (ts *TestServer) DelayExecs(dr time.Duration) {
   174  	ts.execDelay = dr
   175  }
   176  
   177  func (ts *TestServer) Close() {
   178  	ts.l.Close()
   179  }