/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysds.lops.PMapMult;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBlock;
import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInRange;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;

public class PMapmmSPInstruction
extends BinarySPInstruction {
    private static final int NUM_ROWBLOCKS = 4;

    private PMapmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.PMAPMM, op, in1, in2, out, opcode, istr);
    }

    public static PMapmmSPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase(PMapMult.OPCODE)) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
            return new PMapmmSPInstruction(aggbin, in1, in2, out, opcode, str);
        }
        throw new DMLRuntimeException("PMapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input2.getName());
        DataCharacteristics mc1 = sec.getDataCharacteristics(this.input1.getName());
        StorageLevel pmapmmStorageLevel = StorageLevel.MEMORY_AND_DISK();
        in2 = in2.repartition(sec.getSparkContext().defaultParallelism()).persist(pmapmmStorageLevel);
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
        int i = 0;
        while ((long)i < mc1.getRows()) {
            JavaPairRDD<MatrixIndexes, MatrixBlock> rdd = in1.filter(new IsBlockInRange(i + 1, i + 4 * mc1.getBlocksize(), 1L, mc1.getCols(), mc1)).mapToPair(new PMapMMRebaseBlocksFunction(i / mc1.getBlocksize()));
            int rlen = (int)Math.min(mc1.getRows() - (long)i, (long)(4 * mc1.getBlocksize()));
            PartitionedBlock<MatrixBlock> pmb = SparkExecutionContext.toPartitionedMatrixBlock(rdd, rlen, (int)mc1.getCols(), mc1.getBlocksize(), -1L);
            Broadcast<PartitionedBlock<MatrixBlock>> bpmb = sec.getSparkContext().broadcast(pmb);
            JavaPairRDD<MatrixIndexes, MatrixBlock> rdd2 = in2.flatMapToPair(new PMapMMFunction(bpmb, i / mc1.getBlocksize()));
            rdd2 = RDDAggregateUtils.sumByKeyStable(rdd2, false);
            rdd2.persist(pmapmmStorageLevel).count();
            bpmb.unpersist(false);
            out = out == null ? rdd2 : out.union(rdd2);
            i += 4 * mc1.getBlocksize();
        }
        out = out.persist(pmapmmStorageLevel);
        out.count();
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        sec.addLineageRDD(this.output.getName(), this.input2.getName());
        this.updateBinaryMMOutputDataCharacteristics(sec, true);
    }

    private static class PMapMMFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -4520080421816885321L;
        private AggregateBinaryOperator _op = null;
        private Broadcast<PartitionedBlock<MatrixBlock>> _pbc = null;
        private long _offset = -1L;

        public PMapMMFunction(Broadcast<PartitionedBlock<MatrixBlock>> binput, long offset) {
            this._pbc = binput;
            this._offset = offset;
            this._op = InstructionUtils.getMatMultOperator(1);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            PartitionedBlock<MatrixBlock> pm = this._pbc.value();
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes ixOut = new MatrixIndexes();
            MatrixBlock blkOut = new MatrixBlock();
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            for (int i = 1; i <= pm.getNumRowBlocks(); ++i) {
                MatrixBlock left = pm.getBlock(i, (int)ixIn.getRowIndex());
                OperationsOnMatrixValues.matMult(new MatrixIndexes(i, ixIn.getRowIndex()), left, ixIn, blkIn, ixOut, blkOut, this._op);
                ixOut.setIndexes(this._offset + (long)i, ixOut.getColumnIndex());
                ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
            }
            return ret.iterator();
        }
    }

    private static class PMapMMRebaseBlocksFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 98051757210704132L;
        private int _offset = -1;

        public PMapMMRebaseBlocksFunction(int offset) {
            this._offset = offset;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            long rix = ((MatrixIndexes)arg0._1()).getRowIndex() - (long)this._offset;
            MatrixIndexes ixout = new MatrixIndexes(rix, ((MatrixIndexes)arg0._1()).getColumnIndex());
            return new Tuple2((Object)ixout, (Object)((MatrixBlock)arg0._2()));
        }
    }
}

