Extending scatter! to work with CUDA sparse arrays (retry of #648)#672
Extending scatter! to work with CUDA sparse arrays (retry of #648)#672alonsoC1s wants to merge 11 commits intoFluxML:masterfrom
Conversation
|
Quick update. The test failures were significant, and I underestimated what needed to be done. I'll mark the PR as draft while I work on JuliaGPU/GPUArrays.jl#694, which will hopefully make finishing this PR for all of the different sparse formats more feasible |
|
The implementation of This PR is heavily dependant on a related PR to GPUArrays.jl being merged, and shouldn't make its way into main until the other one is succesfully merged and a new release is made |
|
Looking at the CI failures, I can't tell if they are related. They seem to happen only when doing gradients with Enzyme, and some don't even fail for |
This is a new attempt at PR #648
It extends the signature of scatter! to work with AbstractCuSparseArray, a CUDA array type notably excluded by the original method. With the proposed patch, calling scatter! with sparse arrays from CUDA.CUSPARSE will correctly call the CUDA-specialized method instead of calling the generic CPU method, which triggered a scalar indexing error.