#include "astgenvisitor1.h"
#include "ast.h"
#include "strutils.h"

#include <assert.h>
#include <ctype.h> // for isupper

using std::vector;
using std::string;

astgenvisitor1::astgenvisitor1(const char* filename, const char* visitorFilename)
	: m_out(filename)
	, m_filename(filename)
	, m_visitorFilename(visitorFilename)
{
}

astgenvisitor1::~astgenvisitor1()
{
}

void astgenvisitor1::visit_lhs_IDENT_SEPARATOR(const lhs_IDENT_SEPARATOR *plhs_IDENT_SEPARATOR)
{
	m_rulename = *plhs_IDENT_SEPARATOR->m_IDENT;
}

void astgenvisitor1::visit_grammar_grammar_production(const grammar_grammar_production *pgrammar_grammar_production)
{
	if (pgrammar_grammar_production->m_grammar.get() == 0) // start of the AST
	{
		m_out << "#ifndef " << strToIdentifier(m_filename) << "_HPP_GUARD_\n";
		m_out << "#define " << strToIdentifier(m_filename) << "_HPP_GUARD_\n";
		m_out << "#include <string>\n";
		m_out << "#include <list>\n";
		m_out << "#include \"" << m_visitorFilename << "\"\n\n";
		m_out << "class node\n";
		m_out << "{\n";
		m_out << "	public:\n";
		m_out << "		virtual ~node() {}\n";
		m_out << "		virtual void accept( Visitor * ) const = 0;\n";
		m_out << "};\n\n";
	}
	else
	{
		pgrammar_grammar_production->m_grammar->accept(this);
	}
	pgrammar_grammar_production->m_production->accept(this);

}

void astgenvisitor1::visit_grammar_grammar_COMMENT(const grammar_grammar_COMMENT *pgrammar_grammar_COMMENT)
{
	return;
	/*
	if (pgrammar_grammar_COMMENT->m_grammar.get() != 0)
	{
		pgrammar_grammar_COMMENT->m_grammar->accept(this);
	}
	*/
}

void astgenvisitor1::visit_expression_base_OPT(const expression_base_OPT *pexpression_base_OPT)
{
	// not yet implemented
	assert(0);
	pexpression_base_OPT->m_base->accept(this);
}

void astgenvisitor1::visit_production_lhs_expressionListList_TERMINATOR(const production_lhs_expressionListList_TERMINATOR *pproduction_lhs_expressionListList_TERMINATOR)
{

	pproduction_lhs_expressionListList_TERMINATOR->m_lhs->accept(this);
	// m_rulename is now set (by astgenvisitor1::visit_lhs_IDENT_SEPARATOR
	if (!beginsWithStr(m_rulename))
	{
		if (pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList.get() != 0)
		{
			if (pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->size() > 1)
			{
				if (!endsWithList(m_rulename))
				{
					// First time we'll collect up the class names so we
					// can use them in the base class.
					vector<string> names;
					for (vector<vector<expression*>*>::const_iterator i = pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->begin();
						  i != pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->end();
						  ++i)
					{
						if (*i)
						{
							vector<string> dummy;
							string name = create_class_name(i, dummy);
							names.push_back(name);
						}
					}

					// Now write out the base class
					if (!isSpecialRule(m_rulename))
					{
						m_out << "class " << m_rulename << ": public node\n";
						m_out << "{\n";
						m_out << "\tpublic:\n";
						/*
						m_out << "\t\tenum clsType {\n";
						for (vector<string>::const_iterator i = names.begin();
								i != names.end(); ++i)
							m_out << "\t\t\t" << *i << "Type,\n";
						m_out << "\t\t};\n\n";
						m_out << "\t\t" << m_rulename << "(const clsType t)\n";
						m_out << "\t\t\t: type(t) {}\n";
						*/
						m_out << "\t\t" << m_rulename << "()\n";
						m_out << "\t\t\t{}\n";
						m_out << "\t\tvirtual ~" << m_rulename << "() {}\n\n";
						/*
						m_out << "\t\tclsType type;\n\n";
						for (vector<string>::const_iterator i = names.begin();
								i != names.end(); ++i)
						{
							m_out << "\t\tvirtual " << *i << "* get" << *i << "()\n";
							m_out << "\t\t{\n";
							m_out << "\t\t\treturn 0;\n";
							m_out << "\t\t}\n";
						}
						*/
						m_out << "};\n\n";
					}

					// Second time we'll actually write out the classes.
					for (vector<vector<expression*>*>::const_iterator i = pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->begin();
						  i != pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->end();
						  ++i)
					{
						if (*i)
						{
							write_class(i, m_rulename);
						}
					}
				}
				else
				{
					vector<vector<expression*>*>::const_iterator i =  pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->begin();
					if ((*i)->size() > 0 && *(*i)->begin())
					{
						(*(*i)->begin())->accept(this); // fills out m_ident
						if (beginsWithStr(m_ident))
							addListType(m_rulename, "std::string");
						else
							addListType(m_rulename, m_ident);
					}
				}
				
			}
			else
			{
				write_class(pproduction_lhs_expressionListList_TERMINATOR->m_expressionListList->begin(), "node");
			}
		}
	}
}

void astgenvisitor1::write_argument(const vector<string>::const_iterator& i, int arg_num)
{
	if (isAllCaps(*i) || beginsWithStr(*i))
	{
		m_out << "\t\t\tstd::string* pNew" << *i << arg_num;
	}
	else if (endsWithList(*i))
	{
		string listType = getListType(*i);
		m_out << "\t\t\tstd::list< " << listType << "* >* pNew" <<
			*i << arg_num;
	}
	else
	{
		m_out << "\t\t\t" << *i << "* pNew" << *i << arg_num;
	}
}

string astgenvisitor1::create_class_name(const vector<vector<expression*>*>::const_iterator i, vector<string>& idents)
{
	// first build up the classname and gather the idents
	string classname = m_rulename;
	for (vector<expression*>::const_iterator j = (*i)->begin();
		  j != (*i)->end();
		  ++j)
	{
		m_literal.erase();
		m_ident.erase();
		(*j)->accept(this);
		if (m_ident.length() > 0)
		{
			classname += "_" + m_ident;
			idents.push_back(m_ident);
		}
	}
	if ((*i)->size() == 0)
	{
		classname += "_empty";
	}

	return classname;
}

void astgenvisitor1::write_class(const vector<vector<expression*>*>::const_iterator i, const string& baseclass)
{
	// first build up the classname and gather the idents
	vector<string> idents;
	string classname;
	if (baseclass == "node")
	{
		// we don't care about the class name, but we need idents to be filled in.
		create_class_name(i, idents);
		classname = m_rulename;
	}
	else
		classname = create_class_name(i, idents);

	m_out << "class " << classname << " : public " << baseclass << "\n";
	m_out << "{\n";

	// next build up the constructor
	m_out << "\tpublic:\n";
	m_out << "\t\t" << classname << "(\n";
	int arg_num = 1;
	if (idents.size() > 0)
	{
		vector<string>::const_iterator i = idents.begin();
		while (i != idents.end())
		{
			write_argument(i, arg_num);
			++i;
			++arg_num;
			if (i != idents.end())
				m_out << ",\n";
			else
				m_out << "\n";
		}
	}
	m_out << "\t\t)\n";

	// now do the initialization list
	vector<string>::const_iterator j = idents.begin();
	arg_num = 1;
	char sep = ':';
	if (baseclass != "node")
	{
		m_out << "\t\t\t: " << baseclass << "(" /*<< classname << "Type*/ ")\n";
		sep = ',';
	}
	while (j != idents.end())
	{
		m_out << "\t\t\t" << sep << " m_p" << *j << arg_num << "(pNew" << *j << arg_num << ")\n";
		sep = ',';
		++j;
		++arg_num;
	}
	m_out << "\t\t{}\n\n";

	// write out the destructor
	m_out << "\t\tvirtual ~" << classname << "();\n\n";

	// write the accept function
	m_out << "\t\tvoid accept( Visitor* v ) const\n";
	m_out << "\t\t{\n";
	if (baseclass == "node")
		m_out << "\t\t\tv->visit_" << m_rulename << "( this );\n";
	else
		m_out << "\t\t\tv->visit_" << classname << "( this );\n";
	m_out << "\t\t}\n\n";

	// write the getpointer function
	/*
	m_out << "\t\tvirtual " << classname << "* get" << classname << "()\n";
	m_out << "\t\t{\n";
	m_out << "\t\t\treturn this;\n";
	m_out << "\t\t}\n\n";
	*/
	
	// write the member variables
	arg_num = 1;
	for (vector<string>::const_iterator i = idents.begin(); i!= idents.end(); ++i, ++arg_num)
	{
		if (isAllCaps(*i) || beginsWithStr(*i))
		{
			m_out << "\t\tstd::string* m_p" << *i << arg_num << ";\n";
		}
		else if (endsWithList(*i))
		{
			m_out << "\t\tstd::list< " << getListType(*i) << "* >* m_p" <<
				*i << arg_num << ";\n";
		}
		else
		{
			m_out << "\t\t" << *i << "* m_p" << *i << arg_num << ";\n";
		}
	}

	// finish off the class
	m_out << "};\n\n";

}

void astgenvisitor1::visit_expression_base_PLUS(const expression_base_PLUS *pexpression_base_PLUS)
{
	// not yet implemented
	assert(0);
	pexpression_base_PLUS->m_base->accept(this);
}

void astgenvisitor1::visit_expression_base(const expression_base *pexpression_base)
{
	pexpression_base->m_base->accept(this);
}

void astgenvisitor1::visit_base_LITERAL(const base_LITERAL *pbase_LITERAL)
{
	m_literal = *pbase_LITERAL->m_LITERAL;
}

void astgenvisitor1::visit_expression_base_STAR(const expression_base_STAR *pexpression_base_STAR)
{
	// not yet implemented
	assert(0);
	pexpression_base_STAR->m_base->accept(this);
}

void astgenvisitor1::visit_base_LPAREN_expressionList_RPAREN(const base_LPAREN_expressionList_RPAREN *pbase_LPAREN_expressionList_RPAREN)
{
	// not yet implemented
	assert(0);
	if (pbase_LPAREN_expressionList_RPAREN->m_expressionList.get() != 0)
	{
		for (vector<expression*>::const_iterator i = pbase_LPAREN_expressionList_RPAREN->m_expressionList->begin();
			  i != pbase_LPAREN_expressionList_RPAREN->m_expressionList->end();
			  ++i)
		{
			(*i)->accept(this);
		}
	}
}

void astgenvisitor1::visit_expression_COMMENT(const expression_COMMENT *pexpression_COMMENT)
{
	// not yet implemented
	assert(0);
	(void)pexpression_COMMENT->m_COMMENT;
}

void astgenvisitor1::visit_alternation_expression_OR_expression(const alternation_expression_OR_expression *palternation_expression_OR_expression)
{
	// not yet implemented
	assert(0);
	palternation_expression_OR_expression->m_expression1->accept(this);
	palternation_expression_OR_expression->m_expression2->accept(this);
}

void astgenvisitor1::visit_base_IDENT(const base_IDENT *pbase_IDENT)
{
	m_ident = *pbase_IDENT->m_IDENT;
}

void astgenvisitor1::visit_base_LPAREN_alternation_RPAREN(const base_LPAREN_alternation_RPAREN *pbase_LPAREN_alternation_RPAREN)
{
	// not yet implemented
	assert(0);
	pbase_LPAREN_alternation_RPAREN->m_alternation->accept(this);
}

void astgenvisitor1::visit_alternation_alternation_OR_expression(const alternation_alternation_OR_expression *palternation_alternation_OR_expression)
{
	// not yet implemented
	assert(0);
	palternation_alternation_OR_expression->m_alternation->accept(this);
	palternation_alternation_OR_expression->m_expression->accept(this);
}


