#include "combine.h"

#include "block.h"
#include "gradient.h"
#include "par.h"
#include "parse.h"
#include "ref.h"
#include "simple.h"

#define  MAX_NINPUT_FILES  4

static int ndim;
static int block_size_in[MAX_NDIM];
static int block_size_out[MAX_NDIM];
static int npoints_in[MAX_NDIM];
static int npoints_out[MAX_NDIM];

static Ref_info *ref;

static int nrow_in;
static int nrow_out;
static float *row_in[MAX_NINPUT_FILES];
static float *row_out;

static int nstore_in;
static int nstore_out;
static float *store_in[MAX_NINPUT_FILES];
static float *store_out;

static int dim = -1;
static int type;
static Combine_func combine_func;
static Line combine_name;

static int ninput_files;
static Line input_file[MAX_NINPUT_FILES];
static Line output_file;
static FILE *file_in[MAX_NINPUT_FILES];
static FILE *file_out;
static Bool swapped;
static Bool integer;
static Bool blocked;
static int header;
static Bool deflated;
static float level;

static Bool input_found[MAX_NINPUT_FILES];
static Bool output_found;
static Bool par_found;
static Bool combine_found;
static Line output_par_file;
static char *combine_file;

static int parse_int[] = { PARSE_INT };
static int parse_string[] = { PARSE_STRING };
static int parse_int_string[] = { PARSE_INT, PARSE_STRING };

#define  FOUND_TWICE(string) \
	 {   sprintf(error_msg, "in \"%s\" '%s' found twice", \
				combine_file, string);  return  ERROR;   }

#define  FOUND_BEFORE(string1, string2) \
	 {   sprintf(error_msg, "in \"%s\" '%s' found before '%s'", \
			combine_file, string1, string2);  return  ERROR;   }

#define  NOT_FOUND(string) \
	 {   sprintf(error_msg, "in \"%s\" no '%s' found", \
				combine_file, string);  return  ERROR;   }

#define  CHECK_EQUAL(n1, n2, string) \
	 {   if (n1 != n2) \
	     {   sprintf(error_msg, "in \"%s\": '%s' has inconsistent value", \
				combine_file, string);  return  ERROR;   }   }

#define  CHECK_EQUAL_VECTORS(n, v1, v2, string) \
	 {   int I; \
	     for (I = 0; I < n; I++) \
	     {   if (v1[I] != v2[I]) \
	         {   sprintf(error_msg, \
			"in \"%s\": '%s[%d]' has inconsistent value", \
			combine_file, string, I+1);  return  ERROR;   }   }   }

static Status allocate_memory(String error_msg)
{
    int i;

    for (i = 0; i < ninput_files; i++)
    {
	sprintf(error_msg, "allocating memory for store_in[%d]", i);
	MALLOC(store_in[i], float, nstore_in);

	sprintf(error_msg, "allocating memory for row_in[%d]", i);
	MALLOC(row_in[i], float, nrow_in);
    }

    sprintf(error_msg, "allocating memory for store_out");
    MALLOC(store_out, float, nstore_out);

    sprintf(error_msg, "allocating memory for row_out");
    MALLOC(row_out, float, nrow_out);

    return  OK;
}

static void determine_params()
{
    int i;

    if (!blocked)  /* bit of a fudge, but it ought to work */
    {
	block_size_in[0] = npoints_in[0];
	block_size_out[0] = npoints_out[0];

	for (i = 1; i < ndim; i++)
	    block_size_in[i] = block_size_out[i] = 1;
    }

    VECTOR_PRODUCT(nstore_in, block_size_in, ndim);
    nstore_in *= BLOCK(npoints_in[dim], block_size_in[dim]);

    VECTOR_PRODUCT(nstore_out, block_size_out, ndim);
    nstore_out *= BLOCK(npoints_out[dim], block_size_out[dim]);

    nrow_in = npoints_in[dim];
    nrow_out = npoints_out[dim];

    if (type == COMPLEX_DATA)
    {
	nrow_in *= 2;
	nrow_out *= 2;
    }
}

static Status input_parse(Generic_ptr *var, String error_msg)
{
    int n = *((int *) var[0]);
    String input_par_file = (String) (var[1]);
    Line msg1, msg2;
    Par_info par_info;

    if (output_found)
	FOUND_BEFORE("output", "input");

    if (n > MAX_NINPUT_FILES)
    {
	sprintf(error_msg, "in \"%s\" cannot have %d input files",
					combine_file, MAX_NINPUT_FILES);
	return  ERROR;
    }
	
    n--;
    if (input_found[n])
    {
	sprintf(msg1, "input %d", n+1);
	FOUND_TWICE(msg1);
    }

    if ((n > 0) && !input_found[n-1])
    {
	sprintf(msg1, "input %d", n+1);
	sprintf(msg2, "input %d", n);
	FOUND_BEFORE(msg1, msg2);
    }

    CHECK_STATUS(read_par_file(input_par_file, &par_info, error_msg));

    strcpy(input_file[n], par_info.file);

    if (n == 0)
    {
	ndim = par_info.ndim;
	COPY_VECTOR(npoints_in, par_info.npoints, ndim);
	COPY_VECTOR(npoints_out, par_info.npoints, ndim);
	swapped = par_info.swapped;
	integer = par_info.integer;
	blocked = par_info.blocked;
	header = par_info.header;
	deflated = par_info.deflated;
	level = par_info.level;

	if (blocked)
	{
	    COPY_VECTOR(block_size_in, par_info.block_size, ndim);
	    COPY_VECTOR(block_size_out, par_info.block_size, ndim);
	}
    }
    else
    {
	CHECK_EQUAL(ndim, par_info.ndim, "ndim");
	CHECK_EQUAL(swapped, par_info.swapped, "swap");
	CHECK_EQUAL(integer, par_info.integer, "int");
	CHECK_EQUAL(blocked, par_info.blocked, "blocking");
	CHECK_EQUAL(header, par_info.header, "head");
	CHECK_EQUAL(deflated, par_info.deflated, "deflation");
	CHECK_EQUAL_VECTORS(ndim, npoints_in, par_info.npoints, "npts");

	if (blocked)
	 CHECK_EQUAL_VECTORS(ndim, block_size_in, par_info.block_size, "block");
    }

    if (deflated)
    {
	sprintf(error_msg, "input file (#%d) cannot be deflated", n+1);
	return  ERROR;
    }

    ref = par_info.ref;
    input_found[n] = TRUE;

    ninput_files = n + 1;

    return  OK;
}

static Status output_parse(Generic_ptr *var, String error_msg)
{
    String name = (String) (*var);

    if (output_found)
	FOUND_TWICE("output");

    strcpy(output_file, name);

    output_found = TRUE;

    return  OK;
}

static Status par_parse(Generic_ptr *var, String error_msg)
{
    String name = (String) (*var);

    if (!input_found)
	FOUND_BEFORE("par", "input");

    if (par_found)
	FOUND_TWICE("par");

    strcpy(output_par_file, name);

    par_found = TRUE;

    return  OK;
}

#define  SEPARATOR      " \t\n"

static Status gradient_parse(Generic_ptr *var, String error_msg)
{
    if (!output_found)
	FOUND_BEFORE("gradient", "output");

    if (combine_found)
	FOUND_TWICE("combining function");

    combine_found = TRUE;

    if (npoints_in[0] % 2)
    {
	sprintf(error_msg,
	    "'gradient': in dim 1 must have even number of points but have %d",
								npoints_in[0]);
	return  ERROR;
    }

    dim = *((int *) var[0]);

    if (dim > ndim)
    {
	sprintf(error_msg,
	    "in \"%s\" have 'gradient %d', but 'ndim' = %d",
					combine_file, dim, ndim);
	return  ERROR;
    }

    if (dim < 1)
    {
	sprintf(error_msg, "in \"%s\" have 'gradient %d'",
						combine_file, dim);
	return  ERROR;
    }

    dim--;

    strcpy(combine_name, "gradient");

    return  init_gradient(ninput_files, npoints_out+dim, &type,
						&combine_func, error_msg);
}

static Status add_parse(Generic_ptr *var, String error_msg)
{
    if (!output_found)
	FOUND_BEFORE("add", "output");

    if (combine_found)
	FOUND_TWICE("combining function");

    combine_found = TRUE;

    dim = 0;
    type = REAL_DATA;

    strcpy(combine_name, "add");

    return  init_add(ninput_files, npoints_out[0], &combine_func, error_msg);
}

static Status subtract_parse(Generic_ptr *var, String error_msg)
{
    if (!output_found)
	FOUND_BEFORE("subtract", "output");

    if (combine_found)
	FOUND_TWICE("combining function");

    combine_found = TRUE;

    dim = 0;
    type = REAL_DATA;

    strcpy(combine_name, "subtract");

    return  init_subtract(ninput_files, npoints_out[0], &combine_func,
								error_msg);
}

static Parse_line combine_table[] =
{
    { "input",		2,	parse_int_string,	input_parse },
    { "output",		1,	parse_string,		output_parse },
    { "par",		1,	parse_string,		par_parse },
    { "gradient",	1,	parse_int,		gradient_parse },
    { "add",		0,	(int *) NULL,		add_parse },
    { "subtract",	0,	(int *) NULL,		subtract_parse },
    { (String) NULL,	0,	(int *) NULL,		no_parse_func }
};

static Status read_combine_file(String error_msg)
{
    int i;

    for (i = 0; i < MAX_NINPUT_FILES; i++)
	input_found[i] = FALSE;

    output_found = FALSE;
    combine_found = FALSE;
    par_found = FALSE;

    CHECK_STATUS(parse_file(combine_file, combine_table, TRUE, error_msg));

    if (!input_found[0])
	NOT_FOUND("input");

    if (!output_found)
	NOT_FOUND("output");

    if (!combine_found)
	NOT_FOUND("combining function");

    return  OK;
}

static void print_combine_info()
{
    printf("number of dimensions of data = %d\n", ndim);
    printf("combining function = %s\n", combine_name);
}

void main(int argc, char **argv)
{
    int i;
    Line error_msg;
    Size_info size_info;
    Store_info store_info;
    File_info file_info;
    Combine_info combine_info;
    Par_info par_info;
    String par_file;

    printf(product);

    if (help_request(argc, argv, help_table))
	exit (0);

    if (argc != 2)
    {
        sprintf(error_msg, "correct usage: %s <combine file>", argv[0]);
        ERROR_AND_EXIT(error_msg);
    }

    combine_file = argv[1];

    if (read_combine_file(error_msg) == ERROR)
        ERROR_AND_EXIT(error_msg);

    determine_params();

    if (allocate_memory(error_msg) == ERROR)
        ERROR_AND_EXIT(error_msg);

    for (i = 0; i < ninput_files; i++)
    {
	if (OPEN_FOR_BINARY_READING(file_in[i], input_file[i]))
	{
	    sprintf(error_msg, "opening \"%s\" for reading", input_file[i]);
	    ERROR_AND_EXIT(error_msg);
	}
    }

    if (OPEN_FOR_BINARY_WRITING(file_out, output_file))
    {
	sprintf(error_msg, "opening \"%s\" for writing", output_file);
	ERROR_AND_EXIT(error_msg);
    }

    size_info.ndim = ndim;
    size_info.block_size_in = block_size_in;
    size_info.block_size_out = block_size_out;
    size_info.npoints_in = npoints_in;
    size_info.npoints_out = npoints_out;

    store_info.row_in = row_in;
    store_info.row_out = row_out;
    store_info.store_in = store_in;
    store_info.store_out = store_out;

    file_info.ninput_files = ninput_files;
    file_info.input_file = input_file;
    file_info.output_file = output_file;
    file_info.file_in = file_in;
    file_info.file_out = file_out;
    file_info.swapped = swapped;
    file_info.integer = integer;
    file_info.blocked = TRUE;  /* even if not blocked */
    file_info.header = header;

    combine_info.dim = dim;
    combine_info.type = type;
    combine_info.combine_func = combine_func;

    print_combine_info();
    FLUSH;

    if (block_process(&size_info, &store_info, &file_info,
				&combine_info, error_msg) == ERROR)
	ERROR_AND_EXIT(error_msg);

    par_info.file = output_file;
    par_info.ndim = ndim;
    par_info.npoints = npoints_out;
    par_info.block_size = block_size_out;
    par_info.ref = ref;
    par_info.blocked = blocked;
    par_info.deflated = deflated;
    par_info.level = level;
    par_info.param_dim = -1;

    if (par_found)
	par_file = output_par_file;
    else
	par_file = NULL;

    if (write_par_file(par_file, &par_info, error_msg) == ERROR)
	ERROR_AND_EXIT(error_msg);
}
