shader_cache: Fix use-after-free and orphan invalidation cache entries

This fixes some cases where entries could have been removed multiple
times reading freed memory. To address this issue this commit removes
duplicates from entries marked for removal and sorts out the removal
process to fix another use-after-free situation.

Another issue fixed in this commit is orphan invalidation cache entries.
Previously only the entries that were invalidated in the current
operations had its entries removed. This led to more use-after-free
situations when these entries were actually invalidated but referenced
an object that didn't exist.
This commit is contained in:
ReinUsesLisp 2020-06-28 04:58:58 -03:00
parent 0b954a3305
commit f6cb128eac
1 changed files with 43 additions and 31 deletions

View File

@ -20,6 +20,7 @@ namespace VideoCommon {
template <class T> template <class T>
class ShaderCache { class ShaderCache {
static constexpr u64 PAGE_BITS = 14; static constexpr u64 PAGE_BITS = 14;
static constexpr u64 PAGE_SIZE = u64(1) << PAGE_BITS;
struct Entry { struct Entry {
VAddr addr_start; VAddr addr_start;
@ -87,8 +88,8 @@ protected:
const VAddr addr_end = addr + size; const VAddr addr_end = addr + size;
Entry* const entry = NewEntry(addr, addr_end, data.get()); Entry* const entry = NewEntry(addr, addr_end, data.get());
const u64 page_end = addr_end >> PAGE_BITS; const u64 page_end = (addr_end + PAGE_SIZE - 1) >> PAGE_BITS;
for (u64 page = addr >> PAGE_BITS; page <= page_end; ++page) { for (u64 page = addr >> PAGE_BITS; page < page_end; ++page) {
invalidation_cache[page].push_back(entry); invalidation_cache[page].push_back(entry);
} }
@ -108,20 +109,13 @@ private:
/// @pre invalidation_mutex is locked /// @pre invalidation_mutex is locked
void InvalidatePagesInRegion(VAddr addr, std::size_t size) { void InvalidatePagesInRegion(VAddr addr, std::size_t size) {
const VAddr addr_end = addr + size; const VAddr addr_end = addr + size;
const u64 page_end = addr_end >> PAGE_BITS; const u64 page_end = (addr_end + PAGE_SIZE - 1) >> PAGE_BITS;
for (u64 page = addr >> PAGE_BITS; page <= page_end; ++page) { for (u64 page = addr >> PAGE_BITS; page < page_end; ++page) {
const auto it = invalidation_cache.find(page); auto it = invalidation_cache.find(page);
if (it == invalidation_cache.end()) { if (it == invalidation_cache.end()) {
continue; continue;
} }
InvalidatePageEntries(it->second, addr, addr_end);
std::vector<Entry*>& entries = it->second;
InvalidatePageEntries(entries, addr, addr_end);
// If there's nothing else in this page, remove it to avoid overpopulating the hash map.
if (entries.empty()) {
invalidation_cache.erase(it);
}
} }
} }
@ -131,15 +125,22 @@ private:
if (marked_for_removal.empty()) { if (marked_for_removal.empty()) {
return; return;
} }
std::scoped_lock lock{lookup_mutex}; // Remove duplicates
std::sort(marked_for_removal.begin(), marked_for_removal.end());
marked_for_removal.erase(std::unique(marked_for_removal.begin(), marked_for_removal.end()),
marked_for_removal.end());
std::vector<T*> removed_shaders; std::vector<T*> removed_shaders;
removed_shaders.reserve(marked_for_removal.size()); removed_shaders.reserve(marked_for_removal.size());
std::scoped_lock lock{lookup_mutex};
for (Entry* const entry : marked_for_removal) { for (Entry* const entry : marked_for_removal) {
if (lookup_cache.erase(entry->addr_start) > 0) { removed_shaders.push_back(entry->data);
removed_shaders.push_back(entry->data);
} const auto it = lookup_cache.find(entry->addr_start);
ASSERT(it != lookup_cache.end());
lookup_cache.erase(it);
} }
marked_for_removal.clear(); marked_for_removal.clear();
@ -154,17 +155,33 @@ private:
/// @param addr_end Non-inclusive end address of the invalidation /// @param addr_end Non-inclusive end address of the invalidation
/// @pre invalidation_mutex is locked /// @pre invalidation_mutex is locked
void InvalidatePageEntries(std::vector<Entry*>& entries, VAddr addr, VAddr addr_end) { void InvalidatePageEntries(std::vector<Entry*>& entries, VAddr addr, VAddr addr_end) {
auto it = entries.begin(); std::size_t index = 0;
while (it != entries.end()) { while (index < entries.size()) {
Entry* const entry = *it; Entry* const entry = entries[index];
if (!entry->Overlaps(addr, addr_end)) { if (!entry->Overlaps(addr, addr_end)) {
++it; ++index;
continue; continue;
} }
UnmarkMemory(entry);
marked_for_removal.push_back(entry);
it = entries.erase(it); UnmarkMemory(entry);
RemoveEntryFromInvalidationCache(entry);
marked_for_removal.push_back(entry);
}
}
/// @brief Removes all references to an entry in the invalidation cache
/// @param entry Entry to remove from the invalidation cache
/// @pre invalidation_mutex is locked
void RemoveEntryFromInvalidationCache(const Entry* entry) {
const u64 page_end = (entry->addr_end + PAGE_SIZE - 1) >> PAGE_BITS;
for (u64 page = entry->addr_start >> PAGE_BITS; page < page_end; ++page) {
const auto entries_it = invalidation_cache.find(page);
ASSERT(entries_it != invalidation_cache.end());
std::vector<Entry*>& entries = entries_it->second;
const auto entry_it = std::find(entries.begin(), entries.end(), entry);
ASSERT(entry_it != entries.end());
entries.erase(entry_it);
} }
} }
@ -182,16 +199,11 @@ private:
} }
/// @brief Removes a vector of shaders from a list /// @brief Removes a vector of shaders from a list
/// @param removed_shaders Shaders to be removed from the storage, it can contain duplicates /// @param removed_shaders Shaders to be removed from the storage
/// @pre invalidation_mutex is locked /// @pre invalidation_mutex is locked
/// @pre lookup_mutex is locked /// @pre lookup_mutex is locked
void RemoveShadersFromStorage(std::vector<T*> removed_shaders) { void RemoveShadersFromStorage(std::vector<T*> removed_shaders) {
// Remove duplicates // Notify removals
std::sort(removed_shaders.begin(), removed_shaders.end());
removed_shaders.erase(std::unique(removed_shaders.begin(), removed_shaders.end()),
removed_shaders.end());
// Now that there are no duplicates, we can notify removals
for (T* const shader : removed_shaders) { for (T* const shader : removed_shaders) {
OnShaderRemoval(shader); OnShaderRemoval(shader);
} }