import os
import jpype
import jpype.imports
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path

os.environ['JAVA_HOME'] = "/usr/lib/jvm/java-1.11.0-openjdk-amd64"

jpype.startJVM("-ea")
jpype.addClassPath(str(Path(OREKIT_PATH, "orekit-13.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-core-4.0.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-geometry-4.0.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-ode-4.0.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-fitting-4.0.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-optim-4.0.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-filtering-4.0.1.jar")))
jpype.addClassPath(str(Path(OREKIT_PATH, "hipparchus-stat-4.0.1.jar")))

from org.orekit.frames import FramesFactory
from org.orekit.utils import IERSConventions, Constants, PVCoordinates
from org.orekit.models.earth import ReferenceEllipsoid
import matplotlib.pyplot as plt
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.orekit.propagation.numerical import NumericalPropagator
from org.orekit.forces.gravity.potential import GravityFieldFactory, ICGEMFormatReader
from org.orekit.forces.gravity import HolmesFeatherstoneAttractionModel
from org.orekit.propagation import SpacecraftState
from org.hipparchus.geometry.euclidean.threed import Vector3D
from org.orekit.orbits import CartesianOrbit
from org.orekit.time import AbsoluteDate, TimeScalesFactory

from org.orekit.propagation.semianalytical.dsst.forces import DSSTTesseral, DSSTZonal, DSSTForceModel
from org.orekit.forces.gravity.potential import ICGEMFormatReader
from org.orekit.bodies import CelestialBodyFactory
from org.orekit.propagation.semianalytical.dsst import DSSTPropagator

class PropagateurDopri:

    def __init__(self, degree, order, msat=1000., step=timedelta(seconds=30), min_step=1e-3, max_step=300.):
        """
        """
        self.mu = Constants.WGS84_EARTH_MU
        self.msat = msat
        self.step = step
        self.order = order
        self.degree = degree

        self.ell_ref = ReferenceEllipsoid(Constants.WGS84_EARTH_EQUATORIAL_RADIUS,
                                              Constants.WGS84_EARTH_FLATTENING,
                                              FramesFactory.getITRF(IERSConventions.IERS_2010, True),
                                              Constants.WGS84_EARTH_MU,
                                              Constants.WGS84_EARTH_ANGULAR_VELOCITY)


        err_abs = 0.1
        err_rel = 1e-8
        integrator = DormandPrince853Integrator(min_step, max_step, err_abs, err_rel)
        integrator.setInitialStepSize(self.step.total_seconds())
        self.propagator = NumericalPropagator(integrator)
        self.propagator.setMu(self.mu)

        icgemFormatReader = ICGEMFormatReader("eigen-6s.gfc", True)
        GravityFieldFactory.addPotentialCoefficientsReader(icgemFormatReader)
        grav_provider = GravityFieldFactory.getNormalizedProvider(self.degree, self.order)
        itrf = FramesFactory.getITRF(IERSConventions.IERS_2010, True)
        grav_force = HolmesFeatherstoneAttractionModel(itrf, grav_provider)
        self.propagator.addForceModel(grav_force)


    def propagate(self, orb, date_prop):

        date_abs_prop = AbsoluteDate(date_prop.year, date_prop.month, date_prop.day, date_prop.hour, date_prop.minute,
                                   date_prop.second + date_prop.microsecond / 1e6, TimeScalesFactory.getUTC())

        # Initial state
        init_state = SpacecraftState(orb, self.msat)
        self.propagator.setInitialState(init_state)

        generator = self.propagator.getEphemerisGenerator()
        self.propagator.propagate(date_abs_prop)

        ephem = generator.getGeneratedEphemeris()

        return ephem

def osc_to_mean(osc_orb, pot_degree, pot_order):
    force_models = jpype.JClass('java.util.ArrayList')()
    icgemFormatReader = ICGEMFormatReader("eigen-6s.gfc", True)
    GravityFieldFactory.addPotentialCoefficientsReader(icgemFormatReader)
    grav_provider = GravityFieldFactory.getUnnormalizedProvider(pot_degree, pot_order)

    earth_frame = CelestialBodyFactory.getEarth().getBodyOrientedFrame()

    force_zonal = DSSTZonal(grav_provider)
    force_tesseral = DSSTTesseral(earth_frame, Constants.WGS84_EARTH_ANGULAR_VELOCITY, grav_provider)
    force_models.add(force_zonal)
    force_models.add(force_tesseral)


    state_osc = SpacecraftState(osc_orb, msat)
    state_mean = DSSTPropagator.computeMeanState(state_osc, None, force_models)
    return state_mean.getOrbit()



if __name__ == "__main__":

    j2000 = FramesFactory.getEME2000()
    tod = FramesFactory.getTOD(IERSConventions.IERS_2010, True)

    start = datetime(2025, 5, 21, 16, 2, 23)
    pv_start_j2000 = np.array([-1043795.889490802, 3405121.591840935, -6845132.308004311, -5195.150541053999,
                               -4712.509151136999, -1551.0498985929999])

    vec_pos_start = Vector3D(*pv_start_j2000[:3])
    vec_vit_start = Vector3D(*pv_start_j2000[3:])
    pv_start_j2000_coords = PVCoordinates(vec_pos_start, vec_vit_start)

    start_abs = AbsoluteDate(start.year , start.month , start.day , start.hour , start.minute,
                             start.second + start.microsecond / 1e6 , TimeScalesFactory.getUTC())

    pv_start_tod_coords = j2000.getTransformTo(tod, start_abs).transformPVCoordinates(pv_start_j2000_coords)

    orb_start = CartesianOrbit(pv_start_tod_coords, tod, start_abs, Constants.WGS84_EARTH_MU)

    dates = np.arange(start, start + timedelta(hours=2), timedelta(minutes=1)).tolist()

    degrees_orders = [(2, 0), (2, 2), (16, 16)]
    msat = 1000.

    fig, axes = plt.subplots(2, 1)

    smas_osc_tot = []
    labels_tot = []
    for degree, order in degrees_orders:
        print(degree, order)
        prop_orekit = PropagateurDopri(degree, order)

        ephem = prop_orekit.propagate(orb_start, dates[-1])

        smas_mean = []
        smas_osc = []
        for d in dates:
            d_abs = AbsoluteDate(d.year , d.month , d.day , d.hour , d.minute,  d.second + d.microsecond / 1e6 ,
                                 TimeScalesFactory.getUTC())
            state = ephem.propagate(d_abs)
            orb_osc = state.getOrbit()

            orb_mean_equ = osc_to_mean(orb_osc, degree, order)

            smas_mean.append(orb_mean_equ.getA())
            smas_osc.append(orb_osc.getA())

        smas_osc_tot.append(smas_osc)
        axes[0].plot(dates, smas_mean, label=f'DSST {degree}x{order}')
        # axes[1].plot(dates, smas_osc, label=f'Osc. {degree}x{order}')
        labels_tot.append(f'{degree}x{order}')

    smas_osc_tot = np.vstack(smas_osc_tot)
    for j in range(len(smas_osc_tot) - 1):
        axes[1].plot(dates, smas_osc_tot[j + 1] -  smas_osc_tot[j], label=f'Osc. {labels_tot[j + 1]} - Osc. {labels_tot[j]}')

    axes[0].set_xlabel("Date", fontsize=11, fontweight="bold")
    axes[0].set_ylabel("Mean Sma (m)", fontsize=11, fontweight="bold")
    axes[0].legend()

    axes[1].set_xlabel("Date", fontsize=11, fontweight="bold")
    axes[1].set_ylabel("Osc. Sma diff. (m)", fontsize=11, fontweight="bold")
    axes[1].legend()

    plt.show()








