[Python Wrapper] - Kalman Filter slowdown

Hi,

I currently have an Orekit Kalman Filter setup and am successfully running it. Since it can take a couple of minutes to run the orbit determination, I added an update function to the KalmanObserver’s evaluationPerformed method so that I can see the progress. Specifically I am using a Python module called tqdm (https://github.com/tqdm/tqdm) which has a very low overhead.

While running through ~1500 measurements spanning 1 week (Position, Range, RangeRate, & AngularAzEl types), I noticed that the time for each step to be performed takes longer and longer. For example, it starts out processing ~30 iterations per second, however by the time it gets to the end, this number drops to 2-3 iterations per second. When I processed 2 weeks worth of data (~3700 measurements) it slowed to less than 1 iteration per second by the end, and took ~40 minutes.

I have tried initializing the JVM with more memory, running without tqdm, deleting Orekit-class variables after their use, and using the toString method to output my state vectors. However none of this has resulted in any noticeable speed increase. Is there anything that I am missing or ways to streamline the processing?

Happy to provide a code sample/measurements if it helps.

Thanks!

Hi,

Difficult to know what it can be. I find it unlikely that tqdm itself is the issue, I use it and it has indeed low overhead and is just showing the iterations.

But as you use tqdm I understand that the loop iterating is in Python. Is each iteration comparable?

I have not done systematic performance assessment of the python wrapper, but it is the calls and transfer of data between python and java that has penalty. If things are performed within the java engine, the performance shoud not be significantly different than java. But if you have for example callbacks in python that are highly iterated, the performance goes down.

As you mentioned above, not keeping old variables can free up memory in the jvm.

Let us know how you proceed, and to help further I think some test code would be useful.

Regards
/P

Hi @aerow610 and @petrus.hyvonen ,

To my mind, the issue is not link with the python wrapper.

We recently found a big issue in the performance of the Kalman Filter on Orekit. This issue was link to the estimation of the central attraction coefficient. The estimation of this parameter was duplicate at each new measurement in the filter.
At the first measurement, one central attraction coefficient is estimated. At the 30th measurement, 30 central attraction coefficients are estimated and at the nth measurement n are estimated. That’s why the time for each step to be performed takes longer and longer.

This issue is now fixed (see #598) and the fix will be available with the new release 10.1 of Orekit. Currently, the fix of this issue is just available in our development branch.

Kind regards,
Bryan

Thanks for the info @petrus.hyvonen and @bcazabonne - looks like it will be good to wait until 10.1 and then compare results.

For reference, I will attach the snippet of code that is processed on each step of the Kalman Filter, the KalmanObserver evaluationPerformed method:

filter_log = []
rej_count = []
class filt_step_observer(PythonKalmanObserver):
    def evaluationPerformed(self, est):
        kal_epoch = absolutedate_to_datetime(est.currentDate)
        kal_meas_corr = list(est.getCorrectedMeasurement().getEstimatedValue())
        obs_meas = list(est.getPredictedMeasurement().getObservedValue())

        # Residuals
        resids_corr = [kal-obs for kal, obs in zip(kal_meas_corr, obs_meas)]

        # Kalman process Corrected Measurement Status
        status = est.correctedMeasurement.getStatus().toString()

        # Cd and Cr
        coeff = {i.getName(): i.getValue() for i in est.getEstimatedPropagationParameters().getDrivers()}
        est_cd = coeff['drag coefficient'] if any('drag' in s for s in list(coeff.keys())) else c_d
        est_cr = coeff['reflection coefficient'] if any('reflection' in s for s in list(coeff.keys())) else c_r

        # State Covariance
        est_state_cov_mat = est.getPhysicalEstimatedCovarianceMatrix()
        est_state_covar = [est_state_cov_mat.getEntry(x, x) for x in range(est_state_cov_mat.getRowDimension())]

        # Create the Output Dict
        output_dict = {'Epoch': kal_epoch, 'Status': status, 'Cd': est_cd, 'Cr': est_cr, 'State_Covariance': est_state_covar}

        # Measurement Type
        m_type = est.correctedMeasurement.observedMeasurement.getClass()
        s_type = ['Val1', 'Val2', 'Val3', 'Val4', 'Val5', 'Val6']
        if m_type in (Position.class_, PV.class_):
            output_dict['Tracker'] = 'GPS'
            output_dict['Meas_Type'] = 'NavSol'
            sigma = tle_meas_white_noise_sigma
        else:
            trkr = list(est.correctedMeasurement.observedMeasurement.getParametersDrivers())[0]
            output_dict['Tracker'] = trkr.toString().split('-offset')[0].upper()
            if m_type == Range.class_:
                output_dict['Meas_Type'] = 'Range'
                sigma = range_white_noise_sigma
            elif m_type == RangeRate.class_:
                output_dict['Meas_Type'] = 'Doppler'
                sigma = rangerate_white_noise_sigma
            elif m_type == AngularAzEl.class_:
                output_dict['Meas_Type'] = 'AzEl'
                sigma = angles_white_noise_sigma

        # Add Corrected residuals to output
        for m_pairs in list(zip(s_type, resids_corr)):
            output_dict[m_pairs[0] + '_corr'] = m_pairs[1]
        for m_pairs in list(zip(s_type, resids_corr)):
            output_dict[m_pairs[0] + '_ratios'] = m_pairs[1] / sigma

        # Add to the log
        filter_log.append(output_dict)

        # Update the progressbar
        pbar.update()

        # Exit if filter is diverging
        if status == 'REJECTED':
            rej_count.append(status)
            if len(rej_count) > filter_diverge_threshold:
                print(f'\n\nFATAL ERROR: Filter has rejected {filter_diverge_threshold} measurements in a row - exiting\n', flush=True)
                raise ValueError
        else:
            rej_count.clear()

And the code that runs the Kalman Filter itself (meas = ArrayList().of_(ObservedMeasurement)):

# Run the filter
print(f'\nFilter starting at {meas.get(0).date.toString()}', flush=True)
with tqdm(total=meas.size(), desc='Running the Filter') as pbar:
    try:
        final_state = filt.processMeasurements(meas)
    except:
        write_ephem = False
        filter_out_path = False
        predict_time = False
        show_rejected = True

@bcazabonne - just upgraded to the 10.1 python wrapper (thanks to @petrus.hyvonen for getting that out so fast) and the kalman filter is significantly faster with no more slowdown (just ran through 4000+ measurements).

Thanks!

1 Like

Great to hear!