import org.hipparchus.analysis.differentiation.Gradient;
import org.hipparchus.analysis.differentiation.GradientField;
import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
import org.hipparchus.geometry.euclidean.threed.Vector3D;
import org.hipparchus.ode.FieldODEIntegrator;
import org.hipparchus.ode.nonstiff.DormandPrince853FieldIntegrator;
import org.hipparchus.util.FastMath;
import org.orekit.data.DataContext;
import org.orekit.data.DataProvider;
import org.orekit.data.DirectoryCrawler;
import org.orekit.frames.Frame;
import org.orekit.frames.FramesFactory;
import org.orekit.orbits.*;
import org.orekit.propagation.FieldSpacecraftState;
import org.orekit.propagation.analytical.FieldKeplerianPropagator;
import org.orekit.propagation.numerical.FieldNumericalPropagator;
import org.orekit.time.AbsoluteDate;
import org.orekit.time.FieldAbsoluteDate;
import org.orekit.time.TimeScalesFactory;
import org.orekit.utils.Constants;
import org.orekit.utils.FieldAbsolutePVCoordinates;
import org.orekit.utils.FieldPVCoordinates;
import org.orekit.utils.PVCoordinates;

import java.io.File;
import java.util.Arrays;

public class KeplerianMotionDerivativesTest {

    public static void main(String[] args) {
        run(false);
        run(true);
    }

    public static void run(final boolean makeEccentric) {

        // load Orekit data
        final File orekitData = new File(System.getProperty("user.home") + "/Documents/data/orekit-data");
        final DataProvider dirCrawler = new DirectoryCrawler(orekitData);
        DataContext.getDefault().getDataProvidersManager().addProvider(dirCrawler);

        // reference frame and epoch
        final Frame frame = FramesFactory.getGCRF();
        final AbsoluteDate date = new AbsoluteDate("2000-01-01T11:58:55.816Z", TimeScalesFactory.getUTC());

        // nominal initial position, velocity, and acceleration vectors
        final Vector3D pos = new Vector3D(549559.923311143, -1096995.5695899995, 6685464.8243444795);
        Vector3D vel = new Vector3D(-6846.728109323677, -3429.9932278655106, 0.0);
        if (makeEccentric) {
            vel = vel.add(new Vector3D(0.0, -1e3, 0.0));
        }

        // standard gravitational parameter
        final double mu = Constants.WGS84_EARTH_MU;

        // nominal Keplerian orbit
        final Orbit keplerOrbit = new KeplerianOrbit(new PVCoordinates(pos, vel), frame, date, mu);
        System.out.println("\nInitial " + keplerOrbit);

        // nominal Keplerian acceleration
        final Vector3D keplerAcc = pos.scalarMultiply(-mu / FastMath.pow(pos.getNorm(), 3));

        // Gradient field to compute the partials with respect to the initial position and velocity
        final int freeParams = 6;
        final GradientField field = GradientField.getField(freeParams);

        // field time shift
        final Gradient fieldDt = new Gradient(
                -6.004268800567434e-4,
                -2.468257391808438e-9,
                -6.613853266115959e-10,
                2.1441086746419246e-9,
                1.4820080849406567e-12,
                3.9711352817274305e-13,
                -1.2873804820179552e-12
        );

        // field initial and target epochs
        final FieldAbsoluteDate<Gradient> fieldDate = new FieldAbsoluteDate<>(field, date);
        final FieldAbsoluteDate<Gradient> fieldTargetDate = fieldDate.shiftedBy(fieldDt);

        // initial position, velocity, and acceleration as field vectors
        final FieldVector3D<Gradient> fieldPos = new FieldVector3D<>(
                Gradient.variable(freeParams, 0, pos.getX()),
                Gradient.variable(freeParams, 1, pos.getY()),
                Gradient.variable(freeParams, 2, pos.getZ())
        );
        final FieldVector3D<Gradient> fieldVel = new FieldVector3D<>(
                Gradient.variable(freeParams, 3, vel.getX()),
                Gradient.variable(freeParams, 4, vel.getY()),
                Gradient.variable(freeParams, 5, vel.getZ())
        );
        final FieldVector3D<Gradient> fieldAcc = new FieldVector3D<>(
                Gradient.constant(freeParams, keplerAcc.getX()),
                Gradient.constant(freeParams, keplerAcc.getY()),
                Gradient.constant(freeParams, keplerAcc.getZ())
        );

        // initial field absolute PV coordinates
        final FieldAbsolutePVCoordinates<Gradient> fieldAbsPva = new FieldAbsolutePVCoordinates<>(
                frame,
                fieldDate,
                fieldPos,
                fieldVel,
                fieldAcc
        );

        // shifted field absolute PV coordinates
        final FieldAbsolutePVCoordinates<Gradient> shiftedFieldAbsPva = fieldAbsPva.shiftedBy(fieldDt);

        // initial field Cartesian orbit
        final FieldCartesianOrbit<Gradient> fieldOrbit =
                new FieldCartesianOrbit<>(fieldAbsPva, frame, field.getZero().add(mu));

        // predicted PV coordinates
        final FieldPVCoordinates<Gradient> predictedFieldPV = KeplerianMotionCartesianUtility
                .predictPositionVelocity(fieldDt, fieldPos, fieldVel, field.getZero().add(mu));

        // shifted field Cartesian orbit
        final FieldCartesianOrbit<Gradient> shiftedFieldOrbit = fieldOrbit.shiftedBy(fieldDt);

        // analytically propagated Cartesian orbit
        final FieldKeplerianPropagator<Gradient> fieldKepProp = new FieldKeplerianPropagator<>(fieldOrbit);
        final FieldOrbit<Gradient> kepPropFieldOrbit = fieldKepProp.propagate(fieldTargetDate).getOrbit();

        // numerically propagated Cartesian orbit
        final FieldODEIntegrator<Gradient> fieldDP853 =
                new DormandPrince853FieldIntegrator<>(field, 1e-16, 1e6, 1e-14, 1e-14);
        final FieldNumericalPropagator<Gradient> fieldNumProp =
                new FieldNumericalPropagator<>(field, fieldDP853);
        fieldNumProp.setInitialState(new FieldSpacecraftState<>(fieldOrbit));
        final FieldOrbit<Gradient> numPropFieldOrbit = fieldNumProp.propagate(fieldTargetDate).getOrbit();

        System.out.println("\nPartials of final position w.r.t. initial position and velocity");
        printGradient(shiftedFieldAbsPva.getPosition(), "constant acceleration");
        printGradient(predictedFieldPV.getPosition(), "predicted PV (analytical Keplerian motion)");
        printGradient(shiftedFieldOrbit.getPosition(), "shifted orbit (analytical Keplerian motion)");
        printGradient(kepPropFieldOrbit.getPosition(), "analytical Keplerian propagator");
        printGradient(numPropFieldOrbit.getPosition(), "numerical propagator");
    }

    private static void printGradient(final FieldVector3D<Gradient> vec, final String str) {
        System.out.println("\n" + str + ":");
        System.out.println("dx/dX0" + Arrays.toString(vec.getX().getGradient()));
        System.out.println("dy/dX0" + Arrays.toString(vec.getY().getGradient()));
        System.out.println("dz/dX0" + Arrays.toString(vec.getZ().getGradient()));
    }

}
