Skip to content

Commit e47af96

Browse files
committed
Local backend support
1 parent 2a571a8 commit e47af96

File tree

6 files changed

+82
-6
lines changed

6 files changed

+82
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ LLVM = "9.4.1"
4040
LinearAlgebra = "1.6"
4141
MacroTools = "0.5"
4242
PrecompileTools = "1"
43-
SPIRVIntrinsics = "0.5"
43+
SPIRVIntrinsics = "0.5.7"
4444
SPIRV_LLVM_Backend_jll = "20"
4545
SPIRV_Tools_jll = "2024.4, 2025.1"
4646
SparseArrays = "<0.0.1, 1.6"

src/pocl/backend.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using ..POCL: device, clconvert, clfunction
77
import KernelAbstractions as KA
88
import KernelAbstractions.KernelIntrinsics as KI
99

10+
import SPIRVIntrinsics
11+
1012
import StaticArrays
1113

1214
import Adapt
@@ -174,10 +176,36 @@ end
174176
function KI.max_work_group_size(::POCLBackend)::Int
175177
return Int(device().max_work_group_size)
176178
end
179+
function KI.sub_group_size(::POCLBackend)::Int
180+
sg_sizes = cl.device().sub_group_sizes
181+
if 32 in sg_sizes
182+
return 32
183+
elseif 64 in sg_sizes
184+
return 64
185+
elseif 16 in sg_sizes
186+
return 16
187+
else
188+
return 1
189+
end
190+
end
177191
function KI.multiprocessor_count(::POCLBackend)::Int
178192
return Int(device().max_compute_units)
179193
end
180194

195+
function KI.shfl_down_types(::POCLBackend)
196+
res = copy(SPIRVIntrinsics.gentypes)
197+
198+
backend_extensions = cl.device().extensions
199+
if "cl_khr_fp64" backend_extensions
200+
res = setdiff(res, [Float64])
201+
end
202+
if "cl_khr_fp16" backend_extensions
203+
res = setdiff(res, [Float16])
204+
end
205+
206+
return res
207+
end
208+
181209
## Indexing Functions
182210

183211
@device_override @inline function KI.get_local_id()
@@ -204,6 +232,16 @@ end
204232
return (; x = Int(get_global_size(1)), y = Int(get_global_size(2)), z = Int(get_global_size(3)))
205233
end
206234

235+
@device_override KI.get_sub_group_size() = get_sub_group_size()
236+
237+
@device_override KI.get_max_sub_group_size() = get_max_sub_group_size()
238+
239+
@device_override KI.get_num_sub_groups() = get_num_sub_groups()
240+
241+
@device_override KI.get_sub_group_id() = get_sub_group_id()
242+
243+
@device_override KI.get_sub_group_local_id() = get_sub_group_local_id()
244+
207245
@device_override @inline function KA.__validindex(ctx)
208246
if KA.__dynamic_checkbounds(ctx)
209247
I = @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(1), get_local_id(1))
@@ -232,6 +270,14 @@ end
232270
work_group_barrier(POCL.LOCAL_MEM_FENCE | POCL.GLOBAL_MEM_FENCE)
233271
end
234272

273+
@device_override @inline function KI.sub_group_barrier()
274+
sub_group_barrier(POCL.LOCAL_MEM_FENCE | POCL.GLOBAL_MEM_FENCE)
275+
end
276+
277+
@device_override function KI.shfl_down(val::T, offset::Integer) where {T}
278+
sub_group_shuffle(val, get_sub_group_local_id() + offset)
279+
end
280+
235281
@device_override @inline function KI._print(args...)
236282
POCL._print(args...)
237283
end

src/pocl/compiler/compilation.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
## gpucompiler interface
22

3-
struct OpenCLCompilerParams <: AbstractCompilerParams end
3+
Base.@kwdef struct OpenCLCompilerParams <: AbstractCompilerParams
4+
sub_group_size::Int
5+
end
6+
47
const OpenCLCompilerConfig = CompilerConfig{SPIRVCompilerTarget, OpenCLCompilerParams}
58
const OpenCLCompilerJob = CompilerJob{SPIRVCompilerTarget, OpenCLCompilerParams}
69

@@ -19,7 +22,21 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1922
in(fn, known_intrinsics) ||
2023
contains(fn, "__spirv_")
2124

25+
function GPUCompiler.finish_module!(
26+
@nospecialize(job::OpenCLCompilerJob),
27+
mod::LLVM.Module, entry::LLVM.Function
28+
)
29+
entry = invoke(
30+
GPUCompiler.finish_module!,
31+
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
32+
job, mod, entry
33+
)
34+
35+
# Set the subgroup size
36+
metadata(entry)["intel_reqd_sub_group_size"] = MDNode([ConstantInt(Int32(job.config.params.sub_group_size))])
2237

38+
return entry
39+
end
2340
## compiler implementation (cache, configure, compile, and link)
2441

2542
# cache of compilation caches, per context
@@ -45,14 +62,17 @@ function compiler_config(dev::cl.Device; kwargs...)
4562
end
4663
return config
4764
end
48-
@noinline function _compiler_config(dev; kernel = true, name = nothing, always_inline = false, kwargs...)
65+
@noinline function _compiler_config(dev; kernel = true, name = nothing, always_inline = false, sub_group_size = 32, kwargs...)
4966
supports_fp16 = "cl_khr_fp16" in dev.extensions
5067
supports_fp64 = "cl_khr_fp64" in dev.extensions
5168

69+
if sub_group_size dev.sub_group_sizes
70+
@error("$sub_group_size is not a valid sub-group size for this device.")
71+
end
5272

5373
# create GPUCompiler objects
5474
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, kwargs...)
55-
params = OpenCLCompilerParams()
75+
params = OpenCLCompilerParams(; sub_group_size)
5676
return CompilerConfig(target, params; kernel, name, always_inline)
5777
end
5878

src/pocl/compiler/execution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export @opencl, clfunction, clconvert
44
## high-level @opencl interface
55

66
const MACRO_KWARGS = [:launch]
7-
const COMPILER_KWARGS = [:kernel, :name, :always_inline]
7+
const COMPILER_KWARGS = [:kernel, :name, :always_inline, :sub_group_size]
88
const LAUNCH_KWARGS = [:global_size, :local_size, :queue]
99

1010
macro opencl(ex...)

src/pocl/nanoOpenCL.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ const CL_KERNEL_EXEC_INFO_SVM_PTRS = 0x11b6
390390

391391
const CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM = 0x11b7
392392

393+
const CL_DEVICE_SUB_GROUP_SIZES_INTEL = 0x4108
394+
393395
struct CLError <: Exception
394396
code::Cint
395397
end
@@ -935,6 +937,14 @@ devices(p::Platform) = devices(p, CL_DEVICE_TYPE_ALL)
935937
return tuple([Int(r) for r in result]...)
936938
end
937939

940+
if s == :sub_group_sizes
941+
res_size = Ref{Csize_t}()
942+
clGetDeviceInfo(d, CL_DEVICE_SUB_GROUP_SIZES_INTEL, C_NULL, C_NULL, res_size)
943+
result = Vector{Csize_t}(undef, res_size[] ÷ sizeof(Csize_t))
944+
clGetDeviceInfo(d, CL_DEVICE_SUB_GROUP_SIZES_INTEL, sizeof(result), result, C_NULL)
945+
return tuple([Int(r) for r in result]...)
946+
end
947+
938948
if s == :max_image2d_shape
939949
width = Ref{Csize_t}()
940950
height = Ref{Csize_t}()

src/pocl/pocl.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function queue()
4141
end
4242

4343
using GPUCompiler
44-
import LLVM
44+
import LLVM: LLVM, MDNode, ConstantInt, metadata
4545
using Adapt
4646

4747
## device overrides

0 commit comments

Comments
 (0)