github.com/sagernet/sing@v0.2.6/service/filemanager/default.go (about)

     1  package filemanager
     2  
     3  import (
     4  	"context"
     5  	"os"
     6  	"path/filepath"
     7  	"strings"
     8  	"syscall"
     9  
    10  	"github.com/sagernet/sing/common/rw"
    11  	"github.com/sagernet/sing/service"
    12  )
    13  
    14  var _ Manager = (*defaultManager)(nil)
    15  
    16  type defaultManager struct {
    17  	basePath string
    18  	tempPath string
    19  	chown    bool
    20  	userID   int
    21  	groupID  int
    22  }
    23  
    24  func WithDefault(ctx context.Context, basePath string, tempPath string, userID int, groupID int) context.Context {
    25  	chown := userID != os.Getuid() || groupID != os.Getgid()
    26  	if tempPath == "" {
    27  		tempPath = os.TempDir()
    28  	}
    29  	return service.ContextWith[Manager](ctx, &defaultManager{
    30  		basePath: basePath,
    31  		tempPath: tempPath,
    32  		chown:    chown,
    33  		userID:   userID,
    34  		groupID:  groupID,
    35  	})
    36  }
    37  
    38  func (m *defaultManager) BasePath(name string) string {
    39  	if m.basePath == "" || strings.HasPrefix(name, "/") {
    40  		return name
    41  	}
    42  	return filepath.Join(m.basePath, name)
    43  }
    44  
    45  func (m *defaultManager) OpenFile(name string, flag int, perm os.FileMode) (*os.File, error) {
    46  	name = m.BasePath(name)
    47  	willCreate := flag&os.O_CREATE != 0 && !rw.FileExists(name)
    48  	file, err := os.OpenFile(name, flag, perm)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	if m.chown && willCreate {
    53  		err = file.Chown(m.userID, m.groupID)
    54  		if err != nil {
    55  			file.Close()
    56  			os.Remove(file.Name())
    57  			return nil, err
    58  		}
    59  	}
    60  	return file, nil
    61  }
    62  
    63  func (m *defaultManager) Create(name string) (*os.File, error) {
    64  	name = m.BasePath(name)
    65  	file, err := os.Create(name)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	if m.chown {
    70  		err = file.Chown(m.userID, m.groupID)
    71  		if err != nil {
    72  			file.Close()
    73  			os.Remove(file.Name())
    74  			return nil, err
    75  		}
    76  	}
    77  	return file, nil
    78  }
    79  
    80  func (m *defaultManager) CreateTemp(pattern string) (*os.File, error) {
    81  	file, err := os.CreateTemp(m.tempPath, pattern)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	if m.chown {
    86  		err = file.Chown(m.userID, m.groupID)
    87  		if err != nil {
    88  			file.Close()
    89  			os.Remove(file.Name())
    90  			return nil, err
    91  		}
    92  	}
    93  	return file, nil
    94  }
    95  
    96  func (m *defaultManager) Mkdir(path string, perm os.FileMode) error {
    97  	path = m.BasePath(path)
    98  	err := os.Mkdir(path, perm)
    99  	if err != nil {
   100  		return err
   101  	}
   102  	if m.chown {
   103  		err = os.Chown(path, m.userID, m.groupID)
   104  		if err != nil {
   105  			os.Remove(path)
   106  			return err
   107  		}
   108  	}
   109  	return nil
   110  }
   111  
   112  func (m *defaultManager) MkdirAll(path string, perm os.FileMode) error {
   113  	path = m.BasePath(path)
   114  	dir, err := os.Stat(path)
   115  	if err == nil {
   116  		if dir.IsDir() {
   117  			return nil
   118  		}
   119  		return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR}
   120  	}
   121  
   122  	i := len(path)
   123  	for i > 0 && os.IsPathSeparator(path[i-1]) {
   124  		i--
   125  	}
   126  
   127  	j := i
   128  	for j > 0 && !os.IsPathSeparator(path[j-1]) {
   129  		j--
   130  	}
   131  
   132  	if j > 1 {
   133  		err = m.MkdirAll(fixRootDirectory(path[:j-1]), perm)
   134  		if err != nil {
   135  			return err
   136  		}
   137  	}
   138  
   139  	err = os.Mkdir(path, perm)
   140  	if err != nil {
   141  		dir, err1 := os.Lstat(path)
   142  		if err1 == nil && dir.IsDir() {
   143  			return nil
   144  		}
   145  		return err
   146  	}
   147  	if m.chown {
   148  		err = os.Chown(path, m.userID, m.groupID)
   149  		if err != nil {
   150  			os.Remove(path)
   151  			return err
   152  		}
   153  	}
   154  	return nil
   155  }
   156  
   157  func fixRootDirectory(p string) string {
   158  	if len(p) == len(`\\?\c:`) {
   159  		if os.IsPathSeparator(p[0]) && os.IsPathSeparator(p[1]) && p[2] == '?' && os.IsPathSeparator(p[3]) && p[5] == ':' {
   160  			return p + `\`
   161  		}
   162  	}
   163  	return p
   164  }