Switch to a safer technique for obtaining the working directory on Windows

Instead of looping 8 times to work around the TOCTOU issue with sizing the current directory buffer, instead keep doubling the buffer up until the 32767 character limit until the result fits. This ensures we always get a working directory if GetWorkingDirectoryW didn't return some other error, rather than returning nil in the case of a race condition.
This commit is contained in:
Jake Petroules 2025-04-30 23:54:28 -07:00
parent 7ae91602d7
commit f77a636a01
2 changed files with 40 additions and 15 deletions

View File

@ -489,22 +489,13 @@ extension _FileManagerImpl {
var currentDirectoryPath: String? { var currentDirectoryPath: String? {
#if os(Windows) #if os(Windows)
var dwLength: DWORD = GetCurrentDirectoryW(0, nil) // Make an initial call to GetCurrentDirectoryW to get a buffer size estimate.
guard dwLength > 0 else { return nil } // This is solely to minimize the number of allocations and number of bytes allocated versus starting with a hardcoded value like MAX_PATH.
// We should NOT early-return if this returns 0, in order to avoid TOCTOU issues.
for _ in 0 ... 8 { let dwSize = GetCurrentDirectoryW(0, nil)
if let szCurrentDirectory = withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength), { return try? FillNullTerminatedWideStringBuffer(initialSize: dwSize >= 0 ? dwSize : DWORD(MAX_PATH), maxSize: DWORD(Int16.max)) {
let dwResult: DWORD = GetCurrentDirectoryW(dwLength, $0.baseAddress) GetCurrentDirectoryW(DWORD($0.count), $0.baseAddress)
if dwResult == dwLength - 1 {
return String(decodingCString: $0.baseAddress!, as: UTF16.self)
}
dwLength = dwResult
return nil
}) {
return szCurrentDirectory
}
} }
return nil
#else #else
withUnsafeTemporaryAllocation(of: CChar.self, capacity: FileManager.MAX_PATH_SIZE) { buffer in withUnsafeTemporaryAllocation(of: CChar.self, capacity: FileManager.MAX_PATH_SIZE) { buffer in
guard getcwd(buffer.baseAddress!, FileManager.MAX_PATH_SIZE) != nil else { guard getcwd(buffer.baseAddress!, FileManager.MAX_PATH_SIZE) != nil else {

View File

@ -81,6 +81,10 @@ package var ERROR_FILENAME_EXCED_RANGE: DWORD {
DWORD(WinSDK.ERROR_FILENAME_EXCED_RANGE) DWORD(WinSDK.ERROR_FILENAME_EXCED_RANGE)
} }
package var ERROR_INSUFFICIENT_BUFFER: DWORD {
DWORD(WinSDK.ERROR_INSUFFICIENT_BUFFER)
}
package var ERROR_INVALID_ACCESS: DWORD { package var ERROR_INVALID_ACCESS: DWORD {
DWORD(WinSDK.ERROR_INVALID_ACCESS) DWORD(WinSDK.ERROR_INVALID_ACCESS)
} }
@ -288,4 +292,34 @@ internal func WIN32_FROM_HRESULT(_ hr: HRESULT) -> DWORD {
return DWORD(hr) return DWORD(hr)
} }
/// Calls a Win32 API function that fills a (potentially long path) null-terminated string buffer by continually attempting to allocate more memory up until the true max path is reached.
/// This is especially useful for protecting against race conditions like with GetCurrentDirectoryW where the measured length may no longer be valid on subsequent calls.
/// - parameter initialSize: Initial size of the buffer (including the null terminator) to allocate to hold the returned string.
/// - parameter maxSize: Maximum size of the buffer (including the null terminator) to allocate to hold the returned string.
/// - parameter body: Closure to call the Win32 API function to populate the provided buffer.
/// Should return the number of UTF-16 code units (not including the null terminator) copied, 0 to indicate an error.
/// If the buffer is not of sufficient size, should return a value greater than or equal to the size of the buffer.
internal func FillNullTerminatedWideStringBuffer(initialSize: DWORD, maxSize: DWORD, _ body: (UnsafeMutableBufferPointer<WCHAR>) throws -> DWORD) throws -> String {
var bufferCount = max(1, min(initialSize, maxSize))
while bufferCount <= maxSize {
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(bufferCount), { buffer in
let count = try body(buffer)
switch count {
case 0:
throw Win32Error(GetLastError())
case 1..<DWORD(buffer.count):
let result = String(decodingCString: buffer.baseAddress!, as: UTF16.self)
assert(result.utf16.count == count, "Parsed UTF-16 count \(result.utf16.count) != reported UTF-16 count \(count)")
return result
default:
bufferCount *= 2
return nil
}
}) {
return result
}
}
throw Win32Error(ERROR_INSUFFICIENT_BUFFER)
}
#endif #endif