/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * File:   targetinterpretation.cpp
 * Author: pgess <v.melnychenko@xreate.org>
 *
 * Created on June 29, 2016, 6:45 PM
 */

/**
 * \file    targetinterpretation.h
 * \brief   Interpretation support. See more details on [Interpretation](/d/concepts/interpretation/)
 */

#include "compilation/targetinterpretation.h"
#include "pass/interpretationpass.h"
#include "analysis/typeinference.h"
#include "llvmlayer.h"
#include "compilation/decorators.h"
#include "compilation/i12ninst.h"
#include "compilation/intrinsics.h"

#include <boost/scoped_ptr.hpp>
#include <iostream>
#include <csignal>

using namespace std;
using namespace xreate::compilation;

namespace xreate{
namespace interpretation{

const Expression EXPRESSION_FALSE = Expression(Atom<Number_t>(0));
const Expression EXPRESSION_TRUE = Expression(Atom<Number_t>(1));

CodeScope*
InterpretationScope::processOperatorIf(const Expression& expression) {
    const Expression& exprCondition = process(expression.getOperands()[0]);

    if (exprCondition == EXPRESSION_TRUE) {
        return expression.blocks.front();
    }

    return expression.blocks.back();
}

CodeScope*
InterpretationScope::processOperatorSwitch(const Expression& expression) {
    const Expression& exprCondition = process(expression.operands[0]);

    bool flagHasDefault = expression.operands[1].op == Operator::CASE_DEFAULT;

    //TODO check that one and only one case variant is appropriate
    for (size_t size = expression.operands.size(), i = flagHasDefault ? 2 : 1; i < size; ++i) {
        const Expression& exprCase = process(expression.operands[i]);

        if (function->getScope((const CodeScope*) exprCase.blocks.front())->processScope() == exprCondition) {
            return exprCase.blocks.back();
        }
    }

    if (flagHasDefault) {
        const Expression& exprCaseDefault = expression.operands[1];
        return exprCaseDefault.blocks.front();
    }

    assert(false && "Switch has no appropriate variant");
    return nullptr;
}

CodeScope*
InterpretationScope::processOperatorSwitchVariant(const Expression& expression) {
  const ExpandedType& conditionT = function->__pass->man->root->getType(expression.operands.at(0));
  const Expression& conditionE = process(expression.operands.at(0));
  assert(conditionE.op == Operator::VARIANT);
  const string& aliasS = expression.bindings.front();

  unsigned caseExpectedId = (int) conditionE.getValueDouble();
  auto itFoundValue = std::find_if(++expression.operands.begin(), expression.operands.end(), [caseExpectedId](const auto& caseActualE){
    return (unsigned) caseActualE.getValueDouble() == caseExpectedId;
  });
  assert(itFoundValue != expression.operands.end());

  int caseScopeId = itFoundValue - expression.operands.begin() - 1;
  auto caseScopeRef = expression.blocks.begin();
  std::advance(caseScopeRef, caseScopeId);
  InterpretationScope* scopeI12n = function->getScope(*caseScopeRef);

  if(conditionE.operands.size()) {
    Expression valueE(Operator::LIST, {});
    valueE.operands = conditionE.operands;
    valueE.bindings = conditionT->__operands.at(caseExpectedId).fields;

    scopeI12n->overrideBindings({
        {valueE, aliasS}
    });
  };

  return *caseScopeRef;
}

llvm::Value*
InterpretationScope::processLate(const InterpretationOperator& op, const Expression& expression, const Context& context, const std::string& hintAlias) {
    switch(op) {
    case IF_INTERPRET_CONDITION:
    {
        CodeScope* scopeResult = processOperatorIf(expression);

        llvm::Value* result = context.function->getBruteScope(scopeResult)->compile();
        return result;
    }

    case SWITCH_INTERPRET_CONDITION:
    {
        CodeScope* scopeResult = processOperatorSwitch(expression);

        llvm::Value* result = context.function->getBruteScope(scopeResult)->compile();
        return result;
    }

    case SWITCH_VARIANT:
    {
        CodeScope* scopeResult = processOperatorSwitchVariant(expression);
        const Expression& condCrudeE = expression.operands.at(0);
        const Expression& condE = process(condCrudeE);

        const string identCondition = expression.bindings.front();
        auto scopeCompilation = Decorators<CachedScopeDecoratorTag>::getInterface(
        context.function->getBruteScope(scopeResult));

        if(condE.operands.size()) {
            //override value
            Symbol symbCondition{ScopedSymbol{scopeResult->__identifiers.at(identCondition), versions::VERSION_NONE}, scopeResult};
            scopeCompilation->overrideDeclarations({
                {symbCondition, Expression(condE.operands.at(0))}}
            );

            //set correct type for binding:
            const ExpandedType& typeVariant = function->__pass->man->root->getType(condCrudeE);
            int conditionIndex = condE.getValueDouble();
            ScopedSymbol symbolInternal = scopeResult->findSymbolByAlias(identCondition);
            scopeResult->__declarations[symbolInternal].bindType(typeVariant->__operands.at(conditionIndex));
        }

        llvm::Value* result = context.function->getBruteScope(scopeResult)->compile();
        return result;
    }

    case SWITCH_LATE:
    {
      return nullptr;
//        latereasoning::LateReasoningCompiler compiler(dynamic_cast<InterpretationFunction*>(this->function), context);
//        return compiler.processSwitchLateStatement(expression, "");
    }

    case FOLD_INTERPRET_INPUT:
    {
      //initialization
      const Expression& containerE = process(expression.getOperands().at(0));
      const TypeAnnotation& accumT = expression.type;
      assert(containerE.op == Operator::LIST);
      CodeScope* bodyScope = expression.blocks.front();
      const string& elAlias = expression.bindings[0];
      Symbol elS{ScopedSymbol{bodyScope->__identifiers.at(elAlias), versions::VERSION_NONE}, bodyScope};
      const std::string& accumAlias = expression.bindings[1];
      llvm::Value* accumRaw = context.scope->process(expression.getOperands().at(1), accumAlias, accumT);

      InterpretationScope* bodyI12n = function->getScope(bodyScope);
      auto bodyBrute = Decorators<CachedScopeDecoratorTag>::getInterface(context.function->getBruteScope(bodyScope));
      const std::vector<Expression>& containerVec = containerE.getOperands();

      for(size_t i = 0; i < containerVec.size(); ++i) {
        const Expression& elE = containerVec[i];

        bodyI12n->overrideBindings({
            {elE, elAlias}
        });
        bodyBrute->overrideDeclarations({
            {elS, elE}
        }); //resets bodyBrute
        bodyBrute->bindArg(accumRaw, string(accumAlias));
        accumRaw = bodyBrute->compile();
      }

      return accumRaw;
    }

//    case FOLD_INF_INTERPRET_INOUT:
//    {
//    }

        //TODO refactor as InterpretationCallStatement class
    case CALL_INTERPRET_PARTIAL:
    {
        const std::string &calleeName = expression.getValueString();
        IBruteScope* scopeUnitSelf = context.scope;
        ManagedFnPtr callee = this->function->__pass->man->root->findFunction(calleeName);
        const  I12nFunctionSpec& calleeData = FunctionInterpretationHelper::getSignature(callee);
        std::vector<llvm::Value *> argsActual;
        PIFnSignature sig;
        sig.declaration = callee;

        for(size_t no = 0, size = expression.operands.size(); no < size; ++no) {
            const Expression& op =  expression.operands[no];

            if (calleeData.signature.at(no) == INTR_ONLY) {
                sig.bindings.push_back(process(op));
                continue;
            }

            argsActual.push_back(scopeUnitSelf->process(op));
        }

        TargetInterpretation* man = dynamic_cast<TargetInterpretation*> (this->function->__pass);
        PIFunction* pifunction =  man->getFunction(move(sig));

        llvm::Function* raw = pifunction->compile();
        boost::scoped_ptr<BruteFnInvocation> statement(new BruteFnInvocation(raw, man->pass->man->llvm));
        return (*statement)(move(argsActual));
    }
    
    case QUERY_LATE:
    {
      return nullptr;
//      return IntrinsicQueryInstruction(
//          dynamic_cast<InterpretationFunction*>(this->function))
//          .processLate(expression, context);
    }

    default: break;
    }

    assert(false && "Unknown late interpretation operator");
    return nullptr;
}

llvm::Value*
InterpretationScope::compile(const Expression& expression, const Context& context, const std::string& hintAlias) {
    InterpretationData data = Attachments::get<InterpretationData>(expression);

    if (data.op != InterpretationOperator::NONE) {
        return processLate(data.op, expression, context, hintAlias);
    }

    Expression result = process(expression);
    return context.scope->process(result, hintAlias);
}

Expression
InterpretationScope::process(const Expression& expression) {
#ifndef NDEBUG
    if (expression.tags.count("bpoint")) {
        std::raise(SIGINT);
    }
#endif

    PassManager* man = function->__pass->man;

    switch (expression.__state) {
    case Expression::INVALID:
        assert(false);

    case Expression::NUMBER:
    case Expression::STRING:
        return expression;

    case Expression::IDENT:
    {
        Symbol s = Attachments::get<IdentifierSymbol>(expression);
        return Parent::processSymbol(s);
    }

    case Expression::COMPOUND:
        break;

    default: assert(false);
    }

    switch (expression.op) {
    case Operator::EQU:
    {
        const Expression& left = process(expression.operands[0]);
        const Expression& right = process(expression.operands[1]);

        if (left == right) return EXPRESSION_TRUE;
        return EXPRESSION_FALSE;
    }

    case Operator::NE:
    {
        const Expression& left = process(expression.operands[0]);
        const Expression& right = process(expression.operands[1]);

        if (left == right) return EXPRESSION_FALSE;
        return EXPRESSION_TRUE;
    }

    case Operator::LOGIC_AND:
    {
        assert(expression.operands.size() == 1);
        return process (expression.operands[0]);
    }


        //        case Operator::LOGIC_OR:
    case Operator::CALL:
    {
        const std::string &fnName = expression.getValueString();
        ManagedFnPtr fnAst = man->root->findFunction(fnName);
        InterpretationFunction* fnUnit = this->function->__pass->getFunction(fnAst);

        vector<Expression> args;
        args.reserve(expression.getOperands().size());

        for(size_t i = 0, size = expression.getOperands().size(); i < size; ++i) {
            args.push_back(process(expression.getOperands()[i]));
        }

        return fnUnit->process(args);
    }

    case Operator::CALL_INTRINSIC:
    {
      const Expression& opCallIntrCrude = expression;
      vector<Expression> argsActual;
      argsActual.reserve(opCallIntrCrude.getOperands().size());
      for(const auto& op: opCallIntrCrude.getOperands()) {
        argsActual.push_back(process(op));
      }

      Expression opCallIntr(Operator::CALL_INTRINSIC, {});
      opCallIntr.setValueDouble(opCallIntrCrude.getValueDouble());
      opCallIntr.operands = argsActual;

      compilation::IntrinsicCompiler compiler(man);
      return compiler.interpret(opCallIntr);
    }

    case Operator::QUERY:
    {
      return Expression();
//        return IntrinsicQueryInstruction(dynamic_cast<InterpretationFunction*>(this->function))
//            .process(expression);
    }
    
    case Operator::QUERY_LATE:
    {
        assert(false && "Can't be interpretated");
        return Expression();
    }

    case Operator::IF:
    {
        CodeScope* scopeResult = processOperatorIf(expression);
        return function->getScope(scopeResult)->processScope();
    }

    case Operator::SWITCH:
    {
        CodeScope* scopeResult = processOperatorSwitch(expression);
        return function->getScope(scopeResult)->processScope();
    }

    case Operator::SWITCH_VARIANT:
    {
        CodeScope* scopeResult = processOperatorSwitchVariant(expression);
        return function->getScope(scopeResult)->processScope();
    }

    case Operator::VARIANT:
    {
      Expression result{Operator::VARIANT, {}};
      result.setValueDouble(expression.getValueDouble());

      for(const Expression& op: expression.operands){
        result.operands.push_back(process(op));
      }

      return result;
    }

    case Operator::INDEX:
    {
      Expression aggrE = process(expression.operands[0]);

      for (size_t keyId = 1; keyId < expression.operands.size(); ++keyId) {
        const Expression& keyE = process(expression.operands[keyId]);

        if (keyE.__state == Expression::STRING) {
            const string& fieldExpectedS = keyE.getValueString();
            unsigned fieldId;
            for(fieldId = 0; fieldId < aggrE.bindings.size(); ++fieldId){
              if (aggrE.bindings.at(fieldId) == fieldExpectedS){break;}
            }
            assert(fieldId < aggrE.bindings.size());
            aggrE = Expression(aggrE.operands.at(fieldId));
            continue;
        }

        if (keyE.__state == Expression::NUMBER) {
            int opId = keyE.getValueDouble();
            aggrE = Expression(aggrE.operands.at(opId));
            continue;
        }

        assert(false && "Inappropriate key");
      }

      return aggrE;
    }

    case Operator::FOLD:
    {
        const Expression& exprInput = process(expression.getOperands()[0]);
        const Expression& exprInit = process(expression.getOperands()[1]);

        const std::string& argEl = expression.bindings[0];
        const std::string& argAccum = expression.bindings[1];

        InterpretationScope* body = function->getScope(expression.blocks.front());

        Expression accum = exprInit;
        for(size_t size = exprInput.getOperands().size(), i = 0; i < size; ++i) {
            body->overrideBindings({
                {exprInput.getOperands()[i], argEl},
                {accum, argAccum}
            });

            accum = body->processScope();
        }

        return accum;
    }

    case Operator::LIST:
    case Operator::LIST_RANGE:
    {
        Expression result(expression.op,{});
        result.operands.resize(expression.operands.size());
        result.bindings = expression.bindings;

        int keyId = 0;
        for(const Expression& opCurrent : expression.operands) {
            result.operands[keyId++] = process(opCurrent);
        }

        return result;
    }

        //        case Operator::MAP: {
        //            break;
        //        }

    default: break;
    }

    return expression;
}

InterpretationFunction*
TargetInterpretation::getFunction(IBruteFunction* unit) {
    if (__dictFunctionsByUnit.count(unit)) {
        return __dictFunctionsByUnit.at(unit);
    }

    InterpretationFunction* f = new InterpretationFunction(unit->getASTFn(), this);
    __dictFunctionsByUnit.emplace(unit, f);
    assert(__functions.emplace(unit->getASTFn().id(), f).second);

    return f;
}

PIFunction*
TargetInterpretation::getFunction(PIFnSignature&& sig) {
    auto f = __pifunctions.find(sig);
    if (f != __pifunctions.end()) {
        return f->second;
    }

    PIFunction* result  = new PIFunction(PIFnSignature(sig), __pifunctions.size(), this);
    __pifunctions.emplace(move(sig), result);
    assert(__dictFunctionsByUnit.emplace(result->fnRaw, result).second);

    return result;
}

InterpretationScope*
TargetInterpretation::transformContext(const Context& c) {
    return this->getFunction(c.function)->getScope(c.scope->scope);
}

llvm::Value*
TargetInterpretation::compile(const Expression& expression, const Context& ctx, const std::string& hintAlias) {
    return transformContext(ctx)->compile(expression, ctx, hintAlias);
}

InterpretationFunction::InterpretationFunction(const ManagedFnPtr& function, Target<TargetInterpretation>* target)
: Function<TargetInterpretation>(function, target) { }

Expression
InterpretationFunction::process(const std::vector<Expression>& args) {
    InterpretationScope* body = getScope(__function->__entry);

    list<pair<Expression, string>> bindings;
    for(size_t i = 0, size = args.size(); i < size; ++i) {
        bindings.push_back(make_pair(args.at(i), body->scope->__bindings.at(i)));
    }

    body->overrideBindings(bindings);
    return body->processScope();
}

//                  Partial function interpretation

typedef BasicBruteFunction BruteFunction;

class PIBruteFunction : public BruteFunction{
public:
  PIBruteFunction(ManagedFnPtr f, std::set<size_t>&& arguments, size_t id, CompilePass* p)
    : BruteFunction(f, p), argumentsActual(move(arguments)), __id(id) { }

  virtual std::string
  prepareName() override {
      return BruteFunction::prepareName() + "_" + std::to_string(__id);
  }

protected:
    std::vector<llvm::Type*>
    prepareSignature() override {
        LLVMLayer* llvm = BruteFunction::pass->man->llvm;
        AST* ast = BruteFunction::pass->man->root;
        CodeScope* entry = IBruteFunction::__entry;
        std::vector<llvm::Type*> signature;

        for(size_t no : argumentsActual) {
            VNameId argId = entry->__identifiers.at(entry->__bindings.at(no));
            ScopedSymbol arg{argId, versions::VERSION_NONE};

            signature.push_back(llvm->toLLVMType(ast->expandType(entry->__declarations.at(arg).type)));
        }

        return signature;
    }

    llvm::Function::arg_iterator
    prepareBindings() override{
        CodeScope* entry = IBruteFunction::__entry;
        IBruteScope* entryCompilation = BruteFunction::getBruteScope(entry);
        llvm::Function::arg_iterator fargsI = BruteFunction::raw->arg_begin();

        for(size_t no : argumentsActual) {
            ScopedSymbol arg{entry->__identifiers.at(entry->__bindings.at(no)), versions::VERSION_NONE};

            entryCompilation->bindArg(&*fargsI, arg);
            fargsI->setName(entry->__bindings.at(no));
            ++fargsI;
        }

        return fargsI;
    }

private:
    std::set<size_t> argumentsActual;
    size_t __id;
} ;

PIFunction::PIFunction(PIFnSignature&& sig, size_t id, TargetInterpretation* target)
: InterpretationFunction(sig.declaration, target), instance(move(sig)) {
    const I12nFunctionSpec&  functionData = FunctionInterpretationHelper::getSignature(instance.declaration);

    std::set<size_t> argumentsActual;
    for (size_t no = 0, size = functionData.signature.size(); no < size; ++no) {
        if (functionData.signature.at(no) != INTR_ONLY) {
            argumentsActual.insert(no);
        }
    }

    fnRaw = new PIBruteFunction(instance.declaration, move(argumentsActual), id, target->pass);
    CodeScope* entry = instance.declaration->__entry;
    auto entryUnit = Decorators<CachedScopeDecoratorTag>::getInterface<>(fnRaw->getEntry());
    InterpretationScope* entryIntrp = InterpretationFunction::getScope(entry);

    list<pair<Expression, std::string>> bindingsPartial;
    list<pair<Symbol, Expression>> declsPartial;

    for(size_t no = 0, sigNo = 0, size = entry->__bindings.size(); no < size; ++no) {
        if(functionData.signature.at(no) == INTR_ONLY) {
            bindingsPartial.push_back({instance.bindings[sigNo], entry->__bindings[no]});

            VNameId argId = entry->__identifiers.at(entry->__bindings[no]);
            Symbol argSymbol{ScopedSymbol
                {argId, versions::VERSION_NONE}, entry};
            declsPartial.push_back({argSymbol, instance.bindings[sigNo]});
            ++sigNo;
        }
    }

    entryIntrp->overrideBindings(bindingsPartial);
    entryUnit->overrideDeclarations(declsPartial);
}

llvm::Function*
PIFunction::compile() {
    llvm::Function* raw = fnRaw->compile();

    return raw;
}

bool operator<(const PIFnSignature& lhs, const PIFnSignature& rhs) {
    if (lhs.declaration.id() != rhs.declaration.id()) {
        return lhs.declaration.id() < rhs.declaration.id();
    }

    return lhs.bindings < rhs.bindings;
}

bool operator<(const PIFnSignature& lhs, PIFunction * const rhs) {
    return lhs < rhs->instance;
}

bool operator<(PIFunction * const lhs, const PIFnSignature& rhs) {
    return lhs->instance < rhs;
}

}
}

/** \class xreate::interpretation::InterpretationFunction
 *
 * Holds list of xreate::interpretation::InterpretationScope 's focused on interpretation of individual code scopes
 *
 * There is particulat subclass  PIFunction intended to represent partially interpreted functions.
 *\sa TargetInterpretation, [Interpretation Concept](/d/concepts/interpretation/)
 */

/** \class xreate::interpretation::TargetInterpretation
 *
 * TargetInterpretation is executed during compilation and is intended to preprocess eligible for interpretation parts of a source code.
 *
 * Keeps a list of InterpretationFunction / PIFunction that represent interpretation for an individual functions.
 *
 * There is \ref InterpretationScopeDecorator that embeds interpretation to an overall compilation process.
 * \sa InterpretationPass, compilation::Target, [Interpretation Concept](/d/concepts/interpretation/)
 *
 */
