Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions src/naima/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,11 @@ def plot_samples(
Unit in which to plot energy axis.
e_range : list of `~astropy.units.Quantity`, length 2, optional
Limits in energy for the computation of the model samples and ML model.
Note that setting this parameter will mean that the samples for the
model are recomputed and depending on the model speed might be quite
slow.
Note that model re-evaluation may be slow depending on the model
complexity.
e_npoints : int, optional
How many points to compute for the model samples and ML model if
`e_range` is set.
How many points to compute for the model samples and ML model over
the energy range. Default is 100.
threads : int, optional
How many parallel processing threads to use when computing the samples.
Defaults to the number of available cores.
Expand Down Expand Up @@ -805,12 +804,14 @@ def plot_fit(
of the energy array of the observed data.
e_range : list of `~astropy.units.Quantity`, length 2, optional
Limits in energy for the computation of the model samples and ML model.
Note that setting this parameter will mean that the samples for the
model are recomputed and depending on the model speed might be quite
slow.
When ``modelfn`` is available on the sampler and ``e_range`` is not
provided, it defaults to a range extending a factor of 3 beyond the
data energy limits so that the model is always re-evaluated over a
dense grid. Note that model re-evaluation may be slow depending on
the model complexity.
e_npoints : int, optional
How many points to compute for the model samples and ML model if
`e_range` is set.
How many points to compute for the model samples and ML model over
the energy range. Default is 100.
threads : int, optional
How many parallel processing threads to use when computing the samples.
Defaults to the number of available cores.
Expand All @@ -828,7 +829,14 @@ def plot_fit(
"""
import matplotlib.pyplot as plt

ML, MLp, MLerr, model_ML = find_ML(sampler, modelidx)
data = sampler.data

if e_range is None and hasattr(sampler, "modelfn") and sampler.modelfn is not None:
e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0))

ML, MLp, MLerr, model_ML = _calc_ML(
sampler, modelidx, e_range=e_range, e_npoints=e_npoints
)
infostr = "Maximum log probability: {0:.3g}\n".format(ML)
infostr += "Maximum Likelihood values:\n"
maxlen = np.max([len(ilabel) for ilabel in sampler.labels])
Expand All @@ -838,11 +846,6 @@ def plot_fit(

# log.info(infostr)

data = sampler.data

if e_range is None and not hasattr(sampler, "blobs"):
e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0))

if plotdata is None and len(model_ML[0]) == len(data["energy"]):
model_unit, _ = sed_conversion(model_ML[0], model_ML[1].unit, sed)
data_unit, _ = sed_conversion(data["energy"], data["flux"].unit, sed)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,28 @@ def test_diagnostic_plots(sampler):
def test_diagnostic_plots_noblobs(noblob_sampler):
# Diagnostic plots
save_diagnostic_plots("test_function_noblob", noblob_sampler)


@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE")
def test_calc_ML_uses_dense_energy_grid(sampler):
"""Regression test for #240: _calc_ML should evaluate over a dense grid."""
from naima.plot import _calc_ML

data = sampler.data
e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0))

ML, MLp, MLerr, (modelx, model_ML) = _calc_ML(sampler, 0, e_range=e_range)

# Model should be evaluated over a dense logspace grid, not the sparse
# data energy points
assert len(modelx) == 100
assert len(modelx) != len(data["energy"])

# Re-evaluate directly over the same grid and check consistency
eval_data = {
"energy": modelx,
"flux": np.zeros(modelx.shape) * data["flux"].unit,
}
direct_out = sampler.modelfn(MLp, eval_data)
direct_flux = direct_out[0] if isinstance(direct_out, (tuple, list)) else direct_out
np.testing.assert_allclose(model_ML.value, direct_flux.value, rtol=1e-10)
Loading