Skip to content

Deprecate special handling of shared variables in OpFromGraph #2035

@ricardoV94

Description

@ricardoV94

Description

Users can omit shared variables from the OpFromGraph signature under strict=False:

if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)

This puts reference to outer graph variables inside the Op, which is a recipe for disaster:

def __call__(self, *inputs, **kwargs):
# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) == num_expected_inps:
actual_inputs = inputs + tuple(self.shared_inputs)
return super().__call__(*actual_inputs, **kwargs)
elif len(inputs) == len(self.inner_inputs):
return super().__call__(*inputs, **kwargs)
else:
raise ValueError(f"Expected at least {num_expected_inps} input(s)")

if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead.
replace = dict(
zip(
self.inner_inputs[num_expected_inps:],
new_shared_inputs,
strict=True,
)
)
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, copy_inputs_over=True
)
# It's possible that the new shared variable inputs aren't actually
# shared variables. When they aren't we need to add them as new
# inputs.
unshared_inputs = [
inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable)
]
new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs
new_op = type(self)(
inputs=new_inner_inputs,
outputs=new_inner_outputs,
inline=self.is_inline,
lop_overrides=self.lop_overrides,
grad_overrides=self.grad_overrides,
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
destroy_map=self.destroy_map,
**self.kwargs,
)
new_inputs = (
list(non_shared_inputs) + unshared_inputs + new_op.shared_inputs
)

We should add a FutureWarning telling users to switch to strict=True, and eventually never allow the old behavior

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions