
#include "pialign/model-length.h"

using namespace std;
using namespace pialign;
using namespace gng;


void LengthModel::setMaxLen(int maxLenE, int maxLenF) {
    int maxLen = max(maxLenE,maxLenF);
    sepPhrases_ = std::vector< PyDist< WordId,PySparseIndex<WordId> > >(maxLen*2, PyDist< WordId,PySparseIndex<WordId> >(1.0,0.95));
    sepFor_ = std::vector< DirichletDist<int> >(maxLen*2,DirichletDist<int>(1.0,2)) ;
    sepTerm_ = std::vector< DirichletDist<int> >(maxLen*2,DirichletDist<int>(1.0,2)); 
    sepType_ = std::vector< std::vector<Prob> >(maxLen*2,std::vector<Prob>(3,0));
    sepFallbacks_ = std::vector<Prob>(maxLen*2,0);
	sepSplits_ = std::vector<Prob>(maxLen*2);
    for(int i = 0; i < maxLen*2; i++) 
        sepSplits_[i] = -1*log(std::max(i,1));
    sentPen_ = -1*log(maxLen*2);
}


Prob LengthModel::addSentence(const WordString & e, const WordString & f, SpanNode* node, StringWordSet & ePhrases, StringWordSet & fPhrases, PairWordSet & pairs, BaseMeasure* base) {
    if(!node || !node->add) return 0;
    // get the phrase IDs
    node->prob = 0;
    const Span & mySpan = node->span;
    int s=mySpan.es,t=mySpan.ee,u=mySpan.fs,v=mySpan.fe;
    WordId eId = ePhrases.getId(e.substr(s,t-s),true),
        fId = fPhrases.getId(f.substr(u,v-u),true);
    // get the phrase pair ID and save the base probability if necessary
    node->phraseid = pairs.getId(WordPairHash(eId, fId, GlobalVars::maxPhrase),true);
    int idx = saveIdx(node->phraseid,node->span.length()-1);
    int toAdd = node->type;
    // this is breakdown
    if(toAdd == TYPE_REG || toAdd == TYPE_INV) {
        node->prob += addSentence(e,f,node->right,ePhrases,fPhrases,pairs,base);
        node->prob += addSentence(e,f,node->left, ePhrases,fPhrases,pairs,base);
        node->prob += sepSplits_[idx];
        PRINT_DEBUG(" LengthModel::splitProb("<<idx<<"): "<<sepSplits_[idx]<<endl, 2);
    } else if(toAdd != TYPE_GEN) {
        base->add(node->span,node->phraseid,node->baseProb,node->baseElems);
        node->prob += node->baseProb;
        PRINT_DEBUG(" LengthModel::baseProb("<<node->phraseid<<"): "<<node->baseProb<<endl, 2);
        toAdd = TYPE_TERM;
    }
    // find the left and right nodes
    WordId lId = (node->left?node->left->phraseid:-1),
            rId = (node->right?node->right->phraseid:-1);
    // add the generative probability
    if(node->type == TYPE_GEN) {
        node->prob = log(sepPhrases_[idx].getProb(node->phraseid,0));
        PRINT_DEBUG(" LengthModel::genProb("<<idx<<","<<node->phraseid<<"): "<<log(sepPhrases_[idx].getProb(node->phraseid,0))<<endl, 2);
        sepPhrases_[idx].addExisting(node->phraseid);
    } else {
        node->prob += log(sepPhrases_[idx].getFallbackProb());
        PRINT_DEBUG(" LengthModel::fallProb("<<idx<<"): "<<log(sepPhrases_[idx].getFallbackProb())<<endl, 2);
        Prob addProb = addType(toAdd,idx);
        PRINT_DEBUG(" LengthModel::addProb("<<idx<<","<<toAdd<<"):"<<addProb<<endl, 2);
        node->prob += addProb;
        sepPhrases_[idx].addNew(node->phraseid,lId,rId,toAdd);
    }
    addAverageDerivation(node->phraseid,sepPhrases_[idx].getTotal(node->phraseid),node->prob);
    return node->prob;
}


SpanNode* LengthModel::removePhrasePair(WordId jId, BaseMeasure* base) {
    if(jId < 0) return 0;
#ifdef DEBUG_ON
    if(jId >= (int)phraseIdxs_.size())
        THROW_ERROR("Overflown phraseIdx in LengthModel: "<<jId << " >= " << phraseIdxs_.size() << std::endl);
#endif
    SpanNode* ret = new SpanNode(Span(0,0,0,0));
    ret->phraseid = jId;
    int idx = phraseIdxs_[jId];
    PyDist< WordId,PySparseIndex<WordId> > & dist = sepPhrases_[idx];
    ret->prob = dist.remove(jId);
    // this was generated from the fallback
    if(dist.isRemovedTable()) {
        const PyTable<WordId> & table = dist.getLastTable();
        ret->prob += removeType(table.type,idx);
        ret->type = table.type;
        // generated by breaking down
        if(table.right >= 0) {
            ret->right = removePhrasePair(table.right,base);
            ret->prob += ret->right->prob;
            ret->left = removePhrasePair(table.left,base);
            ret->prob += ret->left->prob;
            ret->prob += sepSplits_[idx];
        }
        // generated directly from the base measure
        else {
            ret->type = TYPE_BASE;
            ret->baseProb = base->getBase(jId);
            ret->prob += ret->baseProb;
        }
    }
    // this was generated from the cache
    else 
        ret->type = TYPE_GEN;
    return ret;
}


void LengthModel::initialize(const WordString & e, const WordString & f, 
        ParseChart & chart, WordString & jIds, int* tCounts) {
    int len = e.length()+f.length();
    for(int i = 1; i < len; i++) {
        sepFallbacks_[i] = log(sepPhrases_[i].getFallbackProb());
        sepType_[i][0] = log(sepTerm_[i].getProb(0));
        sepType_[i][1] = log(sepTerm_[i].getProb(1)*sepFor_[i].getProb(0));
        sepType_[i][2] = log(sepTerm_[i].getProb(1)*sepFor_[i].getProb(1));
    }
}


void LengthModel::printStats(std::ostream &out) const {
    out << " s =";
    for(int i = 0; i < (int)sepPhrases_.size(); i++) 
        out<<" "<<sepPhrases_[i].getStrength();
    out << std::endl << " d =";
    for(int i = 0; i < (int)sepPhrases_.size(); i++)
        out<<" "<<sepPhrases_[i].getDiscount();
    out << std::endl << " t =";
    for(int i = 0; i < (int)sepPhrases_.size(); i++)
        out<<" "<<exp(sepType_[i][0])<<"/"<<exp(sepType_[i][1])<<"/"<<exp(sepType_[i][2]);
    out << std::endl;
}



void LengthModel::calcPhraseTable(const PairWordSet & jPhrases, std::vector<Prob> & eProbs, std::vector<Prob> & fProbs, std::vector<Prob> & jProbs, std::vector<Prob> & dProbs) {
    Prob myProb;
    for(PairWordSet::const_iterator it = jPhrases.begin(); it != jPhrases.end(); it++) {
        int idx = phraseIdxs_[it->second];
        if(idx || rememberNull_) {
            myProb = sepPhrases_[idx].getProb(it->second,0);
            if(myProb != 0.0) {
                int first = WordPairFirst(it->first, GlobalVars::maxPhrase), second = WordPairSecond(it->first, GlobalVars::maxPhrase);
                if((int)jProbs.size() <= it->second) jProbs.resize(it->second+1,0);
                jProbs[it->second] = myProb;
                if((int)eProbs.size() <= first) eProbs.resize(first+1,0);
                eProbs[first] += myProb;
                if((int)fProbs.size() <= second) fProbs.resize(second+1,0);
                fProbs[second] += myProb;
            }
        }
    }
    dProbs = derivations_;
}
