package jsdp.sdp.impl.multivariate;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.Function;
import jsdp.sdp.Action;
import jsdp.sdp.BackwardRecursion;
import jsdp.sdp.HashType;
import jsdp.sdp.ImmediateValueFunction;
import jsdp.sdp.RandomOutcomeFunction;
import jsdp.sdp.Recursion;
import jsdp.sdp.State;
import jsdp.sdp.ValueRepository;
import jsdp.utilities.probdist.MultiINIDistribution;
import umontreal.ssj.probdistmulti.DiscreteDistributionIntMulti;

/* loaded from: input_file:jsdp/sdp/impl/multivariate/BackwardRecursionImpl.class */
public class BackwardRecursionImpl extends BackwardRecursion {
    public BackwardRecursionImpl(Recursion.OptimisationDirection optimisationDirection, DiscreteDistributionIntMulti[] discreteDistributionIntMultiArr, ImmediateValueFunction<State, Action, Double> immediateValueFunction, RandomOutcomeFunction<State, Action, double[]> randomOutcomeFunction, Function<State, ArrayList<Action>> function, Function<State, Action> function2, double d, SamplingScheme samplingScheme, int i, double d2, HashType hashType) {
        super(optimisationDirection);
        this.horizonLength = discreteDistributionIntMultiArr.length;
        Arrays.stream(discreteDistributionIntMultiArr).forEach(discreteDistributionIntMulti -> {
            if (discreteDistributionIntMulti instanceof MultiINIDistribution) {
                ((MultiINIDistribution) discreteDistributionIntMulti).discretizeDistributions();
            }
        });
        this.stateSpace = new StateSpaceImpl[this.horizonLength + 1];
        for (int i2 = 0; i2 < this.horizonLength + 1; i2++) {
            this.stateSpace[i2] = new StateSpaceImpl(i2, function, function2, samplingScheme, i, hashType, d2);
        }
        this.transitionProbability = new TransitionProbabilityImpl(discreteDistributionIntMultiArr, randomOutcomeFunction, (StateSpaceImpl[]) getStateSpace());
        this.valueRepository = new ValueRepository(immediateValueFunction, d, hashType);
    }

    public BackwardRecursionImpl(Recursion.OptimisationDirection optimisationDirection, DiscreteDistributionIntMulti[] discreteDistributionIntMultiArr, ImmediateValueFunction<State, Action, Double> immediateValueFunction, RandomOutcomeFunction<State, Action, double[]> randomOutcomeFunction, Function<State, ArrayList<Action>> function, Function<State, Action> function2, double d, SamplingScheme samplingScheme, int i, double d2, int i2, float f, HashType hashType) {
        super(optimisationDirection);
        this.horizonLength = discreteDistributionIntMultiArr.length;
        Arrays.stream(discreteDistributionIntMultiArr).forEach(discreteDistributionIntMulti -> {
            if (discreteDistributionIntMulti instanceof MultiINIDistribution) {
                ((MultiINIDistribution) discreteDistributionIntMulti).discretizeDistributions();
            }
        });
        this.stateSpace = new StateSpaceImpl[this.horizonLength + 1];
        for (int i3 = 0; i3 < this.horizonLength + 1; i3++) {
            this.stateSpace[i3] = new StateSpaceImpl(i3, function, function2, samplingScheme, i, d2, hashType, i2, f);
        }
        this.transitionProbability = new TransitionProbabilityImpl(discreteDistributionIntMultiArr, randomOutcomeFunction, (StateSpaceImpl[]) getStateSpace());
        this.valueRepository = new ValueRepository(immediateValueFunction, d, hashType);
    }

    @Override // jsdp.sdp.Recursion
    public TransitionProbabilityImpl getTransitionProbability() {
        return (TransitionProbabilityImpl) this.transitionProbability;
    }

    public double getExpectedCost(double[] dArr) {
        return getExpectedCost(new StateDescriptorImpl(0, dArr));
    }

    public double getExpectedCost(StateDescriptorImpl stateDescriptorImpl) {
        State state = ((StateSpaceImpl) getStateSpace(stateDescriptorImpl.getPeriod())).getState(stateDescriptorImpl);
        try {
            return getExpectedValue(state);
        } catch (NullPointerException e) {
            recurse(state.getPeriod());
            return getExpectedValue(state);
        }
    }

    public ActionImpl getOptimalAction(StateDescriptorImpl stateDescriptorImpl) {
        State state = ((StateSpaceImpl) getStateSpace(stateDescriptorImpl.getPeriod())).getState(stateDescriptorImpl);
        try {
            getExpectedValue(state);
        } catch (NullPointerException e) {
            recurse(state.getPeriod());
        }
        return (ActionImpl) getValueRepository().getOptimalAction(state);
    }
}
