Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
buf::DataRef{Managed{B}}
dims::Dims{N}
offset::Int # Offset is in number of elements (not bytes).
offset::Int # offset of the data in memory, in bytes

function ROCArray{T, N, B}(::UndefInitializer, dims::Dims{N}) where {T, N, B <: Mem.AbstractAMDBuffer}
check_eltype("ROCArray", T)
Expand Down Expand Up @@ -92,7 +92,7 @@ GPUArrays.storage(a::ROCArray) = a.buf

function GPUArrays.derive(::Type{T}, x::ROCArray, dims::Dims{N}, offset::Int) where {N, T}
ref = copy(x.buf)
offset += (x.offset * Base.elsize(x)) ÷ aligned_sizeof(T)
offset = x.offset + offset * aligned_sizeof(T)
ROCArray{T, N}(ref, dims; offset)
end

Expand Down Expand Up @@ -298,7 +298,7 @@ Adapt.adapt_storage(::Float32Adaptor, xs::AbstractArray{Float16}) =
roc(xs) = adapt(Float32Adaptor(), xs)

Base.unsafe_convert(typ::Type{Ptr{T}}, x::ROCArray{T}) where T =
convert(typ, x.buf[]) + x.offset * aligned_sizeof(T)
convert(typ, x.buf[]) + x.offset

# some nice utilities

Expand Down Expand Up @@ -351,7 +351,7 @@ function Base.convert(
buf = convert(Mem.AbstractAMDBuffer, a.buf[])
ptr = convert(Ptr{T}, typeof(buf) <: Mem.HIPBuffer ?
buf : buf.dev_ptr)
llvm_ptr = AMDGPU.LLVMPtr{T,AS.Global}(ptr + a.offset * aligned_sizeof(T))
llvm_ptr = AMDGPU.LLVMPtr{T,AS.Global}(ptr + a.offset)
ROCDeviceArray{T, N, AS.Global}(a.dims, llvm_ptr)
end

Expand Down
9 changes: 9 additions & 0 deletions test/core/rocarray_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ end
@test collect(c) == 3:4
end

@testset "reinterpret of view with non-aligned offset" begin
# reinterpreting a view to a larger element type where the byte offset
# is not a multiple of the new element size
a = ROCArray(Int32[1,2,3,4,5,6,7,8,9])
v = view(a, 2:7) # offset of 1 Int32 = 4 bytes
r = reinterpret(Int64, v) # Int64 = 8 bytes; 4 is not a multiple of 8
@test Array(r) == reinterpret(Int64, @view Array(a)[2:7])
end

@testset "resize!" begin
a_h = Array(range(1, 10))
a_d = a_h |> roc
Expand Down