/*****************************************************************************/
/*!
 * \file expr_transform.cpp
 * 
 * Author: Ying Hu, Clark Barrett
 * 
 * Created: Jun 05 2003
 *
 * <hr>
 * Copyright (C) 2003 by the Board of Trustees of Leland Stanford
 * Junior University and by New York University. 
 *
 * License to use, copy, modify, sell and/or distribute this software
 * and its documentation for any purpose is hereby granted without
 * royalty, subject to the terms and conditions defined in the \ref
 * LICENSE file provided with this distribution.  In particular:
 *
 * - The above copyright notice and this permission notice must appear
 * in all copies of the software and related documentation.
 *
 * - THE SOFTWARE IS PROVIDED "AS-IS", WITHOUT ANY WARRANTIES,
 * EXPRESSED OR IMPLIED.  USE IT AT YOUR OWN RISK.
 * 
 * <hr>
 * 
 */
/*****************************************************************************/


#include "expr_transform.h"
#include "theory_core.h"
#include "dictionary.h"
#include "hash.h"
#include "command_line_flags.h"
#include "core_proof_rules.h"


using namespace std;
using namespace CVCL;


ExprTransform::ExprTransform(TheoryCore* core) : d_core(core)
{
  d_commonRules = d_core->getCommonRules();  
  d_rules = d_core->getCoreRules();
}


Theorem ExprTransform::preprocess(const Expr& e)
{
  Theorem thm1;
  if (d_core->getFlags()["pp-pushneg"].getBool())
    thm1 = pushNegation(e);
  else
    thm1 = d_commonRules->reflexivityRule(e);
  thm1 = d_commonRules->transitivityRule(thm1, d_core->simplify(thm1.getRHS()));
  if (d_core->getFlags()["ite-ify"].getBool())
    thm1 = d_commonRules->transitivityRule(thm1,ite_convert(thm1.getRHS()));
  if (d_core->getFlags()["pp-ite"].getBool())
    return d_commonRules->transitivityRule(thm1,ite_simplify(thm1.getRHS()));
  return thm1;
}


Theorem ExprTransform::preprocess(const Theorem& thm)
{
  return d_commonRules->iffMP(thm, preprocess(thm.getExpr()));
}


// We assume that the cache is initially empty
Theorem ExprTransform::pushNegation(const Expr& e) {
  if(e.isTerm()) return d_commonRules->reflexivityRule(e);
  Theorem res(pushNegationRec(e, false));
  d_pushNegCache.clear();
  return res;
}


// Recursively descend into the expression e, keeping track of whether
// we are under even or odd number of negations ('neg' == true means
// odd, the context is "negative").

// Produce a proof of e <==> e' or !e <==> e', depending on whether
// neg is false or true, respectively.
Theorem ExprTransform::pushNegationRec(const Expr& e, bool neg) {
  TRACE("pushNegation", "pushNegationRec(", e,
	", neg=" + string(neg? "true":"false") + ") {");
  DebugAssert(!e.isTerm(), "pushNegationRec: not boolean e = "+e.toString());
  ExprMap<Theorem>::iterator i = d_pushNegCache.find(neg? !e : e);
  if(i != d_pushNegCache.end()) { // Found cached result
    TRACE("pushNegation", "pushNegationRec [cached] => ", (*i).second, "}");
    return (*i).second;
  }
  // By default, do not rewrite
  Theorem res(d_core->reflexivityRule((neg)? !e : e));
  if(neg) {
    switch(e.getKind()) {
      case TRUE: res = d_commonRules->rewriteNotTrue(!e); break;
      case FALSE: res = d_commonRules->rewriteNotFalse(!e); break;
      case NOT: 
        res = pushNegationRec(d_commonRules->rewriteNotNot(!e), false);
        break;
      case AND:
        res = pushNegationRec(d_rules->rewriteNotAnd(!e), false);
        break;
      case OR: 
        res = pushNegationRec(d_rules->rewriteNotOr(!e), false);
        break;
      case IMPLIES: {
        vector<Theorem> thms;
        thms.push_back(d_rules->rewriteImplies(e));
        res = pushNegationRec
          (d_commonRules->substitutivityRule(NOT, thms), true);
        break;
      }
//       case IFF:
// 	// Preserve equivalences to explore structural similarities
// 	if(e[0].getKind() == e[1].getKind())
// 	  res = d_commonRules->reflexivityRule(!e);
// 	else
// 	  res = pushNegationRec(d_commonRules->rewriteNotIff(!e), false);
//         break;
      case ITE:
        res = pushNegationRec(d_rules->rewriteNotIte(!e), false);
        break;

      // Replace LETDECL with its definition.  The
      // typechecker makes sure it's type-safe to do so.
      case LETDECL: {
        vector<Theorem> thms;
        thms.push_back(d_rules->rewriteLetDecl(e));
        res = pushNegationRec
          (d_commonRules->substitutivityRule(NOT, thms), true);
        break;
      }
      default:
        res = d_commonRules->reflexivityRule(!e);
    } // end of switch(e.getKind())
  } else { // if(!neg)
    switch(e.getKind()) {
      case NOT: res = pushNegationRec(e[0], true); break;
      case AND:
      case OR:
      case IFF:
      case ITE: {
        Op op = e.getOp();
        vector<Theorem> thms;
        for(Expr::iterator i=e.begin(), iend=e.end(); i!=iend; ++i)
          thms.push_back(pushNegationRec(*i, false));
        res = d_commonRules->substitutivityRule(op, thms);
        break;
      }
      case IMPLIES:
        res = pushNegationRec(d_rules->rewriteImplies(e), false);
        break;
      case LETDECL:
        res = pushNegationRec(d_rules->rewriteLetDecl(e), false);
        break;
      default:
        res = d_commonRules->reflexivityRule(e);
    } // end of switch(e.getKind())
  }
  TRACE("pushNegation", "pushNegationRec => ", res, "}");
  d_pushNegCache[neg? !e : e] = res;
  return res;
}


Theorem ExprTransform::pushNegationRec(const Theorem& thm, bool neg) {
  DebugAssert(thm.isRewrite(), "pushNegationRec(Theorem): bad theorem: "
	      + thm.toString());
  Expr e(thm.getRHS());
  if(neg) {
    DebugAssert(e.isNot(), 
		"pushNegationRec(Theorem, neg = true): bad Theorem: "
		+ thm.toString());
    e = e[0];
  }
  return d_commonRules->transitivityRule(thm, pushNegationRec(e, neg));
}


Theorem ExprTransform::pushNegation1(const Expr& e) {
  TRACE("pushNegation1", "pushNegation1(", e, ") {");
  DebugAssert(e.isNot(), "pushNegation1("+e.toString()+")");
  Theorem res;
  switch(e[0].getKind()) {
    case TRUE: res = d_commonRules->rewriteNotTrue(e); break;
    case FALSE: res = d_commonRules->rewriteNotFalse(e); break;
    case NOT: 
      res = d_commonRules->rewriteNotNot(e);
      break;
    case AND:
      res = d_rules->rewriteNotAnd(e);
      break;
    case OR: 
      res = d_rules->rewriteNotOr(e);
      break;
    case IMPLIES: {
      vector<Theorem> thms;
      thms.push_back(d_rules->rewriteImplies(e[0]));
      res = d_commonRules->substitutivityRule(e.getOp(), thms);
      res = d_commonRules->transitivityRule(res, d_rules->rewriteNotOr(res.getRHS()));
      break;
    }
    case ITE:
      res = d_rules->rewriteNotIte(e);
      break;
      // Replace LETDECL with its definition.  The
      // typechecker makes sure it's type-safe to do so.
    case LETDECL: {
      vector<Theorem> thms;
      thms.push_back(d_rules->rewriteLetDecl(e[0]));
      res = d_commonRules->substitutivityRule(e.getOp(), thms);
      res = d_commonRules->transitivityRule(res, pushNegation1(res.getRHS()));
      break;
    }
    default:
      res = d_commonRules->reflexivityRule(e);
  }
  TRACE("pushNegation1", "pushNegation1 => ", res.getExpr(), " }");
  return res;
}


typedef Hash_Table<Expr, Expr> CareSet;
typedef Hash_Ptr<Expr, Expr> CareSetPtr;
typedef Dict<Expr, CareSet*> Queue;
typedef Dict_Ptr<Expr, CareSet*> QueuePtr;

typedef ExprMap<Theorem> Table;

static int cf(Expr x, Expr y) { 
	return (x.getIndex() < y.getIndex()) ? -1 : (x.getIndex() == y.getIndex()) ? 0 : 1; }  // compare
static size_t hf(const Expr x) { 
	return x.getEM()->hash(x);; }                             // hash
static size_t mf(const Expr x, const Expr y) { return x==y; }             // match

static void update_queue(Queue *q, Expr e, CareSet *cs_prime){
  CareSet **cs=q->Fetch(e);

  // if there is already an entry for this expr,
  // AND the new careset with the old careset

  if (cs) {
    CareSetPtr p(*cs);
    while (p!=NULL) {
      if (!cs_prime->Fetch(p->Key()) ||
	  (*cs_prime)[p->Key()]!=p->Data()) {
	Expr key=p->Key();
	++p;
	(*cs)->Delete(key);
      }
      else
	++p;
    }
  }
  else
    q->Insert(e, new CareSet(*cs_prime));
}

Theorem ExprTransform::substitute(Expr e, Table *init_st) {
  static Table *memoize;
  static Table *st;
  
  // initialize
  if (init_st) {
    if (memoize) delete memoize;
    memoize = new Table();
    st = init_st;
  }
  
  // atomic expressions don't change
  if(e.isNull() || e.isBoolConst()
     || e.isString() || e.isRational())
    return d_commonRules->reflexivityRule(e);
  
  // has the result been memoized
  ExprMap<Theorem>::iterator i = memoize->find(e);
  if(i != memoize->end()) { // Found cached result
    return (*i).second;
  }
  
  // do substitution
  i = st->find(e);
  if (i != st->end()) {
    return d_commonRules->transitivityRule((*i).second, substitute((*i).second.getRHS())); //get the theorem stored by simplifying, and do substitute recursively
  }
  
  // build new expr with all sub-exprs substituted
  Theorem thm;
  if(e.getKind() != ITE)
    thm = d_commonRules->reflexivityRule(e);
  else{
    vector<Theorem> thms;
    vector<unsigned> changed;
    Expr tmp;
    thm = substitute(e[0]);
    if(thm.getRHS() == e[0]){
      thm = d_commonRules->reflexivityRule(e);
      tmp = e;
    }
    else{
      thms.push_back(thm);
      changed.push_back(0);	  
      thm = d_commonRules->substitutivityRule(e, changed, thms);
      tmp = thm.getRHS();
    }    
    Theorem thm1 = d_rules->rewriteIteThen(tmp, substitute(tmp[1]));
    //DebugAssert(thm.getRHS()==thm1.getLHS(), "preprocess: transitivity rule1: " + tmp.toString());
    thm = d_commonRules->transitivityRule(thm, thm1);
    tmp = thm.getRHS();
    thm1 = d_rules->rewriteIteElse(tmp, substitute(tmp[2]));
    //DebugAssert(thm.getRHS()==thm1.getLHS(), "preprocess: transitivity rule2: " + tmp.toString());
    thm = d_commonRules->transitivityRule(thm, thm1);    
  }

  (*memoize)[e] = thm;
  return thm;
}

Theorem ExprTransform::ite_simplify(Expr e){
  
  Queue q(cf);
  Table st;
  
  // go through all the exprs in reverse topological order
  // and build up the substitution table
  q.Insert(e, new CareSet(hf, mf));
  for (QueuePtr p(&q); p!=NULL; ++p) {
    
    Expr e = Expr(p->Key());
    
    if(e.isNull() || e.isBoolConst() || e.isVar() 
       || e.isString() || e.isRational() || e.isApply())
      continue;

    CareSet cs(*p->Data());
    
    // free up memory so that we don't run out of the stuff
    delete p->Data();
    p->Data() = NULL;
    
    // simplify an ite if its ifpart belongs to the care set
    if (e.getKind()==ITE) {
      const Expr e1 = Expr(e);
      Expr *v = cs.Fetch(e1[0]);
      if (v) {
	vector<Theorem> thms;
	vector<unsigned> changed;
	if((*v).isTrue()){
	  // e1[0] |- e1[0] ==> e1[0] IFF TRUE
	  Theorem thm0 = d_commonRules->iffTrue(d_commonRules->assumpRule(e1[0]));
	  thms.push_back(thm0);
	  changed.push_back(0);	  
	  thm0 = d_commonRules->substitutivityRule(e1, changed, thms);
	  Theorem thm1 = d_rules->rewriteIteTrue(thm0.getRHS());
	  thm0 = d_commonRules->transitivityRule(thm0, thm1);
	  st[e1] = thm0;
	  if(e1[1].getKind() != TRUE && e1[1].getKind() != FALSE)
	    update_queue(&q, e1[1], &cs);
	  continue;
	}
	else { //v==FALSE
	  // e1[0] IFF FALSE |- e1[0] IFF FALSE
	  Theorem thm0 = d_commonRules->assumpRule(e1[0].iffExpr(d_core->falseExpr()));
	  thms.push_back(thm0);
	  changed.push_back(0);	  
	  thm0 = d_commonRules->substitutivityRule(e1, changed, thms);
	  Theorem thm1 = d_rules->rewriteIteFalse(thm0.getRHS());
	  thm0 = d_commonRules->transitivityRule(thm0, thm1);
	  st[e1] = thm0;
	  if(e1[2].getKind() != TRUE && e1[2].getKind() != FALSE)
	    update_queue(&q, e1[2], &cs);
	  continue;
	}
      }
    }
    // add children to the queue, updating their caresets
    // in the case of an ite-expr
    if (e.getKind()==ITE) {
      CareSet cs_prime(cs);
      const Expr ite = Expr(e);
      if(ite[0].getKind() != TRUE && ite[0].getKind() != FALSE)
	update_queue(&q, ite[0], &cs);
      // add if-part==TRUE to the care-set for the then-part
      cs_prime.Insert(ite[0], d_core->trueExpr());
      if(ite[1].getKind() != TRUE && ite[1].getKind() != FALSE)
	update_queue(&q, ite[1], &cs_prime);
      // add if-part==FALSE to the care-set for the else-part
      cs_prime[ite[0]]=d_core->falseExpr();
      if(ite[2].getKind() != TRUE && ite[2].getKind() != FALSE)
	update_queue(&q, ite[2], &cs_prime);
    }
  }
  
  // perform all the substituations
  return substitute(e, &st);
}

Theorem ExprTransform::ite_convert(const Expr& e){
  switch(e.getKind()){
  case TRUE:
  case FALSE:
    return d_commonRules->reflexivityRule(e); break;
  case ITE:
    {
      Theorem thm0 = ite_convert(e[0]);
      Theorem thm1 = ite_convert(e[1]);
      Theorem thm2 = ite_convert(e[2]);
      vector<Theorem> thms;
      thms.push_back(thm0);
      thms.push_back(thm1);
      thms.push_back(thm2);
      
      return d_commonRules->substitutivityRule(e.getOp(), thms);
    }
    break;
  case NOT:
    {
      Theorem thm = ite_convert(e[0]);
      vector<Theorem> thms;
      thms.push_back(thm);
      thm = d_commonRules->substitutivityRule(e.getOp(), thms);
      return d_commonRules->transitivityRule(thm, d_rules->NotToIte(thm.getRHS()));
    }
    break;
  case AND:
    {
      DebugAssert(e.arity() > 0, "Expected non-empty AND");
      const vector<Expr>& kids = e.getKids();
      unsigned i(0), ar(kids.size());
      vector<Theorem> thms;
      for(; i < ar; i++){
	thms.push_back(ite_convert(kids[i]));
      }
      Theorem thm = d_commonRules->substitutivityRule(e.getOp(), thms);
      thm = d_commonRules->transitivityRule(thm, d_rules->AndToIte(thm.getRHS()));
      return thm;
    }
    break;
  case OR:
    {
      DebugAssert(e.arity() > 0, "Expected non-empty OR");
      const vector<Expr>& kids = e.getKids();
      unsigned i(0), ar(kids.size());
      vector<Theorem> thms;
      for(; i < ar; i++){
	thms.push_back(ite_convert(kids[i]));
      }
      Theorem thm = d_commonRules->substitutivityRule(e.getOp(), thms);
      return d_commonRules->transitivityRule(thm, d_rules->OrToIte(thm.getRHS()));
    }
    break;
  case IMPLIES:
    {
      vector<Theorem> thms;
      thms.push_back(ite_convert(e[0]));
      thms.push_back(ite_convert(e[1]));
      Theorem thm = d_commonRules->substitutivityRule(e.getOp(), thms);
      return d_commonRules->transitivityRule(thm, d_rules->ImpToIte(thm.getRHS()));
    }
    break;
  case IFF:
	  {
	    vector<Theorem> thms;
	    thms.push_back(ite_convert(e[0]));
	    thms.push_back(ite_convert(e[1]));
	    Theorem thm = d_commonRules->substitutivityRule(e.getOp(), thms);
	    return d_commonRules->transitivityRule(thm, d_rules->IffToIte(thm.getRHS()));
	  }
	  break;
  default:
    return d_commonRules->reflexivityRule(e);
  }
}

Expr ExprTransform::ite_reorder(const Expr& e){
  switch(e.getKind()){
  case TRUE:
  case FALSE:
    return e; break;
  case ITE:
    {
      int x = random(); // 2 ^ 30 = 1073741824
      if(x < 1073741824)
	return ite_reorder(e[0]).iteExpr(ite_reorder(e[1]), ite_reorder(e[2]));
      if(e[0].isTrue()){
	return ite_reorder(e[1]);
      }
      else if(e[0].isFalse()){
	return ite_reorder(e[2]);
      }
      else if(e[1].isTrue())
	return ite_reorder(e[2]).iteExpr(e[1], ite_reorder(e[0]));
      else if(e[1].isFalse() && e[2].isTrue())
	return ite_reorder(e[0]).iteExpr(e[1], e[2]);
      else if(e[1].isFalse()){
	Expr e0 = ite_reorder(e[2]);
	if(e0.isTrue())
	  return getNeg(ite_reorder(e[0]));
	else if(e0.isFalse())
	  return e[1];
	else
	  return getNeg(e0).iteExpr(e[1], getNeg(ite_reorder(e[0])));
      }
      else if(e[2].isTrue()){
	Expr e0 = ite_reorder(e[1]);
	if(e0.isTrue())
	  return e[2];
	else if(e0.isFalse())
	  return getNeg(ite_reorder(e[0])) ;
	else
	  return getNeg(e0).iteExpr(getNeg(ite_reorder(e[0])), e[2]);
      }
      else if(e[2].isFalse()){
	Expr e0 = ite_reorder(e[1]);
	if(e0.isTrue())
	  return ite_reorder(e[0]);
	else if(e0.isFalse())
	  return e[2];
	else
	  return e0.iteExpr(ite_reorder(e[0]), e[2]);
      }
      else{
	Expr e0 = getNeg(ite_reorder(e[0]));
	if(e0.isTrue())
	  return ite_reorder(e[2]);
	else if(e0.isFalse())
	  return ite_reorder(e[1]);
	else
	  return e0.iteExpr(ite_reorder(e[2]), ite_reorder(e[1]));
      }
    }
    break;
  default:
    return e;
  } 
}

Expr ExprTransform::getNeg(const Expr& e){
  switch(e.getKind()){
  case TRUE:
    return d_core->falseExpr(); break;
  case FALSE:
    return d_core->trueExpr(); break;
  case ITE:
    {
      if(e[1].isFalse() && e[2].isTrue())
	return e[0];
      else
	return e.iteExpr(d_core->falseExpr(), d_core->trueExpr());
    }
    break;
  default:
    return e.iteExpr(d_core->falseExpr(), d_core->trueExpr());
  }
}


