Merge pull request #4849 from ReinUsesLisp/fix-fiber-test
tests: Fix data race in fibers test
This commit is contained in:
commit
1fd22823bc
@ -6,18 +6,40 @@
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <catch2/catch.hpp>
|
||||
#include <math.h>
|
||||
|
||||
#include "common/common_types.h"
|
||||
#include "common/fiber.h"
|
||||
#include "common/spin_lock.h"
|
||||
|
||||
namespace Common {
|
||||
|
||||
class ThreadIds {
|
||||
public:
|
||||
void Register(u32 id) {
|
||||
const auto thread_id = std::this_thread::get_id();
|
||||
std::scoped_lock lock{mutex};
|
||||
if (ids.contains(thread_id)) {
|
||||
throw std::logic_error{"Registering the same thread twice"};
|
||||
}
|
||||
ids.emplace(thread_id, id);
|
||||
}
|
||||
|
||||
[[nodiscard]] u32 Get() const {
|
||||
std::scoped_lock lock{mutex};
|
||||
return ids.at(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::mutex mutex;
|
||||
std::unordered_map<std::thread::id, u32> ids;
|
||||
};
|
||||
|
||||
class TestControl1 {
|
||||
public:
|
||||
TestControl1() = default;
|
||||
@ -26,7 +48,7 @@ public:
|
||||
|
||||
void ExecuteThread(u32 id);
|
||||
|
||||
std::unordered_map<std::thread::id, u32> ids;
|
||||
ThreadIds thread_ids;
|
||||
std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
|
||||
std::vector<std::shared_ptr<Common::Fiber>> work_fibers;
|
||||
std::vector<u32> items;
|
||||
@ -39,8 +61,7 @@ static void WorkControl1(void* control) {
|
||||
}
|
||||
|
||||
void TestControl1::DoWork() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
u32 value = items[id];
|
||||
for (u32 i = 0; i < id; i++) {
|
||||
value++;
|
||||
@ -50,8 +71,7 @@ void TestControl1::DoWork() {
|
||||
}
|
||||
|
||||
void TestControl1::ExecuteThread(u32 id) {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
ids[this_id] = id;
|
||||
thread_ids.Register(id);
|
||||
auto thread_fiber = Fiber::ThreadToFiber();
|
||||
thread_fibers[id] = thread_fiber;
|
||||
work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
|
||||
@ -98,8 +118,7 @@ public:
|
||||
value1 += i;
|
||||
}
|
||||
Fiber::YieldTo(fiber1, fiber3);
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
assert1 = id == 1;
|
||||
value2 += 5000;
|
||||
Fiber::YieldTo(fiber1, thread_fibers[id]);
|
||||
@ -115,8 +134,7 @@ public:
|
||||
}
|
||||
|
||||
void DoWork3() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
assert2 = id == 0;
|
||||
value1 += 1000;
|
||||
Fiber::YieldTo(fiber3, thread_fibers[id]);
|
||||
@ -125,14 +143,12 @@ public:
|
||||
void ExecuteThread(u32 id);
|
||||
|
||||
void CallFiber1() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
Fiber::YieldTo(thread_fibers[id], fiber1);
|
||||
}
|
||||
|
||||
void CallFiber2() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
Fiber::YieldTo(thread_fibers[id], fiber2);
|
||||
}
|
||||
|
||||
@ -145,7 +161,7 @@ public:
|
||||
u32 value2{};
|
||||
std::atomic<bool> trap{true};
|
||||
std::atomic<bool> trap2{true};
|
||||
std::unordered_map<std::thread::id, u32> ids;
|
||||
ThreadIds thread_ids;
|
||||
std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
|
||||
std::shared_ptr<Common::Fiber> fiber1;
|
||||
std::shared_ptr<Common::Fiber> fiber2;
|
||||
@ -168,15 +184,13 @@ static void WorkControl2_3(void* control) {
|
||||
}
|
||||
|
||||
void TestControl2::ExecuteThread(u32 id) {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
ids[this_id] = id;
|
||||
thread_ids.Register(id);
|
||||
auto thread_fiber = Fiber::ThreadToFiber();
|
||||
thread_fibers[id] = thread_fiber;
|
||||
}
|
||||
|
||||
void TestControl2::Exit() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
thread_fibers[id]->Exit();
|
||||
}
|
||||
|
||||
@ -228,24 +242,21 @@ public:
|
||||
void DoWork1() {
|
||||
value1 += 1;
|
||||
Fiber::YieldTo(fiber1, fiber2);
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
value3 += 1;
|
||||
Fiber::YieldTo(fiber1, thread_fibers[id]);
|
||||
}
|
||||
|
||||
void DoWork2() {
|
||||
value2 += 1;
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
Fiber::YieldTo(fiber2, thread_fibers[id]);
|
||||
}
|
||||
|
||||
void ExecuteThread(u32 id);
|
||||
|
||||
void CallFiber1() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
Fiber::YieldTo(thread_fibers[id], fiber1);
|
||||
}
|
||||
|
||||
@ -254,7 +265,7 @@ public:
|
||||
u32 value1{};
|
||||
u32 value2{};
|
||||
u32 value3{};
|
||||
std::unordered_map<std::thread::id, u32> ids;
|
||||
ThreadIds thread_ids;
|
||||
std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
|
||||
std::shared_ptr<Common::Fiber> fiber1;
|
||||
std::shared_ptr<Common::Fiber> fiber2;
|
||||
@ -271,15 +282,13 @@ static void WorkControl3_2(void* control) {
|
||||
}
|
||||
|
||||
void TestControl3::ExecuteThread(u32 id) {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
ids[this_id] = id;
|
||||
thread_ids.Register(id);
|
||||
auto thread_fiber = Fiber::ThreadToFiber();
|
||||
thread_fibers[id] = thread_fiber;
|
||||
}
|
||||
|
||||
void TestControl3::Exit() {
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
u32 id = ids[this_id];
|
||||
const u32 id = thread_ids.Get();
|
||||
thread_fibers[id]->Exit();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user