Skip to content

Commit 7c111a7

Browse files
Added test for testing plotting
1 parent 4ec2e0b commit 7c111a7

2 files changed

Lines changed: 79 additions & 2 deletions

File tree

pypfopt/plotting.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,24 @@ def plot_covariance(
103103
cax = ax.imshow(matrix)
104104
fig.colorbar(cax)
105105

106+
# if show_tickers:
107+
# ax.set_xticks(np.arange(0, matrix.shape[0], 1))
108+
# ax.set_xticklabels(matrix.index)
109+
# ax.set_yticks(np.arange(0, matrix.shape[0], 1))
110+
# ax.set_yticklabels(matrix.index)
111+
# plt.xticks(rotation=90)
106112
if show_tickers:
107113
ax.set_xticks(np.arange(0, matrix.shape[0], 1))
108-
ax.set_xticklabels(matrix.index)
114+
# Handle both DataFrame and ndarray for tick labels
115+
if hasattr(matrix, "index"):
116+
labels = matrix.index
117+
else:
118+
# For numpy array, create generic labels
119+
labels = [f"Asset {i + 1}" for i in range(matrix.shape[0])]
120+
121+
ax.set_xticklabels(labels)
109122
ax.set_yticks(np.arange(0, matrix.shape[0], 1))
110-
ax.set_yticklabels(matrix.index)
123+
ax.set_yticklabels(labels)
111124
plt.xticks(rotation=90)
112125

113126
# Optional: overlay numeric values on each cell

tests/test_plotting.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,67 @@ def test_plot_efficient_frontier():
416416
ef = setup_efficient_frontier()
417417
ef.min_volatility()
418418
optimal_ret, optimal_risk, _ = ef.portfolio_performance(risk_free_rate=0.02)
419+
420+
421+
@pytest.mark.skipif(
422+
not _check_soft_dependencies(["matplotlib"], severity="none"),
423+
reason="skip test if matplotlib is not installed in environment",
424+
)
425+
def test_plot_covariance_show_values():
426+
import matplotlib.pyplot as plt
427+
import numpy as np
428+
import pandas as pd
429+
430+
# Simple 3x3 covariance matrix
431+
cov_data = np.array(
432+
[[0.04, 0.01, 0.002], [0.01, 0.09, 0.003], [0.002, 0.003, 0.16]]
433+
)
434+
tickers = ["A", "B", "C"]
435+
df = pd.DataFrame(cov_data, index=tickers, columns=tickers)
436+
437+
def count_texts(ax):
438+
return len([obj for obj in ax.findobj() if obj.__class__.__name__ == "Text"])
439+
440+
# Test with ndarray input, show_values=False (baseline)
441+
plt.figure()
442+
ax = plotting.plot_covariance(cov_data, showfig=False)
443+
baseline_texts = count_texts(ax)
444+
plt.clf()
445+
plt.close()
446+
447+
# Test with ndarray input, show_values=True
448+
plt.figure()
449+
ax = plotting.plot_covariance(cov_data, show_values=True, showfig=False)
450+
with_values_texts = count_texts(ax)
451+
plt.clf()
452+
plt.close()
453+
454+
# Expect more text annotations when show_values=True
455+
assert with_values_texts > baseline_texts
456+
457+
# Test with DataFrame input, show_values=False
458+
plt.figure()
459+
ax = plotting.plot_covariance(df, showfig=False)
460+
baseline_texts_df = count_texts(ax)
461+
plt.clf()
462+
plt.close()
463+
464+
# Test with DataFrame input, show_values=True
465+
plt.figure()
466+
ax = plotting.plot_covariance(df, show_values=True, showfig=False)
467+
with_values_texts_df = count_texts(ax)
468+
plt.clf()
469+
plt.close()
470+
471+
assert with_values_texts_df > baseline_texts_df
472+
473+
# Ensure saving still works
474+
with tempfile.TemporaryDirectory() as tmpdir:
475+
fname = f"{tmpdir}/cov_plot.png"
476+
ax = plotting.plot_covariance(
477+
df, show_values=True, filename=fname, showfig=False
478+
)
479+
assert os.path.exists(fname)
480+
assert os.path.getsize(fname) > 0
481+
plt.clf()
482+
plt.close()

0 commit comments

Comments
 (0)