/**
* @file src/bin2llvmir/optimizations/x87_fpu/x87_fpu.cpp
* @brief x87 FPU analysis - replace fpu stack operations with FPU registers.
* @copyright (c) 2017 Avast Software, licensed under the MIT license
*/

#include <llvm/IR/CFG.h>
#include <llvm/IR/InstIterator.h>
#include <llvm/IR/Operator.h>

#include "retdec/bin2llvmir/optimizations/x87_fpu/x87_fpu.h"
#include "retdec/utils/string.h"
#include "retdec/bin2llvmir/providers/asm_instruction.h"
#include "retdec/bin2llvmir/utils/debug.h"
#define debug_enabled false
#include "retdec/bin2llvmir/utils/ir_modifier.h"
#include "retdec/bin2llvmir/utils/llvm.h"

using namespace llvm;
using namespace retdec::bin2llvmir::llvm_utils;

namespace retdec {
namespace bin2llvmir {

char X87FpuAnalysis::ID = 0;

static RegisterPass<X87FpuAnalysis> X(
		"x87-fpu",
		"x87 fpu register analysis",
		false, // Only looks at CFG
		false // Analysis Pass
);

X87FpuAnalysis::X87FpuAnalysis() :
		ModulePass(ID)
{

}

bool X87FpuAnalysis::runOnModule(llvm::Module& m)
{
	_module = &m;
	_config = ConfigProvider::getConfig(_module);
	_abi = AbiProvider::getAbi(_module);
	return run();
}

bool X87FpuAnalysis::runOnModuleCustom(
		llvm::Module& m,
		Config* c,
		Abi* a)
{
	_module = &m;
	_config = c;
	_abi = a;
	return run();
}

bool X87FpuAnalysis::run()
{
	if (_config == nullptr || _abi == nullptr)
	{
		return false;
	}
	if (!_abi->isX86())
	{
		return false;
	}

	top = _abi->getRegister(X87_REG_TOP);
	if (top == nullptr)
	{
		return false;
	}

	bool changed = false;
	for (Function& f : *_module)
	{
		LOG << f.getName().str() << std::endl;

		retdec::utils::NonIterableSet<BasicBlock*> seenBbs;
		std::map<Value*, int> topVals;

		for (auto& bb : f)
		{
			int topVal = 8;
			changed |= analyzeBb(seenBbs, topVals, &bb, topVal);
		}
	}

	removeAllFpuTopOperations();

	return changed;
}

bool X87FpuAnalysis::analyzeBb(
		retdec::utils::NonIterableSet<llvm::BasicBlock*>& seenBbs,
		std::map<llvm::Value*, int>& topVals,
		llvm::BasicBlock* bb,
		int topVal)
{

	std::queue<std::pair<llvm::BasicBlock*, int>> queue;
	queue.push({bb, topVal});
	bool changed = false;
	while(!queue.empty()) {
		auto pair = queue.front();
		auto currentBb = pair.first;
		topVal = pair.second;
		queue.pop();
		LOG << "\t" << currentBb->getName().str() << std::endl;

		if (seenBbs.has(currentBb)) {
			LOG << "\t\t" << "already seen" << std::endl;
			return false;
		}
		seenBbs.insert(currentBb);

		auto it = currentBb->begin();
		while (it != currentBb->end()) {
			Instruction *i = &(*it);
			++it;

			auto *l = dyn_cast<LoadInst>(i);
			auto *s = dyn_cast<StoreInst>(i);
			auto *add = dyn_cast<AddOperator>(i);
			auto *sub = dyn_cast<SubOperator>(i);
			auto *callStore = _config->isLlvmX87StorePseudoFunctionCall(i);
			auto *callLoad = _config->isLlvmX87LoadPseudoFunctionCall(i);

			if (l && l->getPointerOperand() == top) {
				topVals[i] = topVal;

				LOG << "\t\t" << AsmInstruction(i).getAddress()
					<< " @ " << std::dec << topVal << std::endl;
			} else if (s
					   && s->getPointerOperand() == top
					   && topVals.find(s->getValueOperand()) != topVals.end()) {
				auto fIt = topVals.find(s->getValueOperand());
				topVal = fIt->second;

				LOG << "\t\t" << AsmInstruction(i).getAddress()
					<< " @ " << std::dec << fIt->second << std::endl;
			} else if (add
					   && topVals.find(add->getOperand(0)) != topVals.end()
					   && isa<ConstantInt>(add->getOperand(1))) {
				auto fIt = topVals.find(add->getOperand(0));
				auto *ci = cast<ConstantInt>(add->getOperand(1));
				// Constants are i3, so 7 can be represented as -1, we need to either
				// use zext here (potentially dangerous if instructions were already
				// modified and there are true negative values), or compute values
				// in i3 arithmetics.
				int tmp = fIt->second + ci->getZExtValue();
				if (tmp > 8) {
					LOG << "\t\t\t" << "overflow fix " << tmp << " -> " << 8
						<< std::endl;
					tmp = 8;
				}
				topVals[i] = tmp;

				LOG << "\t\t" << AsmInstruction(i).getAddress() << std::dec
					<< " @ " << fIt->second << " + " << ci->getZExtValue()
					<< " = " << tmp << std::endl;
			} else if (sub
					   && topVals.find(sub->getOperand(0)) != topVals.end()
					   && isa<ConstantInt>(sub->getOperand(1))) {
				auto fIt = topVals.find(sub->getOperand(0));
				auto *ci = cast<ConstantInt>(sub->getOperand(1));
				// Constants are i3, so 7 can be represented as -1, we need to either
				// use zext here (potentially dangerous if instructions were already
				// modified and there are true negative values), or compute values
				// in i3 arithmetics.
				int tmp = fIt->second - ci->getZExtValue();
				if (tmp < 0) {
					LOG << "\t\t\t" << "undeflow fix " << tmp << " -> " << 7
						<< std::endl;
					tmp = 7;
				}
				topVals[i] = tmp;

				LOG << "\t\t" << AsmInstruction(i).getAddress() << std::dec
					<< " @ " << fIt->second << " - " << ci->getZExtValue() << " = "
					<< tmp << std::endl;
			} else if (callStore
					   && topVals.find(callStore->getArgOperand(0)) != topVals.end()) {
				auto fIt = topVals.find(callStore->getArgOperand(0));
				auto tmp = fIt->second;

				uint32_t regBase = X86_REG_ST0;
				// Storing value to an empty stack -> suspicious.
				if (tmp == 8) {
					tmp = 7;
					topVal = 7;
				}
				int regNum = tmp % 8;
				auto *reg = _abi->getRegister(regBase + regNum);

				LOG << "\t\t\t" << "store -- " << reg->getName().str() << std::endl;

				new StoreInst(callStore->getArgOperand(1), reg, callStore);
				_toRemove.insert(callStore->getArgOperand(0));
				// We need to remove this righ away.
				// It does not work if we store it to _toRemove set.
				callStore->eraseFromParent();
				changed = true;
			} else if (callLoad
					   && topVals.find(callLoad->getArgOperand(0)) != topVals.end()) {
				auto fIt = topVals.find(callLoad->getArgOperand(0));
				auto tmp = fIt->second;

				uint32_t regBase = X86_REG_ST0;
				// Loading value from an empty stack -> value may have been placed
				// there without us knowing, e.g. return value of some other
				// function.
				if (tmp == 8) {
					tmp = 7;
					topVal = 7;
				}
				int regNum = tmp % 8;
				auto *reg = _abi->getRegister(regBase + regNum);

				LOG << "\t\t\t" << "load -- " << reg->getName().str() << std::endl;

				auto *lTmp = new LoadInst(reg, "", callLoad);
				auto *conv = IrModifier::convertValueToType(lTmp, callLoad->getType(), callLoad);

				callLoad->replaceAllUsesWith(conv);
				// We need to remove this righ away.
				// It does not work if we store it to _toRemove set.
				callLoad->eraseFromParent();
				changed = true;
			} else if (callStore || callLoad) {
				LOG << "\t\t" << AsmInstruction(i).getAddress() << " @ "
					<< llvmObjToString(i) << std::endl;
				assert(false && "some other pattern");
				return false;
			}
		}

		for (auto succIt = succ_begin(currentBb), e = succ_end(currentBb); succIt != e; ++succIt) {
			auto *succ = *succIt;
			queue.push({succ, topVal});
		}
	}
	return changed;
}

void X87FpuAnalysis::removeAllFpuTopOperations()
{
	// std::unordered_set<llvm::Value*> toRemove;
	for (Function& f : *_module)
	for (auto it = inst_begin(&f), eIt = inst_end(&f); it != eIt; ++it)
	{
		Instruction* i = &*it;
		if (auto* l = dyn_cast<LoadInst>(i); l && l->getPointerOperand() == top)
		{
			_toRemove.insert(i);
		}
		if (auto* s = dyn_cast<StoreInst>(i); s && s->getPointerOperand() == top)
		{
			_toRemove.insert(i);
		}
	}
	IrModifier::eraseUnusedInstructionsRecursive(_toRemove);
}

} // namespace bin2llvmir
} // namespace retdec
