-
Notifications
You must be signed in to change notification settings - Fork 77
Type Error in draw_lines When Generating Point Cloud with generate_point_cloud_figure_mvsplat #100
Description
Description
I encountered a type checking error while trying to generate 3D point clouds using the generate_point_cloud_figure_mvsplat script. The error occurs in the draw_lines function when processing 2D points.
Steps to Reproduce
I ran the following command:
python -m src.paper.generate_point_cloud_figure_mvsplat +experiment=re10k checkpointing.load=checkpoints/re10k.ckpt mode=test dataset/view_sampler=evaluation
Error Message
Error executing job with overrides: ['+experiment=re10k', 'checkpointing.load=checkpoints/re10k.ckpt', 'mode=test', 'dataset/view_sampler=evaluation']
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of src.visualization.drawing.lines.draw_lines.
The problem arose whilst typechecking parameter 'start'.
Actual value: f32[1,2](torch)
Expected type: typing.Union[float, int, Iterable[Union[float, int]], Shaped[Tensor, '3'], Shaped[Tensor, 'batch 3']].
----------------------
Called with parameters: {
'image': f32[3,1024,1024](torch),
'start': f32[1,2](torch),
'end': f32[1,2](torch),
'color': (1, 1, 1),
'width': 1.8,
'cap': 'round',
'num_msaa_passes': 1,
'x_range': (0, 1),
'y_range': (0, 1)
}
Relevant Code
The error occurs at line 339 in generate_point_cloud_figure_mvsplat.py:
alpha = draw_lines(
image=image,
start=start, # This is a tensor with shape [1,2]
end=end, # This is a tensor with shape [1,2]
color=(1, 1, 1),
width=1.8,
cap='round',
num_msaa_passes=1,
x_range=(0, 1),
y_range=(0, 1)
)
The draw_lines function expects the start and end parameters to be either:
A single float/int
An iterable of floats/ints
A tensor with shape [3] (representing a 3D point)
A tensor with shape [batch, 3] (representing multiple 3D points)
However, during point cloud generation, I'm passing tensors with shape [1,2] which represent 2D points, causing the type check to fail.
Questions
- Is there a specific reason why draw_lines only accepts 3D points when generating point clouds?
- What's the recommended approach to fix this issue? Should I:
- Modify the point cloud generation code to provide 3D points (by adding a Z coordinate)?
- Update the type annotations in draw_lines to accept 2D points as well?
- Is there another function I should be using for 2D line drawing in this context?
Environment
Python 3.10
Using the mvsplat environment as specified in the repository
Thank you for your assistance!