diff --git a/src/naima/plot.py b/src/naima/plot.py index 72c610d..b0f693f 100644 --- a/src/naima/plot.py +++ b/src/naima/plot.py @@ -623,12 +623,11 @@ def plot_samples( Unit in which to plot energy axis. e_range : list of `~astropy.units.Quantity`, length 2, optional Limits in energy for the computation of the model samples and ML model. - Note that setting this parameter will mean that the samples for the - model are recomputed and depending on the model speed might be quite - slow. + Note that model re-evaluation may be slow depending on the model + complexity. e_npoints : int, optional - How many points to compute for the model samples and ML model if - `e_range` is set. + How many points to compute for the model samples and ML model over + the energy range. Default is 100. threads : int, optional How many parallel processing threads to use when computing the samples. Defaults to the number of available cores. @@ -805,12 +804,14 @@ def plot_fit( of the energy array of the observed data. e_range : list of `~astropy.units.Quantity`, length 2, optional Limits in energy for the computation of the model samples and ML model. - Note that setting this parameter will mean that the samples for the - model are recomputed and depending on the model speed might be quite - slow. + When ``modelfn`` is available on the sampler and ``e_range`` is not + provided, it defaults to a range extending a factor of 3 beyond the + data energy limits so that the model is always re-evaluated over a + dense grid. Note that model re-evaluation may be slow depending on + the model complexity. e_npoints : int, optional - How many points to compute for the model samples and ML model if - `e_range` is set. + How many points to compute for the model samples and ML model over + the energy range. Default is 100. threads : int, optional How many parallel processing threads to use when computing the samples. Defaults to the number of available cores. @@ -828,7 +829,14 @@ def plot_fit( """ import matplotlib.pyplot as plt - ML, MLp, MLerr, model_ML = find_ML(sampler, modelidx) + data = sampler.data + + if e_range is None and hasattr(sampler, "modelfn") and sampler.modelfn is not None: + e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0)) + + ML, MLp, MLerr, model_ML = _calc_ML( + sampler, modelidx, e_range=e_range, e_npoints=e_npoints + ) infostr = "Maximum log probability: {0:.3g}\n".format(ML) infostr += "Maximum Likelihood values:\n" maxlen = np.max([len(ilabel) for ilabel in sampler.labels]) @@ -838,11 +846,6 @@ def plot_fit( # log.info(infostr) - data = sampler.data - - if e_range is None and not hasattr(sampler, "blobs"): - e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0)) - if plotdata is None and len(model_ML[0]) == len(data["energy"]): model_unit, _ = sed_conversion(model_ML[0], model_ML[1].unit, sed) data_unit, _ = sed_conversion(data["energy"], data["flux"].unit, sed) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c7ef8a9..f01e0e9 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -156,3 +156,28 @@ def test_diagnostic_plots(sampler): def test_diagnostic_plots_noblobs(noblob_sampler): # Diagnostic plots save_diagnostic_plots("test_function_noblob", noblob_sampler) + + +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") +def test_calc_ML_uses_dense_energy_grid(sampler): + """Regression test for #240: _calc_ML should evaluate over a dense grid.""" + from naima.plot import _calc_ML + + data = sampler.data + e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0)) + + ML, MLp, MLerr, (modelx, model_ML) = _calc_ML(sampler, 0, e_range=e_range) + + # Model should be evaluated over a dense logspace grid, not the sparse + # data energy points + assert len(modelx) == 100 + assert len(modelx) != len(data["energy"]) + + # Re-evaluate directly over the same grid and check consistency + eval_data = { + "energy": modelx, + "flux": np.zeros(modelx.shape) * data["flux"].unit, + } + direct_out = sampler.modelfn(MLp, eval_data) + direct_flux = direct_out[0] if isinstance(direct_out, (tuple, list)) else direct_out + np.testing.assert_allclose(model_ML.value, direct_flux.value, rtol=1e-10)