-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathGPUArraysCore.jl
231 lines (186 loc) · 7.11 KB
/
GPUArraysCore.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
module GPUArraysCore
using Adapt
## essential types
export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat,
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle,
AnyGPUArray, AnyGPUVector, AnyGPUMatrix
"""
AbstractGPUArray{T, N} <: DenseArray{T, N}
Supertype for `N`-dimensional GPU arrays (or array-like types) with elements of type `T`.
Instances of this type are expected to live on the host, see [`AbstractDeviceArray`](@ref)
for device-side objects.
"""
abstract type AbstractGPUArray{T, N} <: DenseArray{T, N} end
const AbstractGPUVector{T} = AbstractGPUArray{T, 1}
const AbstractGPUMatrix{T} = AbstractGPUArray{T, 2}
const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T, 2}}
# convenience aliases for working with wrapped arrays
const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
const AnyGPUVector{T} = AnyGPUArray{T, 1}
const AnyGPUMatrix{T} = AnyGPUArray{T, 2}
## broadcasting
"""
Abstract supertype for GPU array styles. The `N` parameter is the dimensionality.
Downstream implementations should provide a concrete array style type that inherits from
this supertype.
"""
abstract type AbstractGPUArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
## scalar iteration
export allowscalar, @allowscalar, assertscalar
@enum ScalarIndexing ScalarAllowed ScalarWarn ScalarWarned ScalarDisallowed
# if the user explicitly calls allowscalar, use that setting for all new tasks
# XXX: use context variables to inherit the parent task's setting, once available.
const requested_scalar_indexing = Ref{Union{Nothing,ScalarIndexing}}(nothing)
const _repl_frontend_task = Ref{Union{Nothing,Missing,Task}}()
function repl_frontend_task()
if !isassigned(_repl_frontend_task)
_repl_frontend_task[] = get_repl_frontend_task()
end
_repl_frontend_task[]
end
function get_repl_frontend_task()
@static if VERSION >= v"1.10.0-DEV.444" || v"1.9-beta4" <= VERSION < v"1.10-"
if isdefined(Base, :active_repl)
Base.active_repl.frontend_task
else
missing
end
else
nothing
end
end
@noinline function default_scalar_indexing()
if isinteractive()
# try to detect the REPL
repl_task = repl_frontend_task()
if repl_task isa Task
if repl_task === current_task()
# we always allow scalar iteration on the REPL's frontend task,
# where we often trigger scalar indexing by displaying GPU objects.
ScalarAllowed
else
ScalarDisallowed
end
else
# we couldn't detect a REPL in this interactive session, so default to a warning
ScalarWarn
end
else
# non-interactively, we always disallow scalar iteration
ScalarDisallowed
end
end
"""
assertscalar(op::String)
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
error will be thrown ([`allowscalar`](@ref)).
"""
function assertscalar(op::String)
behavior = get(task_local_storage(), :ScalarIndexing, nothing)
if behavior === nothing
behavior = requested_scalar_indexing[]
if behavior === nothing
behavior = default_scalar_indexing()
end
task_local_storage(:ScalarIndexing, behavior)
end
behavior = behavior::ScalarIndexing
if behavior === ScalarAllowed
# fast path
return
end
_assertscalar(op, behavior)
end
@noinline function _assertscalar(op, behavior)
desc = """Invocation of '$op' resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question."""
if behavior == ScalarDisallowed
errorscalar(op)
elseif behavior == ScalarWarn
warnscalar(op)
task_local_storage(:ScalarIndexing, ScalarWarned)
end
return
end
function scalardesc(op)
desc = """Invocation of $op resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question."""
end
@noinline function warnscalar(op)
desc = scalardesc(op)
@warn("""Performing scalar indexing on task $(current_task()).
$desc""")
end
@noinline function errorscalar(op)
desc = scalardesc(op)
error("""Scalar indexing is disallowed.
$desc""")
end
# Like a try-finally block, except without introducing the try scope
# NOTE: This is deprecated and should not be used from user logic. A proper solution to
# this problem will be introduced in https://github.com/JuliaLang/julia/pull/39217
macro __tryfinally(ex, fin)
Expr(:tryfinally,
:($(esc(ex))),
:($(esc(fin)))
)
end
"""
allowscalar([true])
allowscalar([true]) do
...
end
Use this function to allow or disallow scalar indexing, either globall or for the
duration of the do block.
See also: [`@allowscalar`](@ref).
"""
allowscalar
function allowscalar(f::Base.Callable)
task_local_storage(f, :ScalarIndexing, ScalarAllowed)
end
function allowscalar(allow::Bool=true)
if allow
@warn """It's not recommended to use allowscalar([true]) to allow scalar indexing.
Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.""" maxlog=1
end
setting = allow ? ScalarAllowed : ScalarDisallowed
task_local_storage(:ScalarIndexing, setting)
requested_scalar_indexing[] = setting
return
end
"""
@allowscalar() begin
# code that can use scalar indexing
end
Denote which operations can use scalar indexing.
See also: [`allowscalar`](@ref).
"""
macro allowscalar(ex)
quote
local tls_value = get(task_local_storage(), :ScalarIndexing, nothing)
task_local_storage(:ScalarIndexing, ScalarAllowed)
@__tryfinally($(esc(ex)),
isnothing(tls_value) ? delete!(task_local_storage(), :ScalarIndexing)
: task_local_storage(:ScalarIndexing, tls_value))
end
end
## other
"""
backend(x)
backend(T::Type)
Gets the GPUArrays back-end responsible for managing arrays of type `T`.
"""
backend(::Type) = error("This object is not a GPU array") # COV_EXCL_LINE
backend(x) = backend(typeof(x))
# WrappedArray from Adapt for Base wrappers.
backend(::Type{WA}) where WA<:WrappedArray = backend(unwrap_type(WA))
end # module GPUArraysCore