# %%
# Measurement generator

'''
This program is a tool to generate measurements of satellites as observed from a configured ground station.
Range , Azimuth , Elevation with appropriately configured random noise will be generated.
'''



import orekit
import pandas as pd
from org.orekit.orbits import CartesianOrbit

print('================= Orekit Initialization ================')
vm = orekit.initVM()
print ('Java version:',vm.java_version)
print ('Orekit version:', orekit.VERSION)
print('========================================================')

#setup orekit 
from orekit.pyhelpers import setup_orekit_curdir
# orekit directory path
orekit_dir = "C:\\Users\\radhakrishna\\Documents\\orekit-files\\orekit-data-main"
setup_orekit_curdir(orekit_dir)


tle='''0 SENTINEL-1C                
1 62261U 24235A   25218.58972421 -.00001515  00000-0 -31181-3 0  9990
2 62261  98.1819 225.4143 0001316  85.5830 274.5520 14.59197729 35531'''

from org.orekit.propagation.analytical.tle import TLE
from org.orekit.propagation.analytical.tle import TLEPropagator
from org.orekit.time import AbsoluteDate, TimeScalesFactory
from org.orekit.frames import FramesFactory, TopocentricFrame


tle=TLE(tle.splitlines()[1], tle.splitlines()[2])
propagator = TLEPropagator.selectExtrapolator(tle)
epoch=tle.getDate()
satelliteState=propagator.propagate(epoch)
#the orbit has frame inbuilt so need not worry about the frame here, but yeah, TLE calcs are done in TEME frame
initial_orbit=satelliteState.getOrbit()
initial_frame = initial_orbit.getFrame()

gcrf = FramesFactory.getGCRF()
transform = initial_frame.getTransformTo(gcrf, epoch)

pvIngcrf = transform.transformPVCoordinates(initial_orbit.getPVCoordinates())

mu=initial_orbit.getMu()
initial_orbit = CartesianOrbit(pvIngcrf, gcrf, epoch, mu)
print('Initial Orbit at epoch:', initial_orbit)

#ground station

from org.orekit.frames import FramesFactory
from org.orekit.utils import IERSConventions
from org.orekit.estimation.measurements import GroundStation
from org.orekit.bodies import OneAxisEllipsoid, GeodeticPoint
from org.orekit.frames import TopocentricFrame
from math import radians,degrees
# Define ITRF and Earth model
itrf = FramesFactory.getITRF(IERSConventions.IERS_2010, True)
earth = OneAxisEllipsoid(6378137.0, 1.0 / 298.257223563, itrf)

# Define a geodetic point for the ground station
latitude_deg = 17.92 # radians
longitude_deg = 78.75 # radians
altitude = 500.0 # meters
point = GeodeticPoint(radians(latitude_deg), radians(longitude_deg), altitude)

# Create a TopocentricFrame for the ground station
station_frame = TopocentricFrame(earth, point, "LRDE_KOLAR")

# Create the GroundStation object
ground_station = GroundStation(station_frame)
print(ground_station)






# %%
def read_measurements(csv_file):
    '''
    Format
    Range,Azimuth_deg,Elevation_deg,Time,Lat,Lon,Alt,Az_sigma_deg,El_sigma_deg,Range_sigma
    '''

    
    df = pd.read_csv(csv_file)
    measurements = []
    for index, row in df.iterrows():
        measurement = {
            'range': row['Range'],
            'azimuth_rad': radians(float(row['Azimuth_deg'])),
            'elevation_rad':radians(float(row['Elevation_deg'])),
            'time': row['Time'],
            'lat': radians(float(row['Lat'])),
            'lon': radians(float(row['Lon'])),
            'alt': row['Alt'],
            'az_sigma_rad': radians(float(row['Az_sigma_deg'])),
            'el_sigma_rad': radians(float(row['El_sigma_deg'])),
            'range_sigma': row['Range_sigma']
        }

        
        measurements.append(measurement)

    return measurements

csv_path='simulated_measurements/generated_OREKIT_RADAR_angsig_1deg_rangesig_1000.0m_TIME_20250925_170755.csv'
csv_path='simulated_measurements/generated_OREKIT_RADAR_angsig_1deg_rangesig_1000.0m_TIME_20250925_192958.csv'
csv_path='simulated_measurements/generated_OREKIT_RADAR_angsig_1deg_rangesig_1000.0m_TIME_20250925_220145.csv'
measurements=read_measurements(csv_path)
print(measurements[0])

# %%

'''
sample measurements :
{'range': 2781496.771666476, 'azimuth_rad': -2.0306491189391744, 'elevation_rad': 0.033666378601059886, 'time': '2025-08-06T14:10:00.000000Z', 'lat': 0.22926994120047914, 'lon': 1.3636100366199015, 'alt': 500.0, 'az_sigma_rad': 0.017453292519943295, 'el_sigma_rad': 0.017453292519943295, 'range_sigma': 1000.0}
'''
from orekit.pyhelpers import absolutedate_to_datetime,datetime_to_absolutedate
from org.orekit.estimation.measurements import AngularAzEl,AngularRaDec , Range , MultiplexedMeasurement
from org.orekit.estimation.measurements import  ObservedMeasurement , ObservableSatellite
from java.util import ArrayList

satellite = ObservableSatellite(0)
true_measurements=[]
for m in measurements[:]:
    print(m['time'])
    py_time_str=m['time']
    py_time=pd.to_datetime(py_time_str)
    az_rad=m['azimuth_rad']
    el_rad=m['elevation_rad']
    range_val=m['range']
    lat_rad=m['lat']
    lon_rad=m['lon']
    alt=m['alt']
    std_az=m['az_sigma_rad']
    std_el=m['el_sigma_rad']
    std_range=m['range_sigma']
    

    gs=GroundStation(TopocentricFrame(earth,GeodeticPoint(lat_rad,lon_rad,alt),'GS_'+str(lat_rad)+'_'+str(lon_rad)))
    
    measurement_angle=AngularAzEl(
        gs,
        datetime_to_absolutedate(py_time),
        [az_rad, el_rad],
        [std_az, std_el], #sigma
        [1.0,1.0],  # Inverse-variance weights #[1.0, 1.0], #weights,
        satellite   
    )
    

    measurement_range=Range(
        gs,
        False,
        datetime_to_absolutedate(py_time),
        range_val,
        std_range, # sigma
        1.0,  # Inverse-variance weight#1.0, # weight
        satellite
    )

    # true_measurements.append(measurement_angle)
    # true_measurements.append(measurement_range)

    mx_list=ArrayList(2)
    mx_list.add(measurement_angle)
    mx_list.add(measurement_range)
    measurement_mx=MultiplexedMeasurement(mx_list)
    true_measurements.append(measurement_mx)

print(true_measurements[0])
print(f"Number of measurements: {len(true_measurements)}")

# %%
from orekit import JArray_double
from orekit.pyhelpers import JArray_double2D
from org.orekit.propagation.numerical import NumericalPropagator
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.orekit.orbits import OrbitType
from org.orekit.propagation import SpacecraftState
from org.orekit.utils import Constants
from org.orekit.forces.gravity import HolmesFeatherstoneAttractionModel
from org.orekit.forces.gravity.potential import GravityFieldFactory
from org.orekit.forces.radiation import SolarRadiationPressure, IsotropicRadiationSingleCoefficient
from org.orekit.forces.drag import DragForce, IsotropicDrag
from org.orekit.models.earth.atmosphere import DTM2000
from org.orekit.models.earth.atmosphere.data import CssiSpaceWeatherData
from org.orekit.forces.gravity import ThirdBodyAttraction
from org.orekit.bodies import CelestialBodyFactory
from org.orekit.propagation.conversion import NumericalPropagatorBuilder, DormandPrince853IntegratorBuilder
from org.orekit.orbits import CartesianOrbit, PositionAngleType, OrbitType
from org.hipparchus.optim.nonlinear.vector.leastsquares import LevenbergMarquardtOptimizer
from org.orekit.estimation.leastsquares import BatchLSEstimator, PythonBatchLSObserver



def GetNumericalPropagatorBuilder(inital_orbit):

    
    PROP_PARAMS = {'min_step': 0.001, 'max_step': 1000.0, 'pos_error': 0.01}
    POS_SCALE = 10.0
    MASS=1000.0  # kg

    # Create the propagator

    prop_builder = NumericalPropagatorBuilder(
            initial_orbit,
            DormandPrince853IntegratorBuilder(*PROP_PARAMS.values()),
            PositionAngleType.MEAN,
            POS_SCALE
        )
    #prop_builder.setMass(MASS)
    # prop_builder.setAttitudeProvider(NadirPointing(eme2000, earth))

   # Define force models
    itrf = FramesFactory.getITRF(IERSConventions.IERS_2010, True)
    earth = OneAxisEllipsoid(
        Constants.WGS84_EARTH_EQUATORIAL_RADIUS, Constants.WGS84_EARTH_FLATTENING, itrf
    )
    gravity_provider = GravityFieldFactory.getNormalizedProvider(120, 120)
    force_models = [
        HolmesFeatherstoneAttractionModel(itrf, gravity_provider),
        # ThirdBodyAttraction(CelestialBodyFactory.getMoon()),
        # ThirdBodyAttraction(CelestialBodyFactory.getSun()),
        # SolarRadiationPressure(
        #     CelestialBodyFactory.getSun(), earth, IsotropicRadiationSingleCoefficient(10.0, 1.2)
        # ),
        # DragForce(
        #     DTM2000(CssiSpaceWeatherData("SpaceWeather-All-v1.2.txt"), 
        #             CelestialBodyFactory.getSun(), earth),
        #     IsotropicDrag(10.0, 1.2)
        # )
    ]
    for fm in force_models:
        prop_builder.addForceModel(fm)

    for force_model in force_models:
        if isinstance(force_model, DragForce) or isinstance(force_model, SolarRadiationPressure):
            for driver in force_model.getParametersDrivers():
                #print(driver.getName())
                if driver.getName() == "global radiation factor":
                    #driver.setSelected(True)
                    driver.setMaxValue(1000.0)
                    #set min factor
                    driver.setMinValue(1.0)
                #print(driver.getName())
                #print(driver.getName())
                if driver.getName() == "global drag factor":
                    #driver.setSelected(True)
                    driver.setMaxValue(1000.0)
                    #set min factor
                    driver.setMinValue(1.0)
    
    return prop_builder



# %%

# Observer
class ODObserver(PythonBatchLSObserver):
    def __init__(self):
        super().__init__()
        self.residuals_data = []
        self.global_drag_factor=1.0
        self.global_radiation_factor=1.0
        

    def evaluationPerformed(self, itCounts, evCounts, orbits, orbParams, propParams, measParams, provider, lspEval):
        
        print(f'Iteration {itCounts}: Eval {evCounts}, RMS: {lspEval.getRMS():.6f}')
        drivers = propParams.getDrivers()
        # print(drivers)
        # print(int(drivers.size()))
        # print(drivers.get(0).getName())
        # print(drivers.get(0).getValue())
        # print(drivers.get(1).getName())
        # print(drivers.get(1).getValue())
        for i in range(0,int(drivers.size())):
            
            name = drivers.get(i).getName()
            value = drivers.get(i).getValue()
            print(f"Estimated {name}:{value:.6f}")
            if name=='global radiation factor':
                self.global_radiation_factor=value
            if name=='global drag factor':
                self.global_drag_factor=value
            print(name,value)
        
        
        residuals = lspEval.getResiduals()
        #print(residuals)

        for i in range(0, residuals.getDimension(), 3):
            self.residuals_data.append({
                'iteration': itCounts, 'measurement': i // 3 + 1,
                'az_residual': degrees(residuals.getEntry(i)),
                'el_residual': degrees(residuals.getEntry(i + 1)),
                'range_residual': residuals.getEntry(i + 2)
            })
        print('----------------------------------')
        #print(self.residuals_data[-1005])
        print(self.residuals_data[-1])
        print('=-=-=-=-=')

        print(lspEval.getCovariances(1e-10))
        

# %%
numerical_prop_builder=GetNumericalPropagatorBuilder(initial_orbit)
# Estimator Setup
EST_PARAMS = { 'conv_thres': 0.001, 'max_iter': 100, 'max_eval': 100}

optimizer= LevenbergMarquardtOptimizer(100.0, 1e-10, 1e-10, 1e-10, 1e-11)


# from org.hipparchus.linear import QRDecomposer
# from org.hipparchus.optim.nonlinear.vector.leastsquares import GaussNewtonOptimizer
# from org.orekit.estimation.leastsquares import BatchLSEstimator
# matrixDecomposer = QRDecomposer(1e-11)
# optimizer = GaussNewtonOptimizer(matrixDecomposer, False)


estimator = BatchLSEstimator(optimizer, numerical_prop_builder)
estimator.setParametersConvergenceThreshold(EST_PARAMS['conv_thres'])
estimator.setMaxIterations(EST_PARAMS['max_iter'])
estimator.setMaxEvaluations(EST_PARAMS['max_eval'])

for m in true_measurements:
    estimator.addMeasurement(m)
# Create and set observer
observer = ODObserver()
estimator.setObserver(observer)
estimated_props = estimator.estimate()
estimated_orbit = estimated_props[0].getInitialState().getOrbit()
print("===================================")
print("Estimated Orbit:", estimated_orbit)

# %%
print("Estimated Orbit:", estimated_orbit)

# %%

initial_cov = estimator.getPhysicalCovariances(1e-10)
print("Initial Covariance Matrix:")
print(initial_cov)
#for residuals > 3 sigma value for ra and then dec , remove them from the measurements




filtered_measurements=true_measurements

# %%
from org.hipparchus.linear import BlockRealMatrix
from org.orekit.estimation.sequential import ConstantProcessNoise
from org.orekit.estimation.sequential import KalmanEstimatorBuilder , KalmanModel
from org.hipparchus.linear import QRDecomposer
from org.orekit.orbits import OrbitType
from org.orekit.orbits import PositionAngleType
from org.orekit.estimation.sequential import KalmanObserver
from org.orekit.estimation.measurements import AngularAzEl
from orekit.pyhelpers import datetime_to_absolutedate
#
import numpy as np

# Constant process noise
constant_process_noise = ConstantProcessNoise(initial_cov, initial_cov)

# initial_covariance_matrix = np.diag([
#     1.0,        # Position X
#     1.0,        # Position Y
#     1.0,        # Position Z
#     0.01,       # Velocity X
#     0.01,       # Velocity Y
#     0.01,        # Velocity Z
#     0.001
# ])

# # Convert the numpy array to a Hipparchus BlockRealMatrix
# covariance_matrix = BlockRealMatrix(7, 7)
# for i in range(initial_covariance_matrix.shape[0]):
#     for j in range(initial_covariance_matrix.shape[1]):
#         covariance_matrix.setEntry(i, j, float(initial_covariance_matrix[i, j]))

# # Constant process noise
# constant_process_noise = ConstantProcessNoise(covariance_matrix, covariance_matrix)

#print(initial_cov)

kalman_propagator_builder=GetNumericalPropagatorBuilder(estimated_orbit)
# Kalman estimator builder
kalman_builder = KalmanEstimatorBuilder()
# from org.orekit.estimation.sequential import UnscentedKalmanEstimatorBuilder
# kalman_builder=UnscentedKalmanEstimatorBuilder()

kalman_builder.addPropagationConfiguration(kalman_propagator_builder, constant_process_noise)
# Build Kalman filter
kalman_estimator = kalman_builder.build()

## smoother
from org.orekit.estimation.sequential import RtsSmoother
rts_smoother=RtsSmoother(kalman_estimator)

from org.orekit.estimation.measurements import ObservedMeasurement
all_initial_x_errors=[]
all_initial_y_errors=[]
all_initial_r_errors=[]
all_corrected_x_errors=[]
all_corrected_y_errors=[]
all_corrected_r_errors=[]



#My Kalman Observer class
#https://forum.orekit.org/t/regarding-kalman-estimator-usage/1855/3

from org.orekit.estimation.sequential import PythonKalmanObserver
from math import degrees

class MyObserver(PythonKalmanObserver):
    def __init__(self):
        super(MyObserver, self).__init__()
    def evaluationPerformed(self, estimation):
        print('------------------------------')
        #print(estimation.getObservedMeasurement())
        
        print(estimation.getPredictedMeasurement().getEstimatedValue())
        print(estimation.getCorrectedMeasurement().getObservedValue()) 
        print(estimation.getCorrectedMeasurement().getEstimatedValue())
        pred_x,pred_y,pred_r=estimation.getPredictedMeasurement().getEstimatedValue()
        obs_x,obs_y,obs_r=estimation.getCorrectedMeasurement().getObservedValue()
        est_x,est_y,est_r=estimation.getCorrectedMeasurement().getEstimatedValue()

        all_initial_x_errors.append(degrees(pred_x-obs_x))
        all_initial_y_errors.append(degrees(pred_y-obs_y))
        all_initial_r_errors.append(pred_r-obs_r)
        all_corrected_x_errors.append(degrees(est_x-obs_x))
        all_corrected_y_errors.append(degrees(est_y-obs_y))
        all_corrected_r_errors.append(est_r-obs_r)


kalman_observer = MyObserver()
kalman_estimator.setObserver(kalman_observer)
#kalman_estimator.setObserver(rts_smoother)

# Initialize lists to store position and velocity differences
all_position_distances = []
all_velocity_differences = []
all_times=[]
all_covariances=[]

# make a java iterable of filtered measurements

#java array list
# from java.util import ArrayList
# filtered_measurements_java = ArrayList()
# for measurement in filtered_measurements:
#     filtered_measurements_java.add(measurement)
# # Process the measurements
# kalman_estimator.processMeasurements(filtered_measurements_java)






for measurement in filtered_measurements:
    # Estimation step
    kalman_propagator = kalman_estimator.estimationStep(measurement)
    
    #get time from measurment
    time_measurement=absolutedate_to_datetime(
        ObservedMeasurement.cast_(measurement).getDate()
        )
    all_times.append(time_measurement)
    all_covariances.append(kalman_estimator.getPhysicalEstimatedCovarianceMatrix())
    
    # Extract the state and covariance from the estimation result
    kalman_state = kalman_propagator[0].getInitialState()
    filtered_orbit = kalman_state.getOrbit()


from matplotlib import pyplot as plt

#plot errors
plt.figure(figsize=(10, 6))
#plt.scatter(all_times[:500],all_initial_x_errors[:500], label='PredX-ObsX',color='red',s=10)

plt.scatter(all_times,all_corrected_x_errors, label='EstX-ObsX (Kalman)',color='green',s=10)
plt.scatter(all_times,all_corrected_y_errors, label='EstY-ObsY (Kalman)',color='red',s=10)
#plt.scatter(all_times, [0]*len(all_times), color='red')
plt.xlabel('Time')
plt.ylabel('Error in degrees')
plt.title('Residuals during kalman')
plt.legend()

#calculate 3 sigma line for the errors

sigma_x_errors=np.std(all_corrected_x_errors)
sigma_y_errors=np.std(all_corrected_y_errors)

print(f"3 sigma x errors: {3*sigma_x_errors}")
print(f"3 sigma y errors: {3*sigma_y_errors}")

plt.axhline(y=3*sigma_x_errors, color='blue', linestyle='--', label='3 Sigma X Error')
plt.axhline(y=-3*sigma_x_errors, color='blue', linestyle='--')

plt.axhline(y=3*sigma_y_errors, color='orange', linestyle='--', label='3 Sigma Y Error')
plt.axhline(y=-3*sigma_y_errors, color='orange', linestyle='--')





# %%



