shader: Add XMAD multiplication folding optimization

This commit is contained in:
ReinUsesLisp 2021-02-16 19:50:23 -03:00 committed by ameerj
parent 4b438f94cf
commit 58914796c0

View File

@ -9,6 +9,7 @@
#include "common/bit_cast.h" #include "common/bit_cast.h"
#include "common/bit_util.h" #include "common/bit_util.h"
#include "shader_recompiler/exception.h" #include "shader_recompiler/exception.h"
#include "shader_recompiler/frontend/ir/ir_emitter.h"
#include "shader_recompiler/frontend/ir/microinstruction.h" #include "shader_recompiler/frontend/ir/microinstruction.h"
#include "shader_recompiler/ir_opt/passes.h" #include "shader_recompiler/ir_opt/passes.h"
@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) {
} }
} }
/// Replaces the pattern generated by two XMAD multiplications
bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
/*
* We are looking for this pattern:
* %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1)
* %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1)
* %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1)
* %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1)
* %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1)
* %result = IAdd32 %lhs_shl, %rhs_mul (uses: 10)
*
* And replacing it with
* %result = IMul32 %factor_a, %factor_b
*
* This optimization has been proven safe by LLVM and MSVC.
*/
const IR::Value lhs_arg{inst.Arg(0)};
const IR::Value rhs_arg{inst.Arg(1)};
if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
return false;
}
IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) {
return false;
}
if (lhs_shl->Arg(0).IsImmediate()) {
return false;
}
IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) {
return false;
}
if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
return false;
}
const IR::U32 factor_b{lhs_mul->Arg(1)};
if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
return false;
}
IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
return false;
}
if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
return false;
}
if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) {
return false;
}
if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
return false;
}
if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
return false;
}
const IR::U32 factor_a{lhs_bfe->Arg(0)};
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
return true;
}
template <typename T> template <typename T>
void FoldAdd(IR::Inst& inst) { void FoldAdd(IR::Block& block, IR::Inst& inst) {
if (inst.HasAssociatedPseudoOperation()) { if (inst.HasAssociatedPseudoOperation()) {
return; return;
} }
@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)}; const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) { if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
inst.ReplaceUsesWith(inst.Arg(0)); inst.ReplaceUsesWith(inst.Arg(0));
return;
}
if constexpr (std::is_same_v<T, u32>) {
if (FoldXmadMultiply(block, inst)) {
return;
}
} }
} }
@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) {
} }
} }
void ConstantPropagation(IR::Inst& inst) { void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
switch (inst.Opcode()) { switch (inst.Opcode()) {
case IR::Opcode::GetRegister: case IR::Opcode::GetRegister:
return FoldGetRegister(inst); return FoldGetRegister(inst);
case IR::Opcode::GetPred: case IR::Opcode::GetPred:
return FoldGetPred(inst); return FoldGetPred(inst);
case IR::Opcode::IAdd32: case IR::Opcode::IAdd32:
return FoldAdd<u32>(inst); return FoldAdd<u32>(block, inst);
case IR::Opcode::ISub32: case IR::Opcode::ISub32:
return FoldISub32(inst); return FoldISub32(inst);
case IR::Opcode::BitCastF32U32: case IR::Opcode::BitCastF32U32:
@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) {
case IR::Opcode::BitCastU32F32: case IR::Opcode::BitCastU32F32:
return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32); return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
case IR::Opcode::IAdd64: case IR::Opcode::IAdd64:
return FoldAdd<u64>(inst); return FoldAdd<u64>(block, inst);
case IR::Opcode::Select32: case IR::Opcode::Select32:
return FoldSelect<u32>(inst); return FoldSelect<u32>(inst);
case IR::Opcode::LogicalAnd: case IR::Opcode::LogicalAnd:
@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) {
} // Anonymous namespace } // Anonymous namespace
void ConstantPropagationPass(IR::Block& block) { void ConstantPropagationPass(IR::Block& block) {
std::ranges::for_each(block, ConstantPropagation); for (IR::Inst& inst : block) {
ConstantPropagation(block, inst);
}
} }
} // namespace Shader::Optimization } // namespace Shader::Optimization