import codecs
import numpy as np
import pandas as pd

from astropy import coordinates
from astropy.time import Time

from datetime import datetime, timedelta

import orekit
from orekit import JArray

from org.hipparchus.optim.nonlinear.vector.leastsquares import LevenbergMarquardtOptimizer
from org.hipparchus.geometry.euclidean.threed import Vector3D

from org.orekit.utils import IERSConventions, Constants as orekit_constants, PVCoordinates

from org.orekit.orbits import CartesianOrbit, PositionAngle, OrbitType, KeplerianOrbit, CircularOrbit

from org.orekit.frames import FramesFactory, TopocentricFrame

from org.orekit.models.earth import ReferenceEllipsoid
from org.orekit.models.earth.atmosphere.data import MarshallSolarActivityFutureEstimation
from org.orekit.models.earth.atmosphere import NRLMSISE00

from org.orekit.time import AbsoluteDate, TimeScalesFactory

from orekit.pyhelpers import setup_orekit_curdir, datetime_to_absolutedate, download_orekit_data_curdir
from org.orekit.attitudes import NadirPointing

from org.orekit.forces.gravity.potential import GravityFieldFactory
from org.orekit.forces.gravity import HolmesFeatherstoneAttractionModel, ThirdBodyAttraction, Relativity
from org.orekit.forces.radiation import IsotropicRadiationSingleCoefficient, SolarRadiationPressure
from org.orekit.forces.drag import IsotropicDrag, DragForce

from org.orekit.bodies import CelestialBodyFactory, GeodeticPoint

vm = orekit.initVM()

from org.orekit.estimation.iod import IodGibbs
from org.orekit.estimation.leastsquares import BatchLSEstimator
from org.orekit.estimation.measurements import (GroundStation, AngularRaDec, ObservableSatellite)

from org.orekit.propagation import SpacecraftState
from org.orekit.propagation.conversion import DormandPrince853IntegratorBuilder, NumericalPropagatorBuilder

from org.orekit.utils import ParameterDriver, ParameterObserver, ParameterDriversList
from org.orekit.forces.radiation import RadiationSensitive

vm = orekit.initVM()

setup_orekit_curdir('./')

try: 
    utc = TimeScalesFactory.getUTC()
except:
    download_orekit_data_curdir('./')
    utc = TimeScalesFactory.getUTC()
    
print ('Java version:',vm.java_version)
print ('Orekit version:', orekit.VERSION)

#%%
fn = 'new_20230823_10122_20744.RES'

def ReadObsData(datafile = fn):
    
    f = codecs.open(datafile, encoding='cp1252')
    datalines = f.readlines()
    data = np.loadtxt(datalines[1:-1], skiprows = 0, dtype='str')
    Datestr = data[:,0]
    Tobsstr = data[:,1]
    RAstr = data[:,2]
    DEstr = data[:,3]
    RES = data[:,4]

    RA = []
    DEC = []
    Date = []
    measurementDate=[]
    for date, t, ra, dec, err in zip(Datestr, Tobsstr, RAstr, DEstr, RES):
        YEAR = f'{2000+int(date[4:6])}'
        MONTH = f'{int(date[2:4]):02d}'
        DAY = f'{int(date[0:2]):02d}'
        HH = f'{int(t[0:2]):02d}'
        MM = f'{int(t[2:4]):02d}'
        SS = f'{int(t[4:6]):02d}'
        mSS = f'{int(t[6:8]):02d}'
        
        Date.append(f'{YEAR}-{MONTH}-{DAY}T{HH}:{MM}:{SS}.{mSS}')
        RA.append(f'{ra[0:2]}:{ra[2:4]}:{ra[4:6]}.{ra[6:8]}')
        DEC.append(f'{dec[0:3]}:{dec[3:5]}:{dec[5:7]}.{dec[7:9]}')
        
        measurementDate.append( datetime(int(YEAR), int(MONTH), int(DAY)) + 
                           timedelta(hours=int(HH), minutes=int(MM), seconds=float(SS),
                                     microseconds=float(mSS)*1e4)
                          )
    
    return RA, DEC, Date, measurementDate

#%%
def OrbitEstimation(RA, DEC, TimeObs):

    t1 = Time(TimeObs[0], format='isot', scale='utc')
    t2 = Time(TimeObs[1], format='isot', scale='utc')
    t3 = Time(TimeObs[2], format='isot', scale='utc')
    
    RA = np.array([15*(float(RA[0].split(':')[0]) + 
                       float(RA[0].split(':')[1])/60 + 
                       float(RA[0].split(':')[2])/3600), 
                   15*(float(RA[1].split(':')[0]) + 
                       float(RA[1].split(':')[1])/60 + 
                       float(RA[1].split(':')[2])/3600), 
                   15*(float(RA[2].split(':')[0]) + 
                       float(RA[2].split(':')[1])/60 + 
                       float(RA[2].split(':')[2])/3600)])
    SIGN = np.sign(float(DEC[0].split(':')[0]))
    
    DEC = np.array([SIGN*(abs(float(DEC[0].split(':')[0])) + 
                          float(DEC[0].split(':')[1])/60 + 
                          float(DEC[0].split(':')[2])/3600), 
                    SIGN*(abs(float(DEC[1].split(':')[0])) + 
                          float(DEC[1].split(':')[1])/60 + 
                          float(DEC[1].split(':')[2])/3600), 
                    SIGN*(abs(float(DEC[2].split(':')[0])) + 
                          float(DEC[2].split(':')[1])/60 + 
                          float(DEC[2].split(':')[2])/3600)])

    ### for TSO    
    LAT = TSO_lat # North
    LON = TSO_lon # West
    ALT = TSO_alt # meters
    
    Mu = orekit_constants.WGS84_EARTH_MU
    
    #
    tau1 = (t1.jd - t2.jd)*24*3600
    tau3 = (t3.jd - t2.jd)*24*3600
    
    Li_unit = np.array([np.cos(np.radians(DEC))*np.cos(np.radians(RA)),
                        np.cos(np.radians(DEC))*np.sin(np.radians(RA)),
                        np.sin(np.radians(DEC))])
        
    # ECI - Erth-Centered Inertial
    ECI = FramesFactory.getGCRF()
    ECEF = FramesFactory.getITRF(IERSConventions.IERS_2010, True)
    
    geodeticPoint = GeodeticPoint(float(np.radians(LAT)), float(np.radians(LON)), float(ALT))
    wgs84Ellipsoid = ReferenceEllipsoid.getWgs84(ECEF)
    topocentricFrame = TopocentricFrame(wgs84Ellipsoid, geodeticPoint, 'OBS')
    
    r1_site_ECI = topocentricFrame.getPVCoordinates(datetime_to_absolutedate(t1.datetime), 
                                                   ECI).getPosition()
    r2_site_ECI = topocentricFrame.getPVCoordinates(datetime_to_absolutedate(t2.datetime), 
                                                   ECI).getPosition()
    r3_site_ECI = topocentricFrame.getPVCoordinates(datetime_to_absolutedate(t3.datetime), 
                                                   ECI).getPosition()
    
    r_site = np.array([[r1_site_ECI.getX(), r2_site_ECI.getX(), r3_site_ECI.getX()],
                       [r1_site_ECI.getY(), r2_site_ECI.getY(), r3_site_ECI.getY()],
                       [r1_site_ECI.getZ(), r2_site_ECI.getZ(), r3_site_ECI.getZ()]])
    #
    a1 = tau3 / (tau3 - tau1)
    a1u = tau3*((tau3-tau1)**2 - tau3**2) / (6*(tau3-tau1))
    a3 = -tau1 / (tau3 - tau1)
    a3u = -tau1*((tau3-tau1)**2 - tau1**2) / (6*(tau3-tau1))
    
    #
    Li_inv = np.linalg.inv(Li_unit)
    
    M = np.matmul(Li_inv, r_site)
    
    d1 = M[1][0]*a1 - M[1][1] + M[1][2]*a3
    d2 = M[1][0]*a1u + M[1][2]*a3u
    C  = np.dot(Li_unit[:,1], r_site[:,1])
    
    #  find roots of 8-th polinomials   x**8 + A*r*6 +B*x**3 + C = 0
    
    P = np.zeros((9,))
    P[0]=1
    P[2]=-(d1**2 + 2*C*d1 + np.linalg.norm(r_site[:,1])**2)
    P[5]=-2*Mu*(C*d2 + d1*d2)
    P[8] = -Mu**2 * d2**2 
    
    roots = np.roots(P)
    
    PosRoots = []
    for item in roots:
        if np.isreal(item) and item > 0:
            PosRoots.append(item)
    
    if len(PosRoots) == 1:
        r2 = PosRoots[0].real
    else:
        pass
    
    #
    u=Mu/r2**3
    
    c1 = -(-a1 -a1u*u)
    c2 = -1
    c3 = -(-a3 -a3u*u)
    
    Mc = np.matmul(M, np.array([[-c1], [-c2], [-c3]]))
    
    C = np.identity(3)
    C[0,0] = c1
    C[1,1] = c2
    C[2,2] = c3
    
    rho = np.linalg.solve(C,Mc)
    
    #
    position = []
    for i in range(3):
        pos = rho[i]*Li_unit[:, i] + r_site[:, i]
        pos = np.transpose(pos) 
        position = np.concatenate((position, pos))
    
    position = position.reshape((3, 3))
    
    posR1 = Vector3D(float(position[0][0]), float(position[0][1]), float(position[0][2])) # Position of first observation.
    posR2 = Vector3D(float(position[1][0]), float(position[1][1]), float(position[1][2])) # Position of second observation.
    posR3 = Vector3D(float(position[2][0]), float(position[2][1]), float(position[2][2])) # Position of third observation.
    
    # Initialization of Gibbs IOD Method
    gibbs = IodGibbs(Mu)
    
    # Gibbs IOD orbit estimation
    estimated_orbit = gibbs.estimate(ECI, 
                                     posR1, datetime_to_absolutedate(t1.datetime), 
                                     posR2, datetime_to_absolutedate(t2.datetime), 
                                     posR3, datetime_to_absolutedate(t3.datetime))
    print('\033[1m'+'\nGIBBS Kepler elements:\n '+'\033[0m', estimated_orbit)

    return estimated_orbit

#%%   
def InitiatePropagator(sat_orbit, sat_cs, sat_cr, sat_cd, sat_mass):
    
    minStep = 0.0001
    maxStep = 100.0
    pos_error = 10.0
    estimator_position_scale = 200.0

    ECI = FramesFactory.getGCRF()
    ECEF = FramesFactory.getITRF(IERSConventions.IERS_2010, True)
    
    integratorBuilder = DormandPrince853IntegratorBuilder(minStep, maxStep, pos_error)
    propagatorBuilder = NumericalPropagatorBuilder(sat_orbit, integratorBuilder, PositionAngle.TRUE, 
                                               estimator_position_scale)

    wgs84Ellipsoid = ReferenceEllipsoid.getWgs84(ECEF)
    nadirPointing = NadirPointing(ECI, wgs84Ellipsoid)
    
    propagatorBuilder.setMass(sat_mass)
    propagatorBuilder.setAttitudeProvider(nadirPointing)
    
    ##### Solar radiation pressure
    sun = CelestialBodyFactory.getSun()
    
    initial_value = 1.5
    scale_factor = 0.1
    lower_bound = 1.0
    upper_bound = 2.0
 
    isotropicRadiationSingleCoeff = IsotropicRadiationSingleCoefficient(float(sat_cs), initial_value)
    solarRadiationPressure = SolarRadiationPressure(sun, wgs84Ellipsoid.getEquatorialRadius(), isotropicRadiationSingleCoeff)

    reflection_param = solarRadiationPressure.getParametersDrivers().get(0)
    reflection_param.setMinValue(lower_bound)
    reflection_param.setMaxValue(upper_bound)
    reflection_param.setReferenceValue(initial_value)
    reflection_param.setScale(scale_factor)
    reflection_param.setSelected(True)

    propagatorBuilder.addForceModel(solarRadiationPressure)

    return propagatorBuilder

#%%

RA_Obs, DEC_Obs, Date_Obs, DateTime_Obs = ReadObsData(datafile = fn)

### the values we used all the time
TSO_lat = 43.225278
TSO_lon = 77.870555
TSO_alt = 2658.3
TSO_CODE = '217'#'N42'

TSO_Data = pd.DataFrame(columns=['CODE', 'Latitude', 'Longitude', 'Altitude', 'OrekitGroundStation'])
geodeticPoint = GeodeticPoint(float(np.deg2rad(TSO_lat)), float(np.deg2rad(TSO_lon)), TSO_alt)
wgs84Ellipsoid = ReferenceEllipsoid.getWgs84(FramesFactory.getITRF(IERSConventions.IERS_2010, True))

# топоцентрическая система координат , привязанная к нашему пункту
topocentricFrame = TopocentricFrame(wgs84Ellipsoid, geodeticPoint, TSO_CODE)
groundStation = GroundStation(topocentricFrame)

# uncomment below for working version on June 22, 2022 !
groundStation.getPrimeMeridianOffsetDriver().setReferenceDate(AbsoluteDate.J2000_EPOCH)
groundStation.getPolarOffsetXDriver().setReferenceDate(AbsoluteDate.J2000_EPOCH);
groundStation.getPolarOffsetYDriver().setReferenceDate(AbsoluteDate.J2000_EPOCH);

TSO_Data.loc[0] = [TSO_CODE, TSO_lat, TSO_lon, TSO_alt, groundStation]

#%%

ra_weight = 1.0  # Веса угловых измерений в сравнении с измерениями расстояний. Все это будет иметь существенное
dec_weight = 1.0 # значение когда у нас одновременно будут данные из наблюдений и РЦКС

ra_sigma = 1.0*orekit_constants.ARC_SECONDS_TO_RADIANS  # приблизительная оценка точности измерений RA и DEC. Более точное значение 
                                     # можно получить из RES файлов (результат астрометрии Апексом)
dec_sigma = 1.0*orekit_constants.ARC_SECONDS_TO_RADIANS # мы предполагаем поумолчанию точность в 1 угловую секунду

# Параметры для алгоритма оценки орбиты
estimator_position_scale = 100.0 # метры
estimator_convergence_thres = 1e-5
estimator_max_iterations = 100
estimator_max_evaluations = 100

## работаем в системе координат GCRF (ECI)
## Gibbs IOD orbit estimation from Double_r_Assy.py , for example

estimated_orbit = OrbitEstimation([RA_Obs[0],RA_Obs[int(len(RA_Obs)/2)], RA_Obs[-1]], 
                                  [DEC_Obs[0],DEC_Obs[int(len(RA_Obs)/2)], DEC_Obs[-1]],
                                  [Date_Obs[0], Date_Obs[int(len(RA_Obs)/2)], Date_Obs[-1]])                           

#%%
ECI = FramesFactory.getGCRF()
PV_ECI = estimated_orbit.getPVCoordinates()
#estimated_orbit = CartesianOrbit(PV_ECI, ECI, wgs84Ellipsoid.getGM())   #uncomment to get PV as output for BatchLS

#%%
sat_cd = float(1.0)
sat_cs = float(10.0)
sat_cr = float(1.5)
sat_mass = float(1700.0)
propagatorBuilder = InitiatePropagator(estimated_orbit, sat_cs, sat_cr, 
                                       sat_cd, sat_mass)

optimizer = LevenbergMarquardtOptimizer()
estimator = BatchLSEstimator(optimizer, propagatorBuilder)
estimator.setParametersConvergenceThreshold(estimator_convergence_thres)
estimator.setMaxIterations(estimator_max_iterations)
estimator.setMaxEvaluations(estimator_max_evaluations)

#%%

for RA, DEC, date, t in zip(RA_Obs, DEC_Obs, Date_Obs, DateTime_Obs):
    
    SatRaDec = coordinates.SkyCoord(RA,DEC,unit=('hour','deg'), frame='icrs')
    err_RA = float(1.0)*orekit_constants.ARC_SECONDS_TO_RADIANS
    err_DEC = float(1.0)*orekit_constants.ARC_SECONDS_TO_RADIANS
    measurementDate = t
    
    SatRaDec = coordinates.SkyCoord(RA,DEC,unit=('hour','deg'), frame='icrs')
    
    ra_weight = 1.0
    dec_weight = 1.0
    
    RA_rad = SatRaDec.ra.rad 
    DEC_rad = SatRaDec.dec.rad
    
    observableSatellite = ObservableSatellite(0)
  
    orekitRaDec = AngularRaDec(TSO_Data.loc[0, 'OrekitGroundStation'],
                               FramesFactory.getEME2000(),  # нужно ли тут переводить в TEME на эпоху наблюдений ??? 
                               datetime_to_absolutedate(measurementDate),
                               JArray('double')([RA_rad, DEC_rad]),
                               JArray('double')([err_RA, err_DEC]),
                               JArray('double')([ra_weight, dec_weight]),
                               observableSatellite)
    
    estimator.addMeasurement(orekitRaDec)
   
# Get the parameter driver for the semimajor axis
parameter_driver = estimator.getOrbitalParametersDrivers(True).getDrivers().get(0)
min_semimajor_axis = float(orekit_constants.WGS84_EARTH_EQUATORIAL_RADIUS)
parameter_driver.setMinValue(min_semimajor_axis)

estimatedPropagatorArray = estimator.estimate()

estimatedPropagator = estimatedPropagatorArray[0]
estimatedInitialState = estimatedPropagator.getInitialState()
Orbit = estimatedInitialState.getOrbit()

#%%
KeplerOrbit = KeplerianOrbit(estimatedInitialState.getPVCoordinates(), ECI,
						      estimatedInitialState.getDate(), orekit_constants.EGM96_EARTH_MU)
print('\033[1m'+'\n Fitting Kepler elements:\n '+'\033[0m', KeplerOrbit)
    
