github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/service/filemanager/default.go (about) 1 package filemanager 2 3 import ( 4 "context" 5 "os" 6 "path/filepath" 7 "strings" 8 "syscall" 9 10 "github.com/sagernet/sing/common/rw" 11 "github.com/sagernet/sing/service" 12 ) 13 14 var _ Manager = (*defaultManager)(nil) 15 16 type defaultManager struct { 17 basePath string 18 tempPath string 19 chown bool 20 userID int 21 groupID int 22 } 23 24 func WithDefault(ctx context.Context, basePath string, tempPath string, userID int, groupID int) context.Context { 25 chown := userID != os.Getuid() || groupID != os.Getgid() 26 if tempPath == "" { 27 tempPath = os.TempDir() 28 } 29 return service.ContextWith[Manager](ctx, &defaultManager{ 30 basePath: basePath, 31 tempPath: tempPath, 32 chown: chown, 33 userID: userID, 34 groupID: groupID, 35 }) 36 } 37 38 func (m *defaultManager) BasePath(name string) string { 39 if m.basePath == "" || strings.HasPrefix(name, "/") { 40 return name 41 } 42 return filepath.Join(m.basePath, name) 43 } 44 45 func (m *defaultManager) OpenFile(name string, flag int, perm os.FileMode) (*os.File, error) { 46 name = m.BasePath(name) 47 willCreate := flag&os.O_CREATE != 0 && !rw.FileExists(name) 48 file, err := os.OpenFile(name, flag, perm) 49 if err != nil { 50 return nil, err 51 } 52 if m.chown && willCreate { 53 err = file.Chown(m.userID, m.groupID) 54 if err != nil { 55 file.Close() 56 os.Remove(file.Name()) 57 return nil, err 58 } 59 } 60 return file, nil 61 } 62 63 func (m *defaultManager) Create(name string) (*os.File, error) { 64 name = m.BasePath(name) 65 file, err := os.Create(name) 66 if err != nil { 67 return nil, err 68 } 69 if m.chown { 70 err = file.Chown(m.userID, m.groupID) 71 if err != nil { 72 file.Close() 73 os.Remove(file.Name()) 74 return nil, err 75 } 76 } 77 return file, nil 78 } 79 80 func (m *defaultManager) CreateTemp(pattern string) (*os.File, error) { 81 file, err := os.CreateTemp(m.tempPath, pattern) 82 if err != nil { 83 return nil, err 84 } 85 if m.chown { 86 err = file.Chown(m.userID, m.groupID) 87 if err != nil { 88 file.Close() 89 os.Remove(file.Name()) 90 return nil, err 91 } 92 } 93 return file, nil 94 } 95 96 func (m *defaultManager) Chown(path string) error { 97 if m.chown { 98 return os.Chown(path, m.userID, m.groupID) 99 } 100 return nil 101 } 102 103 func (m *defaultManager) Mkdir(path string, perm os.FileMode) error { 104 path = m.BasePath(path) 105 err := os.Mkdir(path, perm) 106 if err != nil { 107 return err 108 } 109 if m.chown { 110 err = os.Chown(path, m.userID, m.groupID) 111 if err != nil { 112 os.Remove(path) 113 return err 114 } 115 } 116 return nil 117 } 118 119 func (m *defaultManager) MkdirAll(path string, perm os.FileMode) error { 120 path = m.BasePath(path) 121 return m.mkdirAll(path, perm) 122 } 123 124 func (m *defaultManager) mkdirAll(path string, perm os.FileMode) error { 125 dir, err := os.Stat(path) 126 if err == nil { 127 if dir.IsDir() { 128 return nil 129 } 130 return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR} 131 } 132 133 i := len(path) 134 for i > 0 && os.IsPathSeparator(path[i-1]) { 135 i-- 136 } 137 138 j := i 139 for j > 0 && !os.IsPathSeparator(path[j-1]) { 140 j-- 141 } 142 143 if j > 1 { 144 err = m.MkdirAll(fixRootDirectory(path[:j-1]), perm) 145 if err != nil { 146 return err 147 } 148 } 149 150 err = os.Mkdir(path, perm) 151 if err != nil { 152 dir, err1 := os.Lstat(path) 153 if err1 == nil && dir.IsDir() { 154 return nil 155 } 156 return err 157 } 158 if m.chown { 159 err = os.Chown(path, m.userID, m.groupID) 160 if err != nil { 161 os.Remove(path) 162 return err 163 } 164 } 165 return nil 166 } 167 168 func (m *defaultManager) Remove(path string) error { 169 path = m.BasePath(path) 170 return os.Remove(path) 171 } 172 173 func (m *defaultManager) RemoveAll(path string) error { 174 path = m.BasePath(path) 175 return os.RemoveAll(path) 176 } 177 178 func fixRootDirectory(p string) string { 179 if len(p) == len(`\\?\c:`) { 180 if os.IsPathSeparator(p[0]) && os.IsPathSeparator(p[1]) && p[2] == '?' && os.IsPathSeparator(p[3]) && p[5] == ':' { 181 return p + `\` 182 } 183 } 184 return p 185 }