package graila.orekit;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import org.hipparchus.CalculusFieldElement;
import org.hipparchus.Field;
import org.hipparchus.geometry.euclidean.threed.FieldRotation;
import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
import org.hipparchus.geometry.euclidean.threed.Vector3D;
import org.hipparchus.ode.nonstiff.AdaptiveStepsizeIntegrator;
import org.hipparchus.ode.nonstiff.DormandPrince853Integrator;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtils;
import org.jfree.chart.JFreeChart;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.orekit.bodies.CelestialBody;
import org.orekit.bodies.CelestialBodyFactory;
import org.orekit.data.DataContext;
import org.orekit.data.DataProvidersManager;
import org.orekit.data.DataSource;
import org.orekit.data.DirectoryCrawler;
import org.orekit.files.daf.pck.PCK;
import org.orekit.files.daf.pck.PCKParser;
import org.orekit.files.daf.pck.PCKSegment;
import org.orekit.files.daf.spk.SPK;
import org.orekit.files.daf.spk.SPKParser;
import org.orekit.files.daf.spk.SPKSegment;
import org.orekit.files.daf.spk.SPKSegmentDataTypes.SPKSegmentDataMDASingleRecord;
import org.orekit.files.daf.spk.SPKSegmentDataTypes.SPKSegmentDataType1;
import org.orekit.forces.ForceModel;
import org.orekit.forces.gravity.HolmesFeatherstoneAttractionModel;
import org.orekit.forces.gravity.ThirdBodyAttraction;
import org.orekit.forces.gravity.potential.GravityFieldFactory;
import org.orekit.forces.gravity.potential.NormalizedSphericalHarmonicsProvider;
import org.orekit.forces.gravity.potential.SHAFormatReader;
import org.orekit.frames.FieldTransform;
import org.orekit.frames.Frame;
import org.orekit.frames.Transform;
import org.orekit.frames.TransformProvider;
import org.orekit.orbits.CartesianOrbit;
import org.orekit.orbits.Orbit;
import org.orekit.orbits.OrbitType;
import org.orekit.propagation.SpacecraftState;
import org.orekit.propagation.ToleranceProvider;
import org.orekit.propagation.numerical.NumericalPropagator;
import org.orekit.time.AbsoluteDate;
import org.orekit.time.FieldAbsoluteDate;
import org.orekit.utils.TimeStampedAngularCoordinates;
import org.orekit.utils.TimeStampedPVCoordinates;

public class App {
    public static void main(String[] args) {
        try {
            File orekitData = new File("/Users/rafael/Desktop/graila-orekit-tests/orekit-data");
            DataProvidersManager manager = DataContext.getDefault().getDataProvidersManager();
            manager.addProvider(new DirectoryCrawler(orekitData));

            // List of SPK files to process
            String[] spkFiles = {
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_094_2012_095.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_095_2012_096.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_096_2012_097.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_097_2012_098.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_098_2012_099.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_099_2012_100.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_100_2012_101.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_101_2012_102.spk",
                "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/GRAILA-kernels/gralugf2012_102_2012_103.spk"
            };

            List<SPKSegment> grailaSegments = new ArrayList<>();
            for (String spkFile : spkFiles) {
                grailaSegments.addAll(GrailaUtils.getGrailaSegments(spkFile));
            }
            grailaSegments.sort(Comparator.comparing(seg -> seg.getSegmentSummary().getInitialEpoch()));

            List<TimeStampedPVCoordinates> pvList = GrailaUtils.extractPVCoordinates(grailaSegments);
            pvList.sort(Comparator.comparing(TimeStampedPVCoordinates::getDate));

            System.out.println("Total TimeStampedPVCoordinates: " + pvList.size());

            // we want to start propagation from index 3647
            // lets set up frame, integrator and propagator

            int startIndex = 3647;
            int finalIndex = 3647+250; // 250 steps from 3647; this will be ~2 hours

            final TimeStampedPVCoordinates initialPV = pvList.get(3647);

            final CelestialBody moon = CelestialBodyFactory.getMoon();
            final CelestialBody earth = CelestialBodyFactory.getEarth();

            final Frame SCRF = moon.getInertiallyOrientedFrame();

            String moon_PA_kernel_file = "/Users/rafael/Desktop/graila-orekit-tests/orekit-data/moon_pa_de440_200625.bpc";
            DataSource moonPADataSource = new DataSource("moon_pa", () -> new File(moon_PA_kernel_file).toURI().toURL().openStream());

            PCK moonPApck = new PCKParser().parse(moonPADataSource);

            MoonPATransformProvider moonPATransformProvider = new MoonPATransformProvider(moonPApck);

            final Frame moonPA = new Frame(SCRF, moonPATransformProvider, "MoonPA");

            double minStep = 0.001;
            double maxStep = 1;
            final double[][] tolerances = ToleranceProvider.getDefaultToleranceProvider(1.0).getTolerances(initialPV.getPosition(), initialPV.getVelocity());
            final AdaptiveStepsizeIntegrator integrator = new DormandPrince853Integrator(minStep, maxStep, tolerances[0], tolerances[1]);
            integrator.setInitialStepSize(minStep);

            // Moon spherical harmonics
            GravityFieldFactory.clearPotentialCoefficientsReaders();
            SHAFormatReader fileReaderGRGM1200B = new SHAFormatReader("sha.grgm1200b_sigma", false);
            GravityFieldFactory.addPotentialCoefficientsReader(fileReaderGRGM1200B);
            NormalizedSphericalHarmonicsProvider gravityProviderGRGM1200B = GravityFieldFactory.getNormalizedProvider(350, 350);
            double moonGM = gravityProviderGRGM1200B.getMu();
            double moonAE = gravityProviderGRGM1200B.getAe();
            ForceModel holmesFeatherstoneMoonGRGM1200B = new HolmesFeatherstoneAttractionModel(moonPA, gravityProviderGRGM1200B);

            // Earth spherical harmonics
            // GravityFieldFactory.clearPotentialCoefficientsReaders();
            // ICGEMFormatReader fileReaderGGM05C = new ICGEMFormatReader("GGM05C.gfc", false);
            // GravityFieldFactory.addPotentialCoefficientsReader(fileReaderGGM05C);

            // Moon solid tides
            // SolidTides solidTidesMoon = new SolidTides(moonPA, moonAE, moonGM, TideSystem.ZERO_TIDE,// gravityProviderGRGM1200B.getTideSystem(),
            //                                            IERSConventions.IERS_2010, TimeScalesFactory.getUT1(IERSConventions.IERS_2010, true),
            //                                            CelestialBodyFactory.getSun(), earth);
            
            //NormalizedSphericalHarmonicsProvider gravityProviderGGM05C = (NormalizedSphericalHarmonicsProvider) fileReaderGGM05C.getProvider(true, 8, 8);
            // NormalizedSphericalHarmonicsProvider gravityProviderGGM05C = GravityFieldFactory.getNormalizedProvider(8, 8);
            // ForceModel holmesFeatherstoneEarthGGM05C = new HolmesFeatherstoneAttractionModel(earth.getInertiallyOrientedFrame(), gravityProviderGGM05C);
            // double earthGM = gravityProviderGGM05C.getMu();

            ThirdBodyAttraction earthThirdBodyAttraction = new ThirdBodyAttraction(earth);

            final NumericalPropagator perturbedPropagator = new NumericalPropagator(integrator);
            perturbedPropagator.setOrbitType(OrbitType.CARTESIAN);
            perturbedPropagator.removeForceModels();

            perturbedPropagator.addForceModel(holmesFeatherstoneMoonGRGM1200B);
            // perturbedPropagator.addForceModel(solidTidesMoon);
            // perturbedPropagator.addForceModel(holmesFeatherstoneEarthGGM05C);
            perturbedPropagator.addForceModel(earthThirdBodyAttraction);


            Orbit initialOrbit = new CartesianOrbit(initialPV, SCRF, moonGM);
            double initialMass = 200.0;
            SpacecraftState initialState = new SpacecraftState(initialOrbit, initialMass);
            perturbedPropagator.setInitialState(initialState);
            perturbedPropagator.setOrbitType(OrbitType.CARTESIAN);

            List<SpacecraftState> propagatedStates = new ArrayList<>();
            List<Double>          positionErrors    = new ArrayList<>();
            SpacecraftState       currentState      = initialState;

            // add the initial state/error
            propagatedStates.add(currentState);
            positionErrors.add(
                currentState.getPVCoordinates().getPosition()
                            .distance(pvList.get(startIndex).getPosition())
            );

            for (int i = startIndex + 1; i <= finalIndex; i++) {
                AbsoluteDate targetDate = pvList.get(i).getDate();
                currentState = perturbedPropagator.propagate(targetDate);
                propagatedStates.add(currentState);
                double err = currentState.getPVCoordinates().getPosition()
                                .distance(pvList.get(i).getPosition());
                positionErrors.add(err);
                perturbedPropagator.setInitialState(currentState);
            }

            System.out.println("Total propagated steps: " + propagatedStates.size());
            System.out.println("Max position error (m): " + Collections.max(positionErrors));

            AbsoluteDate t0 = pvList.get(startIndex).getDate();
            XYSeries series = new XYSeries("Position error");
            for (int i = 0; i < positionErrors.size(); i++) {
                double dt = pvList.get(startIndex + i).getDate().durationFrom(t0);
                series.add(dt, positionErrors.get(i));
            }
            XYSeriesCollection dataset = new XYSeriesCollection(series);

            JFreeChart chart = ChartFactory.createXYLineChart(
                "Position Error vs Time",
                "Time since t0 (s)",
                "Error (m)",
                dataset
            );

            File out = new File("error_vs_time.png");
            ChartUtils.saveChartAsPNG(out, chart, 800, 600);

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

class GrailaUtils {
    static List<SPKSegment> getGrailaSegments(String spkFilePath) {
        List<SPKSegment> result = new ArrayList<>();
        try {
            File spkFile = new File(spkFilePath);
            DataSource source = new DataSource(spkFile.getName(), () -> spkFile.toURI().toURL().openStream());
            SPK spk = new SPKParser().parse(source);
            for (SPKSegment segment : spk.getSegments()) {
                if (segment.getSegmentSummary().getTargetNAIFCode() == -177) {
                    result.add(segment);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return result;
    }

    static List<TimeStampedPVCoordinates> extractPVCoordinates(List<SPKSegment> segments) {
        List<TimeStampedPVCoordinates> pvList = new ArrayList<>();
        for (SPKSegment segment : segments) {
            SPKSegmentDataType1 segmentData = (SPKSegmentDataType1) segment.getSegmentData();
            List<SPKSegmentDataMDASingleRecord> records = segmentData.getDataRecords();
            System.out.println("Extracting " + records.size() + " reference positions from segment: " + segment.getSegmentName());
            for (SPKSegmentDataMDASingleRecord record : records) {
                Vector3D pos = record.getReferencePosition().scalarMultiply(1000.0);
                Vector3D vel = record.getReferenceVelocity().scalarMultiply(1000.0);
                AbsoluteDate date = record.getReferenceDate();
                pvList.add(new TimeStampedPVCoordinates(date, pos, vel));
            }
        }
        return pvList;
    }
}

class MoonPATransformProvider implements TransformProvider {

    private final PCKSegment paSegment;

    public MoonPATransformProvider(PCK moonPApck) {
        // grab the first segment
        this.paSegment = moonPApck.getSegments().get(0);
    }

    @Override
    public Transform getTransform(AbsoluteDate date) {
        // evaluate the segment at the requested date
        TimeStampedAngularCoordinates evaluation =
            (TimeStampedAngularCoordinates) paSegment.getSegmentData().evaluate(date);
        // the transform just requires rotation in principle
        return new Transform(
            date,
            evaluation.getRotation(),
            evaluation.getRotationRate(),
            evaluation.getRotationAcceleration());
    }

    @Override
    public <T extends CalculusFieldElement<T>> FieldTransform<T> getTransform(FieldAbsoluteDate<T> date) {
    
        Transform real = getTransform(date.toAbsoluteDate());
    
        Field<T> field = date.getField();
        FieldRotation<T> fr = new FieldRotation<>(
            field,
            real.getRotation());
    
        T zero = field.getZero();

        FieldVector3D<T> fRate = new FieldVector3D<>(
            zero.add(real.getRotationRate().getX()),
            zero.add(real.getRotationRate().getY()),
            zero.add(real.getRotationRate().getZ()));
        FieldVector3D<T> fAcc = new FieldVector3D<>(
            zero.add(real.getRotationAcceleration().getX()),
            zero.add(real.getRotationAcceleration().getY()),
            zero.add(real.getRotationAcceleration().getZ()));
    
        // use the pure‐rotation constructor
        return new FieldTransform<>(date, fr, fRate, fAcc);
    }
}
