Skip to content

Extending scatter! to work with CUDA sparse arrays (retry of #648)#672

Open
alonsoC1s wants to merge 11 commits intoFluxML:masterfrom
alonsoC1s:master
Open

Extending scatter! to work with CUDA sparse arrays (retry of #648)#672
alonsoC1s wants to merge 11 commits intoFluxML:masterfrom
alonsoC1s:master

Conversation

@alonsoC1s
Copy link
Copy Markdown
Contributor

This is a new attempt at PR #648

It extends the signature of scatter! to work with AbstractCuSparseArray, a CUDA array type notably excluded by the original method. With the proposed patch, calling scatter! with sparse arrays from CUDA.CUSPARSE will correctly call the CUDA-specialized method instead of calling the generic CPU method, which triggered a scalar indexing error.

@alonsoC1s
Copy link
Copy Markdown
Contributor Author

Quick update. The test failures were significant, and I underestimated what needed to be done. I'll mark the PR as draft while I work on JuliaGPU/GPUArrays.jl#694, which will hopefully make finishing this PR for all of the different sparse formats more feasible

@alonsoC1s alonsoC1s marked this pull request as draft March 27, 2026 11:01
@alonsoC1s
Copy link
Copy Markdown
Contributor Author

alonsoC1s commented Apr 14, 2026

The implementation of scatter! and friends should now work with GPU sparse device arrays as a source, though not as a destination, since that would potentially require allocating in device code. I recreated the standard testing suite and specialized it so that it makes sense in the sparse case

This PR is heavily dependant on a related PR to GPUArrays.jl being merged, and shouldn't make its way into main until the other one is succesfully merged and a new release is made

@alonsoC1s alonsoC1s marked this pull request as ready for review April 15, 2026 07:16
Comment thread ext/NNlibCUDAExt/scatter.jl Outdated
@alonsoC1s
Copy link
Copy Markdown
Contributor Author

Looking at the CI failures, I can't tell if they are related. They seem to happen only when doing gradients with Enzyme, and some don't even fail for scatter!, but other functions. While testing locally, it all worked as expected

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants