import orekit;
orekit.initVM();

from orekit.pyhelpers import setup_orekit_curdir;
setup_orekit_curdir();

from org.hipparchus.geometry.euclidean.threed import Vector3D;
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator;

from org.orekit.attitudes import BodyCenterPointing;
from org.orekit.bodies import CelestialBodyFactory, OneAxisEllipsoid;
from org.orekit.forces.drag import DragForce, IsotropicDrag;
from org.orekit.forces.gravity import HolmesFeatherstoneAttractionModel, OceanTides, SolidTides, ThirdBodyAttraction;
from org.orekit.forces.gravity.potential import GravityFieldFactory;
from org.orekit.forces.maneuvers import ConstantThrustManeuver;
from org.orekit.forces.radiation import SolarRadiationPressure, IsotropicRadiationSingleCoefficient;
from org.orekit.frames import FramesFactory;
from org.orekit.models.earth.atmosphere import HarrisPriester;
from org.orekit.orbits import CircularOrbit, PositionAngle;
from org.orekit.propagation import PropagationType, SpacecraftState;
from org.orekit.propagation.numerical import NumericalPropagator;
from org.orekit.propagation.semianalytical.dsst import DSSTPropagator;
from org.orekit.propagation.semianalytical.dsst.forces import DSSTAtmosphericDrag, DSSTSolarRadiationPressure, DSSTTesseral, DSSTThirdBody, DSSTZonal;
from org.orekit.time import AbsoluteDate, TimeScalesFactory;
from org.orekit.utils import Constants, IERSConventions;

from MyDSSTManuever import DSSTConstantThrustManeuver;

from datetime import datetime;
from java.util import Arrays;
from math import acos, pi, sqrt;
from numpy import arange;
import matplotlib.pyplot as plt;

# %% CLOSE ALL FIGURES
# =============================================================================
# =============================================================================
# plt.close("all");
"""
# =============================================================================
# NOTES
# =============================================================================
"""

# %% INITIALIZATION
# =============================================================================
# =============================================================================
initDate = "2024-06-08T12:00:00.0Z";
initDate = AbsoluteDate(initDate, TimeScalesFactory.getUTC());
day = 24*3600.;

saveToFile = False;

# Data
J2 = -Constants.EGM96_EARTH_C20;
Re = Constants.WGS84_EARTH_EQUATORIAL_RADIUS;
mu = Constants.WGS84_EARTH_MU;
posAngle = PositionAngle.MEAN;

# Spacecraft
mass = 100.;
area = 1.; # Smallest cross-section
CD = 2.2;
CR = 1.2;

def computeSSOinclination(a):
    cosI = -4*pi/Constants.JULIAN_YEAR*a**(7/2)/(3*J2*Re**2*sqrt(mu));
    return float(acos(cosI));

earth = OneAxisEllipsoid(Re,
                         Constants.WGS84_EARTH_FLATTENING,
                         FramesFactory.getITRF(IERSConventions.IERS_2010, True));
earthCenterAttitudeLaw = BodyCenterPointing(FramesFactory.getEME2000(), earth);

# Initial orbit
a0 = Re + 500e3;
ex0 = 1e-3;
ey0 = 1e-3;
i0 = computeSSOinclination(a0);
raan0 = 0.;
u0 = 1.;

initOrbitMean = CircularOrbit(a0, ex0, ey0, i0, raan0, u0, PositionAngle.TRUE,
                          FramesFactory.getEME2000(), initDate, mu);
initStateMean = SpacecraftState(initOrbitMean, mass);

# %% DYNAMICAL MODEL
# =============================================================================
# =============================================================================
n, m = 4, 4;
nonSphericalGravity = HolmesFeatherstoneAttractionModel(FramesFactory.getITRF(IERSConventions.IERS_2010, True),
                                                        GravityFieldFactory.getConstantNormalizedProvider(n,m));        
forces = [nonSphericalGravity];

zonalDsst = DSSTZonal(GravityFieldFactory.getConstantUnnormalizedProvider(n,m));
tesseralDsst = DSSTTesseral(FramesFactory.getITRF(IERSConventions.IERS_2010, True),
                            Constants.WGS84_EARTH_ANGULAR_VELOCITY,
                            GravityFieldFactory.getConstantUnnormalizedProvider(n,m));
forcesDsst = [zonalDsst, tesseralDsst];

# %% PROPAGATORS
# =============================================================================
# =============================================================================
minStep = 1e-2;
maxStep = 1e+4;
absTol = 1e-9;
relTol = 1e-9;

integrator = DormandPrince853Integrator(minStep, maxStep, absTol, relTol);
integrator.setInitialStepSize(60.0);

tSpan = 2*Constants.JULIAN_DAY;
finalDate = initDate.shiftedBy(tSpan);
tVec = arange(0, tSpan, 60.);

initStateOsc = DSSTPropagator.computeOsculatingState(initStateMean, earthCenterAttitudeLaw, Arrays.asList(forcesDsst));
maneuver1 = ConstantThrustManeuver(initDate.shiftedBy(1e4), 600., 10., 1e3, Vector3D.PLUS_I);
# maneuver2 = ConstantThrustManeuver(initDate.shiftedBy(1e4 + initStateMean.getKeplerianPeriod()*3/2), 600., 10., 1e3, Vector3D.PLUS_I);
    
propagator = NumericalPropagator(integrator);
propagator.setAttitudeProvider(earthCenterAttitudeLaw);
for force in forces:
    propagator.addForceModel(force);
propagator.addForceModel(maneuver1);
# propagator.addForceModel(maneuver2);
ephGen = propagator.getEphemerisGenerator();
propagator.setInitialState(initStateOsc);
propagator.propagate(maneuver1.getStartDate());
lMin1 = propagator.getInitialState().getLv();
propagator.propagate(maneuver1.getEndDate());
lMax1 = propagator.getInitialState().getLv();
# propagator.propagate(maneuver2.getStartDate());
# lMin2 = propagator.getInitialState().getLv();
# propagator.propagate(maneuver2.getEndDate());
# lMax2 = propagator.getInitialState().getLv();
propagator.resetInitialState(initStateOsc);
startTime = datetime.now();
propagator.propagate(finalDate);
print("Runtime: ", datetime.now() - startTime);

aOsc, eOsc, iOsc = [], [], [];
elOsc = [aOsc, eOsc, iOsc];
for t in tVec:
    pv = ephGen.getGeneratedEphemeris().getPVCoordinates(initDate.shiftedBy(float(t)), initStateOsc.getFrame())
    
    orbit = CircularOrbit(pv, initStateOsc.getFrame(), initStateOsc.getMu());
    state = SpacecraftState(orbit, mass);
    aOsc.append(state.getA());
    eOsc.append(state.getE());
    iOsc.append(state.getI());

# %%

"""
FIGURE OUT WHAT'S GOING ON WITH LLIMITS AND WHY THEY CAN'T CAPTURE THE EVOLUTION OF THE MEAN ELEMENTS
"""
# maneuverDsst1 = DSSTConstantThrustManeuver(maneuver1, lMin1, lMin1 + 2*pi);
maneuverDsst1 = DSSTConstantThrustManeuver(maneuver1, lMin1, lMax1);
# maneuverDsst2 = DSSTConstantThrustManeuver(maneuver2, lMin2, lMin2 + 2*pi);
# maneuverDsst2 = DSSTConstantThrustManeuver(maneuver2, lMin2, lMax2);

propagatorDsst = DSSTPropagator(integrator, PropagationType.MEAN);
propagatorDsst.setAttitudeProvider(earthCenterAttitudeLaw);
for force in forcesDsst:
    propagatorDsst.addForceModel(force);
propagatorDsst.addForceModel(maneuverDsst1);
# propagatorDsst.addForceModel(maneuverDsst2);
ephGenDsst = propagatorDsst.getEphemerisGenerator();
propagatorDsst.setInitialState(initStateMean, PropagationType.MEAN);
startTime = datetime.now();
propagatorDsst.propagate(initDate.shiftedBy(tSpan));
print("Runtime: ", datetime.now() - startTime);

aMean, eMean, iMean = [], [], [];
elMean = [aMean, eMean, iMean];
for t in tVec:
    pv = ephGenDsst.getGeneratedEphemeris().getPVCoordinates(initDate.shiftedBy(float(t)), initStateOsc.getFrame())
    
    orbit = CircularOrbit(pv, initStateOsc.getFrame(), initStateOsc.getMu());
    state = SpacecraftState(orbit, mass);
    aMean.append(state.getA());
    eMean.append(state.getE());
    iMean.append(state.getI());

# %% PLOT
# =============================================================================
# =============================================================================
fig, axs = plt.subplots(len(elOsc),1, sharex = True);
for ii in range(len(elOsc)):
    axs[ii].plot(tVec, elOsc[ii]);
    axs[ii].plot(tVec, elMean[ii]);
    axs[ii].grid();
