shader: Fix memory barriers

This commit is contained in:
ReinUsesLisp 2021-04-17 03:21:03 -03:00 committed by ameerj
parent c9e4609d87
commit 0a0818c025
8 changed files with 30 additions and 62 deletions

View File

@ -29,9 +29,8 @@ void EmitReturn(EmitContext& ctx);
void EmitUnreachable(EmitContext& ctx); void EmitUnreachable(EmitContext& ctx);
void EmitDemoteToHelperInvocation(EmitContext& ctx, Id continue_label); void EmitDemoteToHelperInvocation(EmitContext& ctx, Id continue_label);
void EmitBarrier(EmitContext& ctx); void EmitBarrier(EmitContext& ctx);
void EmitMemoryBarrierWorkgroupLevel(EmitContext& ctx); void EmitWorkgroupMemoryBarrier(EmitContext& ctx);
void EmitMemoryBarrierDeviceLevel(EmitContext& ctx); void EmitDeviceMemoryBarrier(EmitContext& ctx);
void EmitMemoryBarrierSystemLevel(EmitContext& ctx);
void EmitPrologue(EmitContext& ctx); void EmitPrologue(EmitContext& ctx);
void EmitEpilogue(EmitContext& ctx); void EmitEpilogue(EmitContext& ctx);
void EmitEmitVertex(EmitContext& ctx, const IR::Value& stream); void EmitEmitVertex(EmitContext& ctx, const IR::Value& stream);

View File

@ -7,7 +7,7 @@
namespace Shader::Backend::SPIRV { namespace Shader::Backend::SPIRV {
namespace { namespace {
void EmitMemoryBarrierImpl(EmitContext& ctx, spv::Scope scope) { void MemoryBarrier(EmitContext& ctx, spv::Scope scope) {
const auto semantics{ const auto semantics{
spv::MemorySemanticsMask::AcquireRelease | spv::MemorySemanticsMask::UniformMemory | spv::MemorySemanticsMask::AcquireRelease | spv::MemorySemanticsMask::UniformMemory |
spv::MemorySemanticsMask::WorkgroupMemory | spv::MemorySemanticsMask::AtomicCounterMemory | spv::MemorySemanticsMask::WorkgroupMemory | spv::MemorySemanticsMask::AtomicCounterMemory |
@ -27,16 +27,12 @@ void EmitBarrier(EmitContext& ctx) {
ctx.Constant(ctx.U32[1], static_cast<u32>(memory_semantics))); ctx.Constant(ctx.U32[1], static_cast<u32>(memory_semantics)));
} }
void EmitMemoryBarrierWorkgroupLevel(EmitContext& ctx) { void EmitWorkgroupMemoryBarrier(EmitContext& ctx) {
EmitMemoryBarrierImpl(ctx, spv::Scope::Workgroup); MemoryBarrier(ctx, spv::Scope::Workgroup);
} }
void EmitMemoryBarrierDeviceLevel(EmitContext& ctx) { void EmitDeviceMemoryBarrier(EmitContext& ctx) {
EmitMemoryBarrierImpl(ctx, spv::Scope::Device); MemoryBarrier(ctx, spv::Scope::Device);
}
void EmitMemoryBarrierSystemLevel(EmitContext& ctx) {
EmitMemoryBarrierImpl(ctx, spv::Scope::CrossDevice);
} }
} // namespace Shader::Backend::SPIRV } // namespace Shader::Backend::SPIRV

View File

@ -86,20 +86,12 @@ void IREmitter::Barrier() {
Inst(Opcode::Barrier); Inst(Opcode::Barrier);
} }
void IREmitter::MemoryBarrier(MemoryScope scope) { void IREmitter::WorkgroupMemoryBarrier() {
switch (scope) { Inst(Opcode::WorkgroupMemoryBarrier);
case MemoryScope::Workgroup:
Inst(Opcode::MemoryBarrierWorkgroupLevel);
break;
case MemoryScope::Device:
Inst(Opcode::MemoryBarrierDeviceLevel);
break;
case MemoryScope::System:
Inst(Opcode::MemoryBarrierSystemLevel);
break;
default:
throw InvalidArgument("Invalid memory scope {}", scope);
} }
void IREmitter::DeviceMemoryBarrier() {
Inst(Opcode::DeviceMemoryBarrier);
} }
void IREmitter::Return() { void IREmitter::Return() {

View File

@ -144,8 +144,9 @@ public:
[[nodiscard]] Value Select(const U1& condition, const Value& true_value, [[nodiscard]] Value Select(const U1& condition, const Value& true_value,
const Value& false_value); const Value& false_value);
[[nodiscard]] void Barrier(); void Barrier();
[[nodiscard]] void MemoryBarrier(MemoryScope scope); void WorkgroupMemoryBarrier();
void DeviceMemoryBarrier();
template <typename Dest, typename Source> template <typename Dest, typename Source>
[[nodiscard]] Dest BitCast(const Source& value); [[nodiscard]] Dest BitCast(const Source& value);

View File

@ -64,9 +64,8 @@ bool Inst::MayHaveSideEffects() const noexcept {
case Opcode::Unreachable: case Opcode::Unreachable:
case Opcode::DemoteToHelperInvocation: case Opcode::DemoteToHelperInvocation:
case Opcode::Barrier: case Opcode::Barrier:
case Opcode::MemoryBarrierWorkgroupLevel: case Opcode::WorkgroupMemoryBarrier:
case Opcode::MemoryBarrierDeviceLevel: case Opcode::DeviceMemoryBarrier:
case Opcode::MemoryBarrierSystemLevel:
case Opcode::Prologue: case Opcode::Prologue:
case Opcode::Epilogue: case Opcode::Epilogue:
case Opcode::EmitVertex: case Opcode::EmitVertex:

View File

@ -25,14 +25,6 @@ enum class FpRounding : u8 {
RZ, // Round towards zero RZ, // Round towards zero
}; };
enum class MemoryScope : u32 {
DontCare,
Warp,
Workgroup,
Device,
System,
};
struct FpControl { struct FpControl {
bool no_contraction{false}; bool no_contraction{false};
FpRounding rounding{FpRounding::DontCare}; FpRounding rounding{FpRounding::DontCare};

View File

@ -18,9 +18,8 @@ OPCODE(DemoteToHelperInvocation, Void, Labe
// Barriers // Barriers
OPCODE(Barrier, Void, ) OPCODE(Barrier, Void, )
OPCODE(MemoryBarrierWorkgroupLevel, Void, ) OPCODE(WorkgroupMemoryBarrier, Void, )
OPCODE(MemoryBarrierDeviceLevel, Void, ) OPCODE(DeviceMemoryBarrier, Void, )
OPCODE(MemoryBarrierSystemLevel, Void, )
// Special operations // Special operations
OPCODE(Prologue, Void, ) OPCODE(Prologue, Void, )

View File

@ -12,34 +12,24 @@ namespace Shader::Maxwell {
namespace { namespace {
// Seems to be in CUDA terminology. // Seems to be in CUDA terminology.
enum class LocalScope : u64 { enum class LocalScope : u64 {
CTG = 0, CTA,
GL = 1, GL,
SYS = 2, SYS,
VC = 3, VC,
}; };
IR::MemoryScope LocalScopeToMemoryScope(LocalScope scope) {
switch (scope) {
case LocalScope::CTG:
return IR::MemoryScope::Workgroup;
case LocalScope::GL:
return IR::MemoryScope::Device;
case LocalScope::SYS:
return IR::MemoryScope::System;
default:
throw NotImplementedException("Unimplemented Local Scope {}", scope);
}
}
} // Anonymous namespace } // Anonymous namespace
void TranslatorVisitor::MEMBAR(u64 inst) { void TranslatorVisitor::MEMBAR(u64 inst) {
union { union {
u64 raw; u64 raw;
BitField<8, 2, LocalScope> scope; BitField<8, 2, LocalScope> scope;
} membar{inst}; } const membar{inst};
ir.MemoryBarrier(LocalScopeToMemoryScope(membar.scope)); if (membar.scope == LocalScope::CTA) {
ir.WorkgroupMemoryBarrier();
} else {
ir.DeviceMemoryBarrier();
}
} }
void TranslatorVisitor::DEPBAR() { void TranslatorVisitor::DEPBAR() {