Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
aec99f0
add hvp
pierrenodet Feb 2, 2026
ad013e5
remove trace estimation algorithm (do that in a second PR)
pierrenodet Feb 3, 2026
23364ee
implicit mmap for hessian
pierrenodet Feb 4, 2026
d5db078
Merge remote-tracking branch 'origin' into block-cg
pierrenodet Feb 4, 2026
c8397de
block_cg
pierrenodet Feb 6, 2026
e59da60
func jacobian backend
pierrenodet Feb 6, 2026
6fba6f8
better dosctring for block_cg
pierrenodet Feb 6, 2026
43ce6f3
update api func jacobian backend
pierrenodet Feb 8, 2026
ca8fb84
detach model weights for reduced memory footprint (don't allow to tak…
pierrenodet Feb 8, 2026
28324eb
even more tests
pierrenodet Feb 9, 2026
03189d8
Merge remote-tracking branch 'origin' into hess-imp-mmap
pierrenodet Feb 9, 2026
9a35678
remove compile option as i don't really know what it does
pierrenodet Feb 9, 2026
a99ceba
Merge branch 'hess-imp-mmap' into block-cg
pierrenodet Feb 10, 2026
f6f0109
more tests
pierrenodet Feb 10, 2026
9d3b5d0
Merge branch 'block-cg' into func-jac-backend
pierrenodet Feb 10, 2026
7af6799
add .DS_Store to .gitignore
pierrenodet Feb 18, 2026
fcd64a5
small fixes in block_cg
pierrenodet Feb 18, 2026
d9687e3
backport some changes from torch-func-jac-backend
pierrenodet Feb 18, 2026
e3557e9
Merge branch 'main' into block-cg
pierrenodet Mar 31, 2026
ef0f8f1
remove regul in solve_fmat
pierrenodet Mar 31, 2026
64dcbc7
check unsupported ops in pmat implicit
pierrenodet Mar 31, 2026
cf8db0c
ops
pierrenodet Mar 31, 2026
9da8066
code cov
pierrenodet Mar 31, 2026
da250e6
codecov padding
pierrenodet Mar 31, 2026
9bcc022
codecov padding continued
pierrenodet Mar 31, 2026
68c04c6
codecov padding again
pierrenodet Apr 1, 2026
f57c2dd
Merge branch 'block-cg' into func-jac-backend
pierrenodet Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
__pycache__
*.swp
env
uv.lock
uv.lock
.DS_Store
.vscode
10 changes: 8 additions & 2 deletions nngeometry/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .dummy import DummyGenerator
from .torch_hooks.torch_hooks import TorchHooksJacobianBackend
from .torch_func_hessian import TorchFuncHessianBackend
from .torch_func_jacobian import TorchFuncJacobianBackend
from .torch_hooks.torch_hooks import TorchHooksJacobianBackend

__all__ = ["TorchHooksJacobianBackend", "DummyGenerator", "TorchFuncHessianBackend"]
__all__ = [
"TorchHooksJacobianBackend",
"DummyGenerator",
"TorchFuncHessianBackend",
"TorchFuncJacobianBackend",
]
26 changes: 13 additions & 13 deletions nngeometry/backend/torch_func_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@ def hvp(func, primals, tangents):


def batched_hvp(func, primals, batched_tangents):
return torch.vmap(
lambda tangents: torch.func.jvp(
lambda p: torch.func.grad(func)(p),
primals=(primals,),
tangents=(tangents,),
)[1]
)(batched_tangents)
return torch.vmap(lambda tangents: hvp(func, primals, tangents))(batched_tangents)


class TorchFuncHessianBackend(AbstractBackend):
Expand All @@ -41,18 +35,20 @@ def get_covariance_matrix(self, examples, layer_collection):
n_parameters = layer_collection.numel()
H = torch.zeros((n_parameters, n_parameters), device=device, dtype=dtype)

def compute_loss(params, X, y):
prediction = torch.func.functional_call(self.model, params, (X,))
return self.function(prediction, y)
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(self.model, params, (inputs,))
return self.function(prediction, targets)

params_dict = dict(layer_collection.named_parameters(layerid_to_mod))
params_dict = {k: v.detach() for k, v in params_dict.items()}

for d in self._get_iter_loader(loader):
inputs = d[0].to(device)
targets = d[1].to(device)

H_mb = torch.func.hessian(compute_loss)(
params_dict, inputs, d[1].to(device)
)
H_mb = torch.func.hessian(
partial(compute_loss, inputs=inputs, targets=targets),
)(params_dict)

for layer_id_x, layer_x in layer_collection.layers.items():
start_x = layer_collection.p_pos[layer_id_x]
Expand Down Expand Up @@ -123,6 +119,8 @@ def compute_loss(params, inputs, targets):
return self.function(prediction, targets)

params_dict = dict(layer_collection.named_parameters(layerid_to_mod))
params_dict = {k: v.detach() for k, v in params_dict.items()}

v_dict = {} # replace with function in PVector ?
for key, value in v.to_dict().items():
if len(value) > 1:
Expand Down Expand Up @@ -171,6 +169,8 @@ def compute_loss(params, inputs, targets):
so, sb, *_ = pfmap.size()

params_dict = dict(layer_collection.named_parameters(layerid_to_mod))
params_dict = {k: v.detach() for k, v in params_dict.items()}

pfmap_dict = {}
for layer_id, layer in layer_collection.layers.items():
d = pfmap.to_torch_layer(layer_id)
Expand Down
153 changes: 153 additions & 0 deletions nngeometry/backend/torch_func_jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from functools import partial

import torch

from nngeometry.object.map import PFMapDense
from nngeometry.object.vector import PVector

from ._backend import AbstractBackend


def fvp(func, primals, tangents):
_, jvp_out = torch.func.jvp(
lambda p: func(p),
primals=(primals,),
tangents=(tangents,),
)
_, vjp_fn = torch.func.vjp(func, primals)
return vjp_fn(jvp_out)[0]


def batched_fvp(func, primals, batched_tangents):
return torch.vmap(lambda tangents: fvp(func, primals, tangents))(batched_tangents)


class TorchFuncJacobianBackend(AbstractBackend):
def __init__(self, model, function, verbose=False):
self.model = model
self.function = function
self.verbose = verbose

def implicit_mv(self, v, examples, layer_collection):
layerid_to_mod = layer_collection.get_layerid_module_map(self.model)
device = self._check_same_device(layerid_to_mod.values())

loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

def function(params, inputs, targets=None):
predictions = torch.func.functional_call(self.model, params, (inputs,))
if targets is None:
return self.function(predictions)
else:
return self.function(predictions, targets)

params_dict = dict(layer_collection.named_parameters(layerid_to_mod))
params_dict = {k: v.detach() for k, v in params_dict.items()}

v_dict = {} # replace with function in PVector ?
for key, value in v.to_dict().items():
if len(value) > 1:
v_dict[key + ".weight"] = value[0]
v_dict[key + ".bias"] = value[1]
else:
v_dict[key + ".weight"] = value[0]

fvp_dict = {k: torch.zeros_like(p) for k, p in params_dict.items()}

for d in self._get_iter_loader(loader):
inputs = d[0].to(device)
if len(d) > 1:
targets = d[1].to(device)
else:
targets = None

fvp_mb = fvp(
partial(function, inputs=inputs, targets=targets), params_dict, v_dict
)

for k in fvp_mb:
fvp_dict[k] += fvp_mb[k].detach()

for k in fvp_dict:
fvp_dict[k] /= n_examples

output_dict = dict()
for layer_id, layer in layer_collection.layers.items():
if layer.has_bias():
output_dict[layer_id] = (
fvp_dict[layer_id + ".weight"],
fvp_dict[layer_id + ".bias"],
)
else:
output_dict[layer_id] = (fvp_dict[layer_id + ".weight"],)

return PVector(layer_collection, dict_repr=output_dict)

def implicit_mmap(self, pfmap, examples, layer_collection):
layerid_to_mod = layer_collection.get_layerid_module_map(self.model)
device = self._check_same_device(layerid_to_mod.values())

loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

def function(params, inputs, targets):
predictions = torch.func.functional_call(self.model, params, (inputs,))
if targets is None:
return self.function(predictions)
else:
return self.function(predictions, targets)

so, sb, *_ = pfmap.size()

params_dict = dict(layer_collection.named_parameters(layerid_to_mod))
params_dict = {k: v.detach() for k, v in params_dict.items()}

pfmap_dict = {}
for layer_id, layer in layer_collection.layers.items():
d = pfmap.to_torch_layer(layer_id)
if layer.has_bias():
pfmap_dict[layer_id + ".weight"] = d[0].view(-1, *layer.weight.size)
pfmap_dict[layer_id + ".bias"] = d[1].view(-1, *layer.bias.size)
else:
pfmap_dict[layer_id + ".weight"] = d[0].view(-1, *layer.weight.size)

b_fvp_dict = {
k: torch.zeros((so * sb, *p.shape), dtype=p.dtype, device=p.device)
for k, p in params_dict.items()
}

for d in self._get_iter_loader(loader):
inputs = d[0].to(device)
if len(d) > 1:
targets = d[1].to(device)
else:
targets = None

b_fvp_mb = batched_fvp(
partial(function, inputs=inputs, targets=targets),
params_dict,
pfmap_dict,
)

for k in b_fvp_mb:
b_fvp_dict[k] += b_fvp_mb[k].detach()

for k in b_fvp_dict:
b_fvp_dict[k] /= n_examples

output_dict = dict()
for layer_id, layer in layer_collection.layers.items():
if layer.has_bias():
output_dict[layer_id] = (
b_fvp_dict[layer_id + ".weight"].view(so, sb, -1),
b_fvp_dict[layer_id + ".bias"].view(so, sb, -1),
)
else:
output_dict[layer_id] = (
b_fvp_dict[layer_id + ".weight"].view(so, sb, -1),
)

return PFMapDense.from_dict(
generator=None, data_dict=output_dict, layer_collection=layer_collection
)
31 changes: 20 additions & 11 deletions nngeometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch
from torch.distributions.multivariate_normal import MultivariateNormal

from .backend import TorchHooksJacobianBackend
from .backend import TorchFuncJacobianBackend, TorchHooksJacobianBackend
from .layercollection import LayerCollection
from .object.pspace import PMatImplicit


def FIM_MonteCarlo(
Expand Down Expand Up @@ -94,7 +95,6 @@ def fim_function(*d):
return trials**-0.5 * sampled_targets

elif variant == "regression":

if "covariance" in kwargs:
sigma_2 = kwargs["covariance"]
else:
Expand Down Expand Up @@ -173,19 +173,28 @@ def FIM(
An optional layer collection
"""

if function is None:

def function(*d):
return model(d[0].to(device))

if layer_collection is None:
layer_collection = LayerCollection.from_model(model)

function_fim = partial(SQRT_VAR[variant], function)
if representation == PMatImplicit:

generator = TorchHooksJacobianBackend(
model=model, function=function_fim, verbose=verbose
)
def function_fim(*d):
return SQRT_VAR[variant](lambda predictions, _: predictions, *d)

generator = TorchFuncJacobianBackend(
model=model, function=function_fim, verbose=verbose
)

else:
if function is None:

def function(*d):
return model(d[0].to(device))

function_fim = partial(SQRT_VAR[variant], function)
generator = TorchHooksJacobianBackend(
model=model, function=function_fim, verbose=verbose
)

return representation(
generator=generator,
Expand Down
69 changes: 62 additions & 7 deletions nngeometry/object/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,30 @@

import torch

from .map import PFMap, PFMapDense
from .vector import FVector, PVector


class FMatAbstract(ABC):
@abstractmethod
def __init__(self, generator):
return NotImplementedError
pass

def __matmul__(self, other):
if isinstance(other, FVector):
return self.mv(other)
elif isinstance(other, PFMap):
return self.mmap(other)
else:
return NotImplemented

def __rmatmul__(self, other):
if isinstance(other, FVector):
return self.T.mv(other)
elif isinstance(other, PFMap):
return self.T.mmap(other)
else:
return NotImplemented


class FMatDense(FMatAbstract):
Expand All @@ -32,10 +49,30 @@ def compute_eigendecomposition(self, impl="eigh"):
raise NotImplementedError

def mv(self, v):
# TODO: test
v_flat = torch.mv(self.data, v.to_torch())
v_flat = v.to_torch().view(-1)
s = self.data.size()
v_flat = torch.mv(self.data.view(s[0] * s[1], s[2] * s[3]), v_flat)
return FVector(vector_repr=v_flat)

def mmap(self, pfmap):
s = self.data.size()
pfmap_flat = pfmap.to_torch().view(s[2] * s[3], -1)
return PFMapDense(
self.layer_collection,
self.generator,
data=torch.mm(self.data.view(s[0] * s[1], s[2] * s[3]), pfmap_flat).view(
s[0], s[1], -1
),
)

@property
def T(self):
return FMatDense(
layer_collection=self.layer_collection,
generator=self.generator,
data=self.data.permute(2, 3, 0, 1),
)

def vTMv(self, v):
v_flat = v.to_torch().view(-1)
sd = self.data.size()
Expand Down Expand Up @@ -82,11 +119,29 @@ def to_torch(self):
return self.data

def __add__(self, other):
# TODO: test
sum_data = self.data + other.data
return FMatDense(generator=self.generator, data=sum_data)
return FMatDense(
layer_collection=self.layer_collection,
generator=self.generator,
data=sum_data,
)

def __sub__(self, other):
# TODO: test
sub_data = self.data - other.data
return FMatDense(generator=self.generator, data=sub_data)
return FMatDense(
layer_collection=self.layer_collection,
generator=self.generator,
data=sub_data,
)

def __rmul__(self, other):
rmul_data = other * self.data
return FMatDense(
layer_collection=self.layer_collection,
generator=self.generator,
data=rmul_data,
)

def __imul__(self, other):
self.data *= other
return self
Loading
Loading