import dataclasses
import datetime as dt
import numpy as np

import orekit_jpype

orekit_jpype.initVM()

from orekit_jpype import pyhelpers

from java import util as java_util

from org.hipparchus import linear as hipp_linear
from org.hipparchus.analysis import polynomials
from org.hipparchus.geometry.euclidean import threed

from org.orekit import frames, orbits, propagation, time, utils


pyhelpers.setup_orekit_curdir("data/orekit-data")


# Constants
EME2000 = frames.FramesFactory.getEME2000()
UTC = time.TimeScalesFactory.getUTC()


# Data
@dataclasses.dataclass(frozen=True)
class Ephemeris:
    epoch: dt.datetime
    state_vector: np.ndarray
    covariance_matrix: np.ndarray


EPHEMERIS = [
    Ephemeris(
        epoch=dt.datetime.fromisoformat("2026-03-04T13:08:27.172000+00:00"),
        state_vector=np.array(
            [
                4.81955102e06,
                -5.32656111e06,
                4.13613788e04,
                -8.04180942e02,
                -7.96482651e02,
                -7.36437750e03,
            ]
        ),
        covariance_matrix=np.array(
            [
                [
                    2.19003854e-01,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    2.29333216e01,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    2.13371605e-01,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    2.14954180e-05,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    2.35942558e-07,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    3.14765697e-07,
                ],
            ]
        ),
    ),
    Ephemeris(
        epoch=dt.datetime.fromisoformat("2026-03-04T14:38:27.172000+00:00"),
        state_vector=np.array(
            [
                4.22891921e06,
                -3.63775409e06,
                4.51588650e06,
                2.54330829e03,
                -4.10799610e03,
                -5.67568978e03,
            ]
        ),
        covariance_matrix=np.array(
            [
                [
                    1.35770307e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    1.25959336e02,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    1.99415313e-01,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    1.24599024e-04,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    1.51905612e-06,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    2.06758044e-07,
                ],
            ]
        ),
    ),
    Ephemeris(
        epoch=dt.datetime.fromisoformat("2026-03-04T16:08:27.172000+00:00"),
        state_vector=np.array(
            [
                1.72889099e06,
                -3.15359930e05,
                6.95416591e06,
                4.74950732e03,
                -5.56266612e03,
                -1.43035538e03,
            ]
        ),
        covariance_matrix=np.array(
            [
                [
                    2.15396014e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    1.56165962e02,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    2.00698604e-01,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    1.56968228e-04,
                    0.00000000e00,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    2.30953827e-06,
                    0.00000000e00,
                ],
                [
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    0.00000000e00,
                    1.93678593e-07,
                ],
            ]
        ),
    ),
]


def array_to_hipparchus_realmatrix(array):
    rows, cols = array.shape
    realmatrix = hipp_linear.MatrixUtils.createRealMatrix(rows, cols)
    for i in range(rows):
        realmatrix.setRow(i, array[i].tolist())

    return realmatrix


def build_covariance_blender():
    # Set blending function
    blending_func = polynomials.SmoothStepFactory.getQuadratic()

    # Create orbit interpolator
    hermite_orbit_interp = orbits.OrbitHermiteInterpolator(2, EME2000)

    # Create covariance blender
    covariance_blender = propagation.StateCovarianceBlender(
        blending_func,
        hermite_orbit_interp,
        EME2000,
        orbits.OrbitType.CARTESIAN,
        orbits.PositionAngleType.MEAN,
    )

    return covariance_blender


def create_orbit(state):
    epoch = state.epoch
    sv = state.state_vector

    # Convert data types
    pos = threed.Vector3D(float(sv[0]), float(sv[1]), float(sv[2]))
    vel = threed.Vector3D(float(sv[3]), float(sv[4]), float(sv[5]))
    # abs_date = pyhelpers.datetime_to_absolutedate(epoch)  # Accurate datetime conversion
    abs_date = time.AbsoluteDate(  # Inaccurate datetime conversion, causing the error
        epoch.year,
        epoch.month,
        epoch.day,
        epoch.hour,
        epoch.minute,
        epoch.second + epoch.microsecond / 1e6,  # Equal to 27.172 s but some noise is added as part of the constructor
        UTC,
    )
    abs_pv = utils.AbsolutePVCoordinates(EME2000, abs_date, pos, vel)

    return orbits.CartesianOrbit(
        abs_pv, EME2000, utils.Constants.EIGEN5C_EARTH_MU
    )


def prepare_time_stamped_pair_list():
    pair_list = java_util.ArrayList()

    for state in EPHEMERIS:
        orbit = create_orbit(state)

        hipp_cov = array_to_hipparchus_realmatrix(state.covariance_matrix)
        orekit_cov = propagation.StateCovariance(
            hipp_cov,
            pyhelpers.datetime_to_absolutedate(state.epoch),
            EME2000,
            orbits.OrbitType.CARTESIAN,
            orbits.PositionAngleType.MEAN,
        )

        pair_list.add(time.TimeStampedPair(orbit, orekit_cov))

    return pair_list


if __name__ == "__main__":
    # Create covariance blender
    covariance_blender = build_covariance_blender()

    # Prepare list of interpolation points
    time_stamped_pair_list = prepare_time_stamped_pair_list()

    # Interpolate
    interpolation_epoch = EPHEMERIS[0].epoch
    orekit_epoch = pyhelpers.datetime_to_absolutedate(interpolation_epoch)
    interpolated_pair = covariance_blender.interpolate(
        orekit_epoch, time_stamped_pair_list
    )
