import orekit
import math


orekit.initVM()

from org.hipparchus.geometry.euclidean.threed import Vector3D
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.orekit.utils import PVCoordinates, Constants, LagrangianPoints, AbsolutePVCoordinates, PVCoordinatesProvider
from org.orekit.bodies import CelestialBodyFactory, CR3BPSystem, CR3BPFactory
from org.orekit.propagation import SpacecraftState
from org.orekit.propagation.numerical import NumericalPropagator
from org.orekit.frames import FramesFactory
from org.orekit.time import AbsoluteDate, TimeScalesFactory
from org.orekit.orbits import HaloOrbit, RichardsonExpansion, LibrationOrbitFamily, LibrationOrbitType, CR3BPDifferentialCorrection
from org.orekit.propagation.numerical.cr3bp import CR3BPForceModel
from org.orekit.data import DirectoryCrawler, DataContext
from org.hipparchus.ode.nonstiff import ClassicalRungeKuttaIntegrator
from java.io import File
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

orekit_data_path = r'C:\Users\Samuele.Larese1\AppData\Local\anaconda3\pkgs\orekit-13.0.1-py312h275cf98_0\info\test\test\orekit-data'

JavaFile = File(orekit_data_path)
manager = DataContext.getDefault().getDataProvidersManager()
try:
    manager.addProvider(DirectoryCrawler(JavaFile))
except Exception as e:
    print(f"Error loading Orekit datra: {e}")
    exit()


initial_date = AbsoluteDate(2025, 11, 8, 23, 22, 7.10353, TimeScalesFactory.getUTC())

def setup_propagator(initial_state, cr3bp_system):
  
    min_step = 1E-10
    max_step = 1E-3

    position_tolerance = 1.0E-16
    velocity_tolerance = 1.0E-16
    mass_tolerance = 1.0E-12    

    vec_absolute_tolerances = [
        position_tolerance, position_tolerance, position_tolerance,
        velocity_tolerance, velocity_tolerance, velocity_tolerance,
        mass_tolerance
    ]
    vec_relative_tolerances = [0.0] * len(vec_absolute_tolerances)

    integrator = DormandPrince853Integrator(min_step, max_step,
                                             vec_absolute_tolerances,
                                             vec_relative_tolerances)
        
    propagator = NumericalPropagator(integrator)
    propagator.setOrbitType(None)
    propagator.setIgnoreCentralAttraction(True)

    propagator.addForceModel(CR3BPForceModel(cr3bp_system))
    propagator.setInitialState(initial_state)
    return propagator

def simulate_halo_orbit(system_name, cr3bp_system, libration_point, azimuth_amplitude):
  
    print(f"\n--- {system_name} ---")

    d_dim = cr3bp_system.getDdim()
    t_dim = cr3bp_system.getTdim()
    mu_ratio = cr3bp_system.getMassRatio()

    print(f"d_dim: {d_dim / 1e3:.3f} km")
    print(f"t_dim: {t_dim:.3f} s")
    print(f"mass_ratio: {mu_ratio:.6e}")
    print("-" * 40)

    rot_frame = cr3bp_system.getRotatingFrame()

    richardson = RichardsonExpansion(cr3bp_system, libration_point)
    halo_orbit_approx = HaloOrbit(richardson, azimuth_amplitude, LibrationOrbitFamily.NORTHERN)
    
    pv_first_guess_non_dim = halo_orbit_approx.getInitialPV() 
    T_first_guess_non_dim = halo_orbit_approx.getOrbitalPeriod() 
    T_orbit_guess_real = T_first_guess_non_dim * t_dim / (2 * math.pi) 
    
    print('--- Before differential correction ---')
    print(f"Non-dimensional period: {T_first_guess_non_dim:.10f} units")
    print(f"Orbital period (dimensional): {T_orbit_guess_real:.2f} s ({T_orbit_guess_real / Constants.JULIAN_DAY:.6f} days)")
    print(f"Initial position (non-dimensional): {pv_first_guess_non_dim.getPosition()}")
    print(f"Initial velocity (non-dimensional): {pv_first_guess_non_dim.getVelocity()}")
    print("-" * 40)
    

    halo_orbit_approx.applyDifferentialCorrection()
    T_orbit_non_dim = halo_orbit_approx.getOrbitalPeriod() 
    pv_halo_non_dim_corrected = halo_orbit_approx.getInitialPV()
    T_orbit_real = T_orbit_non_dim * t_dim / (2 * math.pi) 
    
    print('--- After differential correction ---')
    print(f"Non-dimensional period: {T_orbit_non_dim:.10f} units")
    print(f"Orbital period (dimensional): {T_orbit_real:.2f} s ({T_orbit_real / Constants.JULIAN_DAY:.6f} days)")
    print(f"Initial position (non-dimensional): {pv_halo_non_dim_corrected.getPosition()}")
    print(f"Initial velocity (non-dimensional): {pv_halo_non_dim_corrected.getVelocity()}")
    print("-" * 40)

    
    initial_absolute_pv = AbsolutePVCoordinates(rot_frame, initial_date,
                                                 pv_halo_non_dim_corrected.getPosition(),
                                                 pv_halo_non_dim_corrected.getVelocity())
    
    initial_state = SpacecraftState(initial_absolute_pv)
    propagator = setup_propagator(initial_state, cr3bp_system)

    final_date = initial_date.shiftedBy(T_orbit_real*2*math.pi/t_dim)
    final_state = propagator.propagate(final_date)

    trajectory_x_rot_km = []
    trajectory_y_rot_km = []
    trajectory_z_rot_km = []

    extrap_date = initial_date
    while(extrap_date.compareTo(final_date) <= 0):

        current_state = propagator.propagate(extrap_date)

        pv_rot_non_dim = current_state.getPVCoordinates(rot_frame)
        pos_rot_real_m = pv_rot_non_dim.getPosition().scalarMultiply(d_dim)
        
        trajectory_x_rot_km.append(pos_rot_real_m.getX() / 1e3)
        trajectory_y_rot_km.append(pos_rot_real_m.getY() / 1e3)
        trajectory_z_rot_km.append(pos_rot_real_m.getZ() / 1e3)


        extrap_date = extrap_date.shiftedBy((T_orbit_real*2*math.pi/t_dim)/30)

    print("-" * 40)


    pv_final_non_dim = final_state.getPVCoordinates(rot_frame) 

    real_apv_final = cr3bp_system.getRealAPV(final_state.getAbsPVA(), final_date, rot_frame) 
    real_apv_initial = cr3bp_system.getRealAPV(initial_absolute_pv, initial_date, rot_frame)

    print('Retrieving the real (dimensional) initial and final position and velocity with getRealAPV...')
    print('Final Real Position: ', real_apv_final.getPVCoordinates().getPosition())
    print('Final Real Velocity: ', real_apv_final.getPVCoordinates().getVelocity())

    print('Initial Real Position: ', real_apv_initial.getPVCoordinates().getPosition())
    print('Initial Real Velocity: ', real_apv_initial.getPVCoordinates().getVelocity())

    print('Real Delta Position: ', real_apv_final.getPVCoordinates().getPosition().subtract(real_apv_initial.getPVCoordinates().getPosition()).getNorm())
    

    print("-" * 40)

    print('Retrieving the real (dimensional) initial and final position and velocity with scalarMultiply...')
    pos_initial_real = pv_halo_non_dim_corrected.getPosition().scalarMultiply(d_dim)
    vel_initial_real = pv_halo_non_dim_corrected.getVelocity().scalarMultiply(d_dim / (t_dim / (2*math.pi)))
    print(f"Initial position (rotating frame): X={pos_initial_real.getX():.3f} m, Y={pos_initial_real.getY():.3f} m, Z={pos_initial_real.getZ():.3f} m")
    print(f"Initial velocity (rotating frame): X={vel_initial_real.getX():.3f} m/s, Y={vel_initial_real.getY():.3f} m/s, Z={vel_initial_real.getZ():.3f} m/s")
    
    pos_final_real = pv_final_non_dim.getPosition().scalarMultiply(d_dim)
    vel_final_real = pv_final_non_dim.getVelocity().scalarMultiply(d_dim / (t_dim / (2*math.pi)))
    print(f"Final position (rotating frame): X={pos_final_real.getX():.3f} m, Y={pos_final_real.getY():.3f} m, Z={pos_final_real.getZ():.3f} m")
    print(f"Final velocity (rotating frame): X={vel_final_real.getX():.3f} m/s, Y={vel_final_real.getY():.3f} m/s, Z={vel_final_real.getZ():.3f} m/s")

    delta_pos_real = pos_final_real.subtract(pos_initial_real).getNorm()
    delta_vel_real = vel_final_real.subtract(vel_initial_real).getNorm()
    print(f"Δ Position (dimensional): {delta_pos_real:.3f} m")
    print(f"Δ Velocity (dimensional): {delta_vel_real:.6f} m/s")
    print("-" * 40)



    # --- Plotting 3D  ---
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Scatter plot della traiettoria
    ax.scatter(trajectory_x_rot_km, trajectory_y_rot_km, trajectory_z_rot_km, 
               c='blue', marker='o', s=5, label='Halo Trajectory')


    sun_pos_rot_km = [(-mu_ratio * d_dim) / 1e3, 0, 0]
    earth_pos_rot_km = [((1 - mu_ratio) * d_dim) / 1e3, 0, 0]
    
    L1_pos_nd = cr3bp_system.getLPosition(LagrangianPoints.L1)
    L1_pos_km = [L1_pos_nd.getX() * d_dim / 1e3, L1_pos_nd.getY() * d_dim / 1e3, L1_pos_nd.getZ() * d_dim / 1e3]

    ax.scatter(sun_pos_rot_km[0], sun_pos_rot_km[1], sun_pos_rot_km[2], 
               color='yellow', marker='o', s=100, label='Sun')
    ax.scatter(earth_pos_rot_km[0], earth_pos_rot_km[1], earth_pos_rot_km[2], 
               color='blue', marker='o', s=50, label='Earth')
    ax.scatter(L1_pos_km[0], L1_pos_km[1], L1_pos_km[2], 
               color='red', marker='x', s=100, label='L1 Point')

    ax.set_xlabel('X (km)')
    ax.set_ylabel('Y (km)')
    ax.set_zlabel('Z (km)')
    ax.set_title('L1 Halo Orbit in Rotating Frame')
    ax.legend()
    plt.show()



azimuth_amplitude_se = 120_000e3 

# Simulate Sun-Earth L1 Orbit
cr3bp_system = CR3BPFactory.getSunEarthCR3BP(initial_date, TimeScalesFactory.getUTC())
simulate_halo_orbit("Sun-Earth L1 Orbit", cr3bp_system, LagrangianPoints.L1, azimuth_amplitude_se)




