#include "IrValueLocationTracking.h"
#include "Luau/IrUtils.h"
namespace Luau
{
namespace CodeGen
{
IrValueLocationTracking::IrValueLocationTracking(IrFunction& function)
: function(function)
{
vmRegValue.fill(kInvalidInstIdx);
}
void IrValueLocationTracking::setRestoreCallback(void* context, void (*callback)(void* context, IrInst& inst))
{
restoreCallbackCtx = context;
restoreCallback = callback;
}
bool IrValueLocationTracking::canBeRematerialized(IrCmd cmd)
{
return cmd == IrCmd::UINT_TO_NUM || cmd == IrCmd::INT_TO_NUM;
}
bool IrValueLocationTracking::canRematerializeArguments(IrInst& inst)
{
if (canBeRematerialized(inst.cmd) && OP_A(inst).kind == IrOpKind::Inst)
{
IrInst& depInst = function.instOp(OP_A(inst));
if (depInst.lastUse != function.getInstIndex(inst))
return true;
}
return false;
}
void IrValueLocationTracking::beforeInstLowering(IrInst& inst)
{
switch (inst.cmd)
{
case IrCmd::STORE_TAG:
invalidateRestoreOp(OP_A(inst), true);
break;
case IrCmd::STORE_EXTRA:
invalidateRestoreOp(OP_A(inst), false);
break;
case IrCmd::STORE_POINTER:
case IrCmd::STORE_DOUBLE:
case IrCmd::STORE_INT:
case IrCmd::STORE_VECTOR:
case IrCmd::STORE_TVALUE:
case IrCmd::STORE_SPLIT_TVALUE:
invalidateRestoreOp(OP_A(inst), false);
break;
case IrCmd::ADJUST_STACK_TO_REG:
invalidateRestoreVmRegs(vmRegOp(OP_A(inst)), -1);
break;
case IrCmd::FASTCALL:
invalidateRestoreVmRegs(vmRegOp(OP_B(inst)), function.intOp(OP_D(inst)));
break;
case IrCmd::INVOKE_FASTCALL:
if (int count = function.intOp(OP_G(inst)); count != -1)
invalidateRestoreVmRegs(vmRegOp(OP_B(inst)), count);
break;
case IrCmd::DO_ARITH:
case IrCmd::DO_LEN:
case IrCmd::GET_TABLE:
case IrCmd::GET_CACHED_IMPORT:
invalidateRestoreOp(OP_A(inst), false);
break;
case IrCmd::CONCAT:
invalidateRestoreVmRegs(vmRegOp(OP_A(inst)), function.uintOp(OP_B(inst)));
break;
case IrCmd::GET_UPVALUE:
break;
case IrCmd::CALL:
invalidateRestoreVmRegs(vmRegOp(OP_A(inst)), -1);
break;
case IrCmd::FORGLOOP:
case IrCmd::FORGLOOP_FALLBACK:
invalidateRestoreVmRegs(vmRegOp(OP_A(inst)) + 2, -1);
break;
case IrCmd::FALLBACK_GETGLOBAL:
case IrCmd::FALLBACK_GETTABLEKS:
invalidateRestoreOp(OP_B(inst), false);
break;
case IrCmd::FALLBACK_NAMECALL:
invalidateRestoreVmRegs(vmRegOp(OP_B(inst)), 2);
break;
case IrCmd::FALLBACK_GETVARARGS:
invalidateRestoreVmRegs(vmRegOp(OP_B(inst)), function.intOp(OP_C(inst)));
break;
case IrCmd::FALLBACK_DUPCLOSURE:
invalidateRestoreOp(OP_B(inst), false);
break;
case IrCmd::FALLBACK_FORGPREP:
invalidateRestoreVmRegs(vmRegOp(OP_B(inst)), 3);
break;
case IrCmd::LOAD_TAG:
case IrCmd::LOAD_POINTER:
case IrCmd::LOAD_DOUBLE:
case IrCmd::LOAD_INT:
case IrCmd::LOAD_FLOAT:
case IrCmd::LOAD_TVALUE:
case IrCmd::CMP_ANY:
case IrCmd::CMP_TAG:
case IrCmd::JUMP_IF_TRUTHY:
case IrCmd::JUMP_IF_FALSY:
case IrCmd::JUMP_EQ_TAG:
case IrCmd::SET_TABLE:
case IrCmd::SET_UPVALUE:
case IrCmd::INTERRUPT:
case IrCmd::BARRIER_OBJ:
case IrCmd::BARRIER_TABLE_FORWARD:
case IrCmd::CLOSE_UPVALS:
case IrCmd::CAPTURE:
case IrCmd::SETLIST:
case IrCmd::RETURN:
case IrCmd::FORGPREP_XNEXT_FALLBACK:
case IrCmd::FALLBACK_SETGLOBAL:
case IrCmd::FALLBACK_SETTABLEKS:
case IrCmd::FALLBACK_PREPVARARGS:
case IrCmd::ADJUST_STACK_TO_TOP:
case IrCmd::GET_TYPEOF:
case IrCmd::NEWCLOSURE:
case IrCmd::FINDUPVAL:
break;
case IrCmd::CHECK_TAG:
case IrCmd::CHECK_TRUTHY:
case IrCmd::ADD_NUM:
case IrCmd::SUB_NUM:
case IrCmd::MUL_NUM:
case IrCmd::DIV_NUM:
case IrCmd::IDIV_NUM:
case IrCmd::MOD_NUM:
case IrCmd::MIN_NUM:
case IrCmd::MAX_NUM:
case IrCmd::JUMP_CMP_NUM:
case IrCmd::FLOOR_NUM:
case IrCmd::CEIL_NUM:
case IrCmd::ROUND_NUM:
case IrCmd::SQRT_NUM:
case IrCmd::ABS_NUM:
break;
default:
for (auto& op : inst.ops)
CODEGEN_ASSERT(op.kind != IrOpKind::VmReg);
break;
}
}
void IrValueLocationTracking::afterInstLowering(IrInst& inst, uint32_t instIdx)
{
switch (inst.cmd)
{
case IrCmd::LOAD_TAG:
case IrCmd::LOAD_POINTER:
case IrCmd::LOAD_DOUBLE:
case IrCmd::LOAD_INT:
case IrCmd::LOAD_TVALUE:
if (OP_A(inst).kind == IrOpKind::VmReg)
invalidateRestoreOp(OP_A(inst), false);
recordRestoreOp(instIdx, OP_A(inst));
break;
case IrCmd::STORE_POINTER:
case IrCmd::STORE_DOUBLE:
case IrCmd::STORE_INT:
case IrCmd::STORE_TVALUE:
if (OP_B(inst).kind == IrOpKind::Inst)
{
IrInst& source = function.instOp(OP_B(inst));
if (source.lastUse != instIdx || canRematerializeArguments(source))
recordRestoreOp(OP_B(inst).index, OP_A(inst));
}
break;
case IrCmd::STORE_SPLIT_TVALUE:
if (OP_C(inst).kind == IrOpKind::Inst)
{
IrInst& source = function.instOp(OP_C(inst));
if (source.lastUse != instIdx || canRematerializeArguments(source))
recordRestoreOp(OP_C(inst).index, OP_A(inst));
}
break;
default:
break;
}
}
void IrValueLocationTracking::recordRestoreOp(uint32_t instIdx, IrOp location)
{
IrInst& inst = function.instructions[instIdx];
if (location.kind == IrOpKind::VmReg)
{
int reg = vmRegOp(location);
if (reg > maxReg)
maxReg = reg;
bool captured = function.cfg.captured.regs.test(reg);
if (!captured)
function.recordRestoreLocation(instIdx, {location, getCmdValueKind(inst.cmd), IrCmd::NOP});
vmRegValue[reg] = instIdx;
if (canBeRematerialized(inst.cmd) && OP_A(inst).kind == IrOpKind::Inst)
{
uint32_t depInstIdx = OP_A(inst).index;
if (!captured)
function.recordRestoreLocation(depInstIdx, {location, getCmdValueKind(inst.cmd), inst.cmd});
}
}
else if (location.kind == IrOpKind::VmConst)
{
function.recordRestoreLocation(instIdx, {location, getCmdValueKind(inst.cmd)});
}
}
void IrValueLocationTracking::invalidateRestoreOp(IrOp location, bool skipValueInvalidation)
{
if (location.kind == IrOpKind::VmReg)
{
uint32_t& instIdx = vmRegValue[vmRegOp(location)];
if (instIdx != kInvalidInstIdx)
{
IrInst& inst = function.instructions[instIdx];
if (skipValueInvalidation)
{
switch (getCmdValueKind(inst.cmd))
{
case IrValueKind::Double:
case IrValueKind::Pointer:
case IrValueKind::Int:
return;
default:
break;
}
}
if (inst.needsReload)
restoreCallback(restoreCallbackCtx, inst);
ValueRestoreLocation currRestoreLocation = function.findRestoreLocation(instIdx, false);
if (location == currRestoreLocation.op)
function.recordRestoreLocation(instIdx, {});
instIdx = kInvalidInstIdx;
if (canBeRematerialized(inst.cmd) && OP_A(inst).kind == IrOpKind::Inst)
{
uint32_t depInstIdx = OP_A(inst).index;
IrInst& depInst = function.instructions[depInstIdx];
if (depInst.needsReload)
restoreCallback(restoreCallbackCtx, depInst);
if (location == currRestoreLocation.op)
function.recordRestoreLocation(depInstIdx, {});
}
}
}
else if (location.kind == IrOpKind::VmConst)
{
CODEGEN_ASSERT(!"VM constants are immutable");
}
}
void IrValueLocationTracking::invalidateRestoreVmRegs(int start, int count)
{
int end = count == -1 ? 255 : start + count;
if (end > maxReg)
end = maxReg;
for (int reg = start; reg <= end; reg++)
invalidateRestoreOp(IrOp{IrOpKind::VmReg, uint8_t(reg)}, false);
}
}
}