#include "parser.h"
#include "AST_utils.h"
#include "expr.h"
#include "builtins.h"
#include "c-parse.h"
#include "constants.h"
#include "analyze.h"
#include "cqual.h"
#include "hash.h"
#include "qerror.h"
#include "qtype.h"
#include "utils.h"
#include "pam.h"
#include "effect.h"

/* #define DEBUG */
#define CONDITIONAL_HAS_NO_EFFECT

typedef struct Expr_and_pos {
  expression e;
  int pos;
} *expr_and_pos;


static void discover_scopes_declaration(declaration d);
static expression discover_scopes_statement(statement s);
static expression discover_scopes_expression(expression e);
static expression discover_scopes_unary_expression(unary e);
static expression discover_scopes_binary_expression(binary e);

/* Some handy functions defined elsewhere */
void unparse_start(FILE *);                  /* unparse.c */
void prt_variable_decl(variable_decl);
void prt_expression(expression, int);
const char *binary_op_name(ast_kind);
bool is_void_parms(declaration);             /* semantics.c */
bool equal_expressions(expression e1, expression e2); /* eq_expressions.c */
bool isassignment(binary e); /* analyze.c */
compound_stmt new_compound_stmt(region r, location loc, id_label id_labels, declaration decls, statement stmts, environment env); /* AST_defs.h */

static region pending_region;

void discover_scopes(declaration program)
{
  declaration d;
  
  unparse_start(stdout);
  AST_set_parents(CAST(node, program));

  pending_region = newregion();

  scan_declaration(d, program)
    {
      discover_scopes_declaration(d);
    }
  
  deleteregion(pending_region);
}

static void discover_scopes_declaration(declaration d)
{
  switch (d->kind)
    {
    case kind_function_decl:
      {
	function_decl fd = CAST(function_decl, d);

	assert(is_compound_stmt(fd->stmt));
	discover_scopes_statement(fd->stmt);
	return;
      }
    default:
      return;
    }
}

/* From the block stmts, form the block:
   {last_case->stmt;stmts[pos1];...stmts[pos2];} and set this as
   last_case->stmt, then set last_case->next to be stmts[pos2+1]

   Return true if pos1 <= pos2 */
static bool insert_new_scope_labeled_stmt(statement stmts,statement last_case,
					  int pos1, int pos2, environment env)
{
  int pos = 0;
  statement stmt = NULL, new_block_end = NULL, new_block_snd = NULL;
  statement succ = NULL,prev = NULL,pred = NULL, new_compound = NULL;
  labeled_stmt ls = CAST(labeled_stmt,last_case);

  assert(pos1 != 0);
  assert(is_labeled_stmt(last_case));

  /* We have case e_1 -> s1; case ....., so no new scope to add */
  if (pos1 > pos2)
    return FALSE;

  scan_statement(stmt, stmts) {
    
    if (pos == pos1) 
      {
	assert(prev);
	assert(is_labeled_stmt(prev));
	assert(prev == last_case);
	pred = prev;
	pred->next = NULL;
	new_block_snd = stmt;
      }

    if (pos == pos2) 
      {
	new_block_end = stmt;
	succ = (statement)new_block_end->next;
	new_block_end->next = NULL;
	break;
      }

    prev = stmt;
    pos++;
  }
  
  ls->stmt->next = (node)new_block_snd;
  new_compound = (statement) new_compound_stmt(parse_region,
					       ls->stmt->loc,
					       NULL,NULL,
					       ls->stmt,env);
  ls->stmt = new_compound;
  ls->next = (node)succ;

  return TRUE;
}

static statement insert_new_scope(statement stmts, int pos1, int pos2,
				  environment env)
{
  int pos = 0;
  statement stmt, prev = NULL, pred = NULL, succ = NULL;
  statement new_block = NULL, new_block_end = NULL, new_compound;

  assert (pos1 < pos2);

  /*  fprintf(stderr,"New scopes surrounding %d,%d\n",pos1, pos2); */

  scan_statement(stmt,stmts) {

    if (pos == pos1) {
      if (prev)
	{
	  pred = prev;
	  pred->next = NULL;
	}
      new_block = stmt;
    }
    
    else if (pos == pos2) {
      new_block_end = stmt;
      succ = (statement)new_block_end->next;
      new_block_end->next = NULL;
      break;
    }
    prev = stmt;
    pos++;
  }
  
  new_compound = (statement) new_compound_stmt(parse_region,new_block->loc,
					       NULL, NULL,
					       new_block,env);


  if (pred) 
    pred->next = (node) new_compound;
  else
    stmts = new_compound;

  new_compound->next = (node) succ;

  return stmts;
}

/* 
   Change case e_1 : stmt_1;...,stmt_n; case e_2.... 
   to case e_1 : {stmt_1;....,stmt_n;} case e_2...
*/
static expression discover_scopes_switch_statement(statement s)
{
  compound_stmt cs = CAST(compound_stmt, s);
  statement stmt, last_case = NULL;
  int pos = 0,last_case_pos = -1;
  bool changed = FALSE;
  expression result = NULL;
  
  do {
    changed = FALSE;
    last_case_pos = -1;
    pos = 0;
    last_case = NULL;
    scan_statement(stmt, cs->stmts) {
      if (is_labeled_stmt(stmt) && 
	  ( is_case_label(CAST(labeled_stmt,stmt)->label)
	    || is_default_label(CAST(labeled_stmt,stmt)->label)) )
	{
	  if (last_case_pos >= 0) 
	    {
	      if (insert_new_scope_labeled_stmt(cs->stmts,last_case,
						last_case_pos+1,pos-1,cs->env))
		{
		  changed = TRUE;
		  break;
		}
	    }
	  last_case = stmt;
	  last_case_pos = pos;
	}
      
      pos++;
    }
    
    if (!changed && last_case_pos > 0) /* handle the last labeled statement */
      changed = insert_new_scope_labeled_stmt(cs->stmts,last_case, 
					      last_case_pos+1,
					      pos-1,cs->env);
    
  }
  while (changed);
  
  scan_statement(stmt, cs->stmts) {
    /* assert(is_labeled_stmt(stmt)); */
    if (is_labeled_stmt(stmt)) 
      {
	labeled_stmt ls = CAST(labeled_stmt,stmt);
	result = discover_scopes_statement(ls->stmt); /* XXX */
      }
  }
  
  return result;
}


static expr_and_pos find_in_pending(dd_list pending, expression e)
{
  expr_and_pos ep;
  dd_list_pos p;
  
  dd_scan(p,pending) {
    ep = DD_GET(expr_and_pos,p);

#ifdef DEBUG
    printf("Checking equality: ");
    prt_expression(e,0);
    printf(", ");
    prt_expression(ep->e,0);
    printf("\n");
#endif /* DEBUG */

    if (equal_expressions(e,ep->e))
      return ep;
  }
  return NULL;
}

static expression discover_scopes_statement(statement s)
{
  switch (s->kind)
    {
    case kind_asm_stmt:
      return NULL;
      break;
    case kind_compound_stmt:
      {
	bool found_pair = FALSE;
	compound_stmt cs = CAST(compound_stmt, s);
	statement stmt;
	int pos = 0,pos1 = -1, pos2 = -1;
	int num_stmts = chain_length((node) cs->stmts);
	/* dd_list pending = dd_new_list(pending_region); */
	dd_list pending = dd_new_list(pending_region);

	scan_statement(stmt, cs->stmts)
	  {
	    expression e = discover_scopes_statement(stmt);

	    if (e)
	      {
		expr_and_pos ep =  find_in_pending(pending,e);

		if (ep)
		  pos1 = ep->pos;
		    
		if (pos1 >= 0) /* remove e from the stack, add scopes */
		  {
		    assert(ep);
#ifdef DEBUG
		    printf("Popping: ");
		    prt_expression(e,0);
		    printf("\n");
#endif /* DEBUG */
		    dd_remove(dd_find(pending,ep));
		    pos2 = pos;
		    found_pair = TRUE;
		    break;
		  }
		else /* push pos onto the stack */
		  { 
		    assert(ep == NULL);
#ifdef DEBUG
		    printf("Pushing: ");
		    prt_expression(e,0);
		    printf(",%d\n",pos);
#endif /* DEBUG */
		    ep = ralloc(pending_region,struct Expr_and_pos);
		    ep->pos = pos;
		    ep->e = e;
		    dd_add_last(pending_region,pending,ep);

		  }
	      }
	    pos++;
	  }
	if (found_pair) 
	  {
	    assert(pos1 >= 0 && pos2 >= 0);
	    if (pos1 == 0 && pos2 == (num_stmts -1))
	      goto END;
	    else 
	      {
#ifdef DEBUG
		printf("Adding new scope at %d,%d\n",pos1,pos2);
#endif
		cs->stmts = insert_new_scope(cs->stmts,pos1,pos2,cs->env);
		discover_scopes_statement(s); /* iterate until no pairs found */
	      }
	  }

	END:
	{
	  if (dd_is_empty(pending))
	    return NULL;
	  else 
	    return DD_GET(expr_and_pos, dd_last(pending))->e;	
	}
      }
    
    case kind_if_stmt:
      {
	expression result = NULL;
	expression tmp = NULL;
	if_stmt is = CAST(if_stmt, s);

#ifdef CONDITIONAL_HAS_NO_EFFECT       
	discover_scopes_statement(is->stmt1);
	discover_scopes_expression(is->condition);
#else
	result = discover_scopes_statement(is->stmt1);

	tmp = discover_scopes_expression(is->condition);
#endif	

	if (tmp && result)
	  {
	    fprintf(stderr,"Warning: discarding change_type in if statement\n");
	    result = tmp;
	  }
	else if (tmp)
	  result = tmp;

	if (is->stmt2)
	  {
	    tmp = discover_scopes_statement(is->stmt2);
	    if (tmp && result)
	      {
		fprintf(stderr,"Warning: discarding change_type in if statement\n");
		result = tmp;
	      }
	    else if (tmp)
	      result = tmp;
	  }

	return result;
      }
    case kind_labeled_stmt:
      {
	labeled_stmt ls = CAST(labeled_stmt, s);

	return discover_scopes_statement(ls->stmt);
      }
    case kind_expression_stmt:
      {
	expression_stmt es = CAST(expression_stmt, s);

	return discover_scopes_expression(es->arg1);
      }
    case kind_while_stmt:
      {
	expression tmp = NULL;
	expression result = NULL;
	while_stmt ws = CAST(while_stmt, s);
	
	result = discover_scopes_expression(ws->condition);
	tmp = discover_scopes_statement(ws->stmt);
	
	if (tmp && result)
	  {
	    fprintf(stderr,"Warning: discarding change_type in do while statement\n");
	    result = tmp;
	  }
	else if (tmp)
	  result = tmp;

	return result;
      }
    case kind_dowhile_stmt:
      {
	expression tmp = NULL;
	expression result = NULL;
	dowhile_stmt dws = CAST(dowhile_stmt, s);
	
	result = discover_scopes_statement(dws->stmt);

	tmp = discover_scopes_expression(dws->condition);

	if (tmp && result)
	  {
	    fprintf(stderr,"Warning: discarding change_type in do while statement\n");
	    result = tmp;
	  }
	else if (tmp)
	  result = tmp;

	return result;
      }
    case kind_switch_stmt: 
      {
	expression tmp = FALSE;
	expression result = FALSE;
	switch_stmt ss = CAST(switch_stmt, s);
	
	result = discover_scopes_expression(ss->condition);
	
	assert(is_compound_stmt(ss->stmt));

	tmp = discover_scopes_switch_statement(ss->stmt);

	if (tmp && result)
	  {
	    fprintf(stderr, "Warning: discarding change_type in switch statement\n");
	    result = tmp;
	  }
	else if (tmp)
	  result = tmp;
	
	return result;
      }
    case kind_for_stmt:
      {
	expression tmp = NULL;
	expression result = NULL;
	for_stmt fs = CAST(for_stmt, s);
	
	result = discover_scopes_statement(fs->stmt);

	if (fs->arg1)
	  {
	    tmp = discover_scopes_expression(fs->arg1);

	    if (tmp && result)
	      {
		fprintf(stderr, "Warning: discarding change_type in for statement\n");
		result = tmp;
	      }
	    else if (tmp)
	      result = tmp;
	  }


	if (fs->arg2)
	  {
	    tmp = discover_scopes_expression(fs->arg2);
	    if (tmp && result)
	      {
		fprintf(stderr, "Warning: discarding change_type in for statement\n");
		result = tmp;
	      }
	    else if (tmp)
	      result = tmp;
	  }

	if (fs->arg3)
	  {
	    tmp = discover_scopes_expression(fs->arg3);
	    if (tmp && result)
	      {
		fprintf(stderr, "Warning: discarding change_type in for statement\n");
		result = tmp;
	      }
	    else if (tmp)
	      result = tmp;
	  }
	
	return result;
      }
    case kind_return_stmt:
      {
	return_stmt rs = CAST(return_stmt, s);

	if (rs->arg1)
	    return discover_scopes_expression(rs->arg1);

	return NULL;
      }
    case kind_computed_goto_stmt:
      {
	computed_goto_stmt cgs = CAST(computed_goto_stmt, s);

	return discover_scopes_expression(cgs->arg1);
      }
    case kind_break_stmt:
      {
	return NULL;
      }
    case kind_continue_stmt:
      {
	return NULL;
      }
    case kind_goto_stmt:
      {
	return NULL;
      }
    case kind_empty_stmt:
      {
	return NULL;
      }
    case kind_change_type_stmt:
      {
	change_type_stmt cts = CAST(change_type_stmt, s);
	/* fprintf(stderr,"Found change type\n"); */
	return cts->arg1;
      }
    case kind_assert_type_stmt:
      {
	return NULL;
      }
    case kind_deep_restrict_stmt:
      {
	return NULL;
      }
    default:
      fail_loc(s->loc, "Unexpected statement kind 0x%x\n", s->kind);
      return NULL;
    }
}

static expression discover_scopes_expression(expression e)
{
  switch(e->kind) 
    {
    case kind_comma:
      {
	expression tmp = NULL;
	expression result = NULL;
	comma c = CAST(comma, e);
	expression e2;
	
	scan_expression (e2, c->arg1)
	  {
	    tmp = discover_scopes_expression(e2);

	    if (tmp && result)
	      {
		fprintf(stderr, "Warning: discarding change_type in comma expression\n");
		result = tmp;
	      }
	    else if (tmp)
	      result = tmp;
	  }
	return result;
      }
    case kind_sizeof_type:
      {
	return NULL;
      }
    case kind_alignof_type:
      {
	return NULL;
      }
    case kind_label_address:
      {
	return NULL;
      }
    case kind_cast:
      {
	cast c = CAST(cast, e);
	
	return discover_scopes_expression(c->arg1);
      }
    case kind_cast_list:
      {
	return NULL;
      }
    case kind_conditional:
      {
	expression tmp = NULL;
	expression result = NULL;
	conditional c = CAST(conditional, e);
	
	result = discover_scopes_expression(c->condition);
	  
	if (c->arg1) 
	  {
	    tmp = discover_scopes_expression(c->arg1);
	    
	    if (result && tmp)
	      {
		fprintf(stderr,"Warning: discarding change_type in conditional expression\n");
		result = tmp;
	      }
	    else if (tmp)
	      result = tmp;
	  }

       
	tmp = discover_scopes_expression(c->arg2);

	if (result && tmp)
	  { 
	    fprintf(stderr,"Warning: discarding change_type in conditional expression\n");
	    result = tmp;
	  }
	else if (tmp)
	  result = tmp;
	
	return result;
      }
    case kind_identifier:
      {
	return NULL;
      }
    case kind_compound_expr:
      {
	expression result = NULL,tmp = NULL;
	compound_expr ce = CAST(compound_expr, e);
	compound_stmt cs = CAST(compound_stmt, ce->stmt);
	statement cur_stmt;
	declaration d;
	
	if (cs->id_labels)
	  fail_loc(cs->loc, "Unimplemented: id_labels\n", 0);
	
	/* Analyze the declarations in the block */
	scan_declaration(d, cs->decls)
	  {
	  assert(d->kind != kind_asm_decl); /*asm_decl only at toplevel */
	  discover_scopes_declaration(d);
	  }
	
	/* Analyze the statements in the block.  Analyze all but the
	   last one. */
	cur_stmt = cs->stmts;
	while (cur_stmt && cur_stmt->next)
	  {
	    tmp = discover_scopes_statement(cur_stmt);
	    if (tmp && result)
	      fprintf(stderr,"Warning: discarding change_type in compound expression\n");
	    else if (tmp)
	      result = tmp;
	    cur_stmt = CAST(statement, cur_stmt->next);
	  }
	
      /* Now analyze the last statement (if there is one), and
         compute the type of the expression. */
	if (cur_stmt && is_expression_stmt(cur_stmt))
	  {
	    tmp = 
	      discover_scopes_expression(CAST(expression_stmt, cur_stmt)->arg1);
	    if (tmp && result)
	      fprintf(stderr,"Warning: discarding change_type in compound expression\n");
	    else if (tmp)
	      result = tmp;
	  }
	else
	  {
	  /* Type is void */
	    if (cur_stmt)
	      {
		tmp = discover_scopes_statement(cur_stmt);
		if (tmp && result)
		  fprintf(stderr,"Warning: discarding change_type in compound expression\n");
		else if (tmp)
		  result = tmp;
	      }
	  }
	
	return result;
      }
    case kind_function_call:
      {
	expression result;
	function_call fc = CAST(function_call, e);
	expression arg;
	
	if (fc->va_arg_call)
	  {
	    return NULL;
	  }

	result = discover_scopes_expression(fc->arg1);
	
	scan_expression(arg, fc->args)
	  {
	    if (discover_scopes_expression(arg))
	      fprintf(stderr,"Warning: discarding change_type in function arguments\n");
	  }
	return result;
      }
    case kind_array_ref:
      {
	expression result = NULL, result2 = NULL;
	array_ref ar = CAST(array_ref, e);
		
	result = discover_scopes_expression(ar->arg1);
	
	result2 = discover_scopes_expression(ar->arg2);
	
	if (result2) 
	  fprintf(stderr,"Warning: discarding change_type in array ref\n");

	return result;
      }
    case kind_field_ref:
      {
	field_ref fr = CAST(field_ref, e);
	return discover_scopes_expression(fr->arg1);
      }
    case kind_init_list:
      {
	/*    init_list il = CAST(init_list, e);*/
	fail_loc(e->loc, "Unexpected init list\n", 0);
	return NULL;
      }
    case kind_init_index:
      {
	/*    init_index ii = CAST(init_index, e);*/
	fail_loc(e->loc, "Unexpected init index\n", 0);
	return NULL;
      }
    case kind_init_field:
      {
	fail_loc(e->loc, "Unexpected init field\n", 0);
	return NULL;
      }
    case kind_lexical_cst:
      {
	return NULL;      
      }
    case kind_string:
      {
	return NULL;
      }
    default:
      if (is_unary(e))
	return discover_scopes_unary_expression(CAST(unary, e));
      else if (is_binary(e))
	return discover_scopes_binary_expression(CAST(binary, e));
    }
  
  return NULL;
}


static expression discover_scopes_unary_expression(unary e)
{
  switch (e->kind)
    {
    case kind_dereference:
      {
        return discover_scopes_expression(e->arg1);
      }
    case kind_address_of:
      {
	return discover_scopes_expression(e->arg1);
      }
    case kind_extension_expr:
      {
	return discover_scopes_expression(e->arg1);
      }
    case kind_sizeof_expr:
      {
	return NULL;
      }
    case kind_alignof_expr:
      {
	return NULL;
      }
    case kind_realpart:
    case kind_imagpart:
      {
	return discover_scopes_expression(e->arg1);
      }
    case kind_unary_minus:
    case kind_unary_plus:
    case kind_conjugate:
    case kind_bitnot:
      {
	return discover_scopes_expression(e->arg1);
      }
    case kind_not:
      {
	return discover_scopes_expression(e->arg1);
      }
    case kind_preincrement:
    case kind_postincrement:
    case kind_predecrement:
    case kind_postdecrement:
      {
	return discover_scopes_expression(e->arg1);
      }
    default:
      fail_loc(e->loc, "Unexpected unary op kind 0x%x\n", e->kind);
      return NULL;
    }
  return NULL;
}

static expression discover_scopes_binary_expression(binary e)
{
  expression result, result2;
  result = discover_scopes_expression(e->arg1);
  result2 = discover_scopes_expression(e->arg2);

  if (result2)
    fprintf(stderr,"Warning: discarding change_type in binary expression\n");

  return result;
}
/*
case kind_compound_stmt:
      {
	compound_stmt cs = CAST(compound_stmt, s);
	statement stmt;
	int pos = 0,pos1 = -1, pos2 = -1;
	bool has_change_type = FALSE;
	int num_stmts = chain_length((node) cs->stmts);

	scan_statement(stmt, cs->stmts)
	  {
	    bool b = discover_scopes_statement(stmt);
	    has_change_type |= b;
	    
	    if (b && (pos1 >= 0))
	      {
		pos2 = pos;
	      }
	    else if (b) {
	      pos1 = pos;
	    }
	    if (pos1 >= 0 && pos2 >= 0)
	      break;
	    pos++;
	  }
	if (has_change_type) {
	  if (pos1 == 0 && pos2 == (num_stmts -1))
	    return TRUE;
	  else if (  (pos1 >= 0) && (pos2 >= 0) ) {
	    cs->stmts = insert_new_scope(cs->stmts,pos1,pos2,cs->env);
	    discover_scopes_statement(s);
	  }
	  
	}
	return has_change_type;
      }


*/
