github.com/apptainer/singularity@v3.1.1+incompatible/internal/pkg/test/privilege.go (about)

     1  // Copyright (c) 2018, Sylabs Inc. All rights reserved.
     2  // This software is licensed under a 3-clause BSD license. Please consult the
     3  // LICENSE.md file distributed with the sources of this project regarding your
     4  // rights to use or distribute this software.
     5  
     6  package test
     7  
     8  import (
     9  	"bufio"
    10  	"fmt"
    11  	"log"
    12  	"os"
    13  	"os/user"
    14  	"runtime"
    15  	"strconv"
    16  	"syscall"
    17  	"testing"
    18  )
    19  
    20  var origUID, origGID, unprivUID, unprivGID int
    21  var origHome, unprivHome string
    22  
    23  const (
    24  	// CacheDirPriv is the directory the cachedir gets set to when running privileged
    25  	CacheDirPriv = "/tmp/WithPrivilege"
    26  	// CacheDirUnpriv is the directory the cachedir gets set to when running unprivileged
    27  	CacheDirUnpriv = "/tmp/WithoutPrivilege"
    28  )
    29  
    30  // EnsurePrivilege ensures elevated privileges are available during a test.
    31  func EnsurePrivilege(t *testing.T) {
    32  	uid := os.Getuid()
    33  	if uid != 0 {
    34  		t.Fatal("test must be run with privilege")
    35  	}
    36  }
    37  
    38  // DropPrivilege drops privilege. Use this at the start of a test that does
    39  // not require elevated privileges. A matching call to ResetPrivilege must
    40  // occur before the test completes (a defer statement is recommended.)
    41  func DropPrivilege(t *testing.T) {
    42  
    43  	// setresuid/setresgid modifies the current thread only. To ensure our new
    44  	// uid/gid sticks, we need to lock ourselves to the current OS thread.
    45  	runtime.LockOSThread()
    46  
    47  	if os.Getgid() == 0 {
    48  		if err := syscall.Setresgid(unprivGID, unprivGID, origGID); err != nil {
    49  			t.Fatalf("failed to set group identity: %v", err)
    50  		}
    51  	}
    52  	if os.Getuid() == 0 {
    53  		if err := syscall.Setresuid(unprivUID, unprivUID, origUID); err != nil {
    54  			t.Fatalf("failed to set user identity: %v", err)
    55  		}
    56  
    57  		if err := os.Setenv("HOME", unprivHome); err != nil {
    58  			t.Fatalf("failed to set HOME environment variable: %v", err)
    59  		}
    60  	}
    61  
    62  	// set SINGULARITY_CACHEDIR
    63  	os.Setenv("SINGULARITY_CACHEDIR", CacheDirUnpriv)
    64  }
    65  
    66  // ResetPrivilege returns effective privilege to the original user.
    67  func ResetPrivilege(t *testing.T) {
    68  	if err := syscall.Setresuid(origUID, origUID, unprivUID); err != nil {
    69  		t.Fatalf("failed to reset user identity: %v", err)
    70  	}
    71  	if err := syscall.Setresgid(origGID, origGID, unprivGID); err != nil {
    72  		t.Fatalf("failed to reset group identity: %v", err)
    73  	}
    74  	if err := os.Setenv("HOME", origHome); err != nil {
    75  		t.Fatalf("failed to reset HOME environment variable: %v", err)
    76  	}
    77  
    78  	runtime.UnlockOSThread()
    79  
    80  	// set SINGULARITY_CACHEDIR
    81  	os.Setenv("SINGULARITY_CACHEDIR", CacheDirPriv)
    82  }
    83  
    84  // WithPrivilege wraps the supplied test function with calls to ensure
    85  // the test is run with elevated privileges.
    86  func WithPrivilege(f func(t *testing.T)) func(t *testing.T) {
    87  	return func(t *testing.T) {
    88  		t.Helper()
    89  
    90  		// set SINGULARITY_CACHEDIR
    91  		os.Setenv("SINGULARITY_CACHEDIR", CacheDirPriv)
    92  
    93  		EnsurePrivilege(t)
    94  
    95  		f(t)
    96  	}
    97  }
    98  
    99  // WithoutPrivilege wraps the supplied test function with calls to ensure
   100  // the test is run without elevated privileges.
   101  func WithoutPrivilege(f func(t *testing.T)) func(t *testing.T) {
   102  	return func(t *testing.T) {
   103  		t.Helper()
   104  
   105  		DropPrivilege(t)
   106  		defer ResetPrivilege(t)
   107  
   108  		// set SINGULARITY_CACHEDIR
   109  		os.Setenv("SINGULARITY_CACHEDIR", CacheDirUnpriv)
   110  
   111  		f(t)
   112  	}
   113  }
   114  
   115  // getProcInfo returns the parent PID, UID, and GID associated with the
   116  // supplied PID. Calls os.Exit on error.
   117  func getProcInfo(pid int) (ppid int, uid int, gid int) {
   118  	f, err := os.Open(fmt.Sprintf("/proc/%v/status", pid))
   119  	if err != nil {
   120  		log.Fatalf("failed to open /proc/%v/status", pid)
   121  	}
   122  
   123  	for s := bufio.NewScanner(f); s.Scan(); {
   124  		var temp int
   125  		if n, _ := fmt.Sscanf(s.Text(), "PPid:\t%d", &temp); n == 1 {
   126  			ppid = temp
   127  		}
   128  		if n, _ := fmt.Sscanf(s.Text(), "Uid:\t%d", &temp); n == 1 {
   129  			uid = temp
   130  		}
   131  		if n, _ := fmt.Sscanf(s.Text(), "Gid:\t%d", &temp); n == 1 {
   132  			gid = temp
   133  		}
   134  	}
   135  	return ppid, uid, gid
   136  }
   137  
   138  // getUnprivIDs searches recursively up the process parent chain to find a
   139  // process with a non-root UID, then returns the UID and GID of that process.
   140  // Calls os.Exit on error, or if no non-root process is found.
   141  func getUnprivIDs(pid int) (uid int, gid int) {
   142  	if 1 == pid {
   143  		log.Fatal("no unprivileged process found")
   144  	}
   145  
   146  	ppid, uid, gid := getProcInfo(pid)
   147  	if uid != 0 {
   148  		return uid, gid
   149  	}
   150  	return getUnprivIDs(ppid)
   151  }
   152  
   153  func init() {
   154  	origUID = os.Getuid()
   155  	origGID = os.Getgid()
   156  	origUser, err := user.LookupId(strconv.Itoa(origUID))
   157  
   158  	if err != nil {
   159  		log.Fatalf("err: %s", err)
   160  	}
   161  
   162  	origHome = origUser.HomeDir
   163  
   164  	unprivUID, unprivGID = getUnprivIDs(os.Getpid())
   165  	unprivUser, err := user.LookupId(strconv.Itoa(unprivUID))
   166  
   167  	if err != nil {
   168  		log.Fatalf("err: %s", err)
   169  	}
   170  
   171  	unprivHome = unprivUser.HomeDir
   172  }