//
//  Copyright(C) 2011 Taro Watanabe <taro.watanabe@nict.go.jp>
//

//
// filter for transforming terminals...
// we assume parse forest, like forest generated by cicada_filter_{penntreebank,charniak}
//


#include <iostream>
#include <vector>
#include <utility>
#include <string>
#include <algorithm>
#include <iterator>

#include "cicada/hypergraph.hpp"
#include "cicada/sentence.hpp"
#include "cicada/span_edge.hpp"
#include "cicada/vocab.hpp"

#include <boost/program_options.hpp>
#include <boost/filesystem.hpp>
#include <boost/tokenizer.hpp>
#include <boost/shared_ptr.hpp>

#include "utils/program_options.hpp"
#include "utils/compress_stream.hpp"
#include "utils/bithack.hpp"

typedef cicada::HyperGraph hypergraph_type;
typedef cicada::Sentence   sentence_type;
typedef cicada::Vocab      vocab_type;

typedef boost::filesystem::path path_type;

typedef std::pair<int, int> span_type;
typedef std::vector<span_type, std::allocator<span_type> > span_set_type;
typedef sentence_type phrase_type;

path_type input_file = "-";
path_type output_file = "-";
path_type map_file;

int debug = 0;

void options(int argc, char** argv);

int main(int argc, char** argv)
{
  try {
    options(argc, argv);
    
    utils::compress_istream is(input_file, 1024 * 1024);
    utils::compress_ostream os(output_file);
    
    if (map_file != "-" && ! boost::filesystem::exists(map_file))
      throw std::runtime_error("no map file: " + map_file.string());
    
    utils::compress_istream ms(map_file);
    
    hypergraph_type hypergraph;
    sentence_type   sentence;
    span_set_type   spans;
    phrase_type     rhs;
    
    while (1) {
      is >> hypergraph;
      ms >> sentence;
      
      if (! is || ! ms) break;

      if (! hypergraph.is_valid()) {
	os << hypergraph << '\n';
	continue;
      }
      
      // map terminals...
      // we will first compute spans, then, perform terminal mapping..
      
      spans.clear();
      spans.resize(hypergraph.edges.size());
      
      cicada::span_edge(hypergraph, spans);
      
      for (size_t edge_id = 0; edge_id != hypergraph.edges.size(); ++ edge_id) {
	hypergraph_type::edge_type& edge = hypergraph.edges[edge_id];
	const span_type& span = spans[edge_id];
	const hypergraph_type::symbol_type lhs = edge.rule->lhs;
	
	rhs.clear();
	
	int pos = 0;
	int span_pos = span.first;
	hypergraph_type::rule_type::symbol_set_type::const_iterator riter_end = edge.rule->rhs.end();
	for (hypergraph_type::rule_type::symbol_set_type::const_iterator riter = edge.rule->rhs.begin(); riter != riter_end; ++ riter) {
	  if (riter->is_non_terminal()) {
	    const int __non_terminal_index = riter->non_terminal_index();
	    const int non_terminal_pos = utils::bithack::branch(__non_terminal_index <= 0, pos, __non_terminal_index - 1);
	    ++ pos;
	    
	    // compute span_pos from antecedent node...
	    
	    rhs.push_back(*riter);
	    span_pos = spans[hypergraph.nodes[edge.tails[non_terminal_pos]].edges.front()].second;
	  } else if (*riter != vocab_type::EPSILON) {
	    rhs.push_back(sentence[span_pos]);
	    ++ span_pos;
	  }
	}
	
	edge.rule = hypergraph_type::rule_type::create(hypergraph_type::rule_type(lhs, rhs.begin(), rhs.end()));
      }
      
      os << hypergraph << '\n';
    }
    
    if (is || ms)
      throw std::runtime_error("# of hypergraphs and # of sentences do not match");
  }
  catch (std::exception& err) {
    std::cerr << "error: " << err.what() << std::endl;
    return 1;
  }
  return 0;
}


void options(int argc, char** argv)
{
  namespace po = boost::program_options;
  
  po::options_description desc("options");
  desc.add_options()
    ("input",     po::value<path_type>(&input_file)->default_value(input_file),   "input file")
    ("output",    po::value<path_type>(&output_file)->default_value(output_file), "output")
    ("map",       po::value<path_type>(&map_file)->default_value(map_file), "map terminal symbols")
    
    ("debug", po::value<int>(&debug)->implicit_value(1), "debug level")
        
    ("help", "help message");
  
  po::variables_map vm;
  po::store(po::parse_command_line(argc, argv, desc, po::command_line_style::unix_style & (~po::command_line_style::allow_guessing)), vm);
  po::notify(vm);
  
  if (vm.count("help")) {
    std::cout << argv[0] << " [options]" << '\n' << desc << '\n';
    exit(0);
  }
}
