-
Notifications
You must be signed in to change notification settings - Fork 156
/
Copy pathcomposites.jl
431 lines (341 loc) · 13.6 KB
/
composites.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
const SupervisedNetwork = Union{DeterministicNetwork,ProbabilisticNetwork}
# to suppress inclusion in models():
MLJBase.is_wrapper(::Type{DeterministicNetwork}) = true
MLJBase.is_wrapper(::Type{ProbabilisticNetwork}) = true
# fall-back for updating learning networks exported as models:
function MLJBase.update(model::Union{SupervisedNetwork,UnsupervisedNetwork},
verbosity, fitresult, cache, args...)
fit!(fitresult; verbosity=verbosity)
return fitresult, cache, nothing
end
# fall-back for predicting on learning networks exported as models
MLJBase.predict(composite::SupervisedNetwork, fitresult, Xnew) =
fitresult(Xnew)
"""
MLJ.tree(N::Node)
Return a description of the tree defined by the learning network
terminating at node `N`.
"""
tree(s::MLJ.Source) = (source = s,)
function tree(W::MLJ.Node)
mach = W.machine
if mach == nothing
value2 = nothing
endkeys=[]
endvalues=[]
else
value2 = mach.model
endkeys = [Symbol("train_arg", i) for i in eachindex(mach.args)]
endvalues = [tree(arg) for arg in mach.args]
end
keys = tuple(:operation, :model,
[Symbol("arg", i) for i in eachindex(W.args)]...,
endkeys...)
values = tuple(W.operation, value2,
[tree(arg) for arg in W.args]...,
endvalues...)
return NamedTuple{keys}(values)
end
# get the top level args of the tree of some node:
function args(tree)
keys_ = filter(keys(tree) |> collect) do key
match(r"^arg[0-9]*", string(key)) != nothing
end
return [getproperty(tree, key) for key in keys_]
end
# get the top level train_args of the tree of some node:
function train_args(tree)
keys_ = filter(keys(tree) |> collect) do key
match(r"^train_arg[0-9]*", string(key)) != nothing
end
return [getproperty(tree, key) for key in keys_]
end
"""
models(N::AbstractNode)
A vector of all models referenced by node `N`, each model
appearing exactly once.
"""
function models(W::MLJ.AbstractNode)
models_ = filter(flat_values(tree(W)) |> collect) do model
model isa MLJ.Model
end
return unique(models_)
end
"""
sources(N::AbstractNode)
A vector of all sources referenced by calls `N()` and `fit!(N)`. These
are the sources of the directed acyclic graph associated with the
learning network terminating at `N`.
Not to be confused with `origins(N)` which refers to the same graph with edges corresponding to training arguments deleted.
See also: orgins, source
"""
function sources(W::MLJ.AbstractNode)
sources_ = filter(MLJ.flat_values(tree(W)) |> collect) do model
model isa MLJ.Source
end
return unique(sources_)
end
"""
machines(N)
List all machines in the learning network terminating at node `N`.
"""
machines(W::MLJ.Source) = Any[]
function machines(W::MLJ.Node)
if W.machine == nothing
return vcat([machines(arg) for arg in W.args]...) |> unique
else
return vcat(Any[W.machine, ],
[machines(arg) for arg in W.args]...,
[machines(arg) for arg in W.machine.args]...) |> unique
end
end
"""
replace(W::MLJ.Node, a1=>b1, a2=>b2, ....)
Create a deep copy of a node `W`, and thereby replicate the learning
network terminating at `W`, but replacing any specified sources and
models `a1, a2, ...` of the original network with the specified targets
`b1, b2, ...`.
"""
function Base.replace(W::Node, pairs::Pair...) where N
# Note: We construct nodes of the new network as values of a
# dictionary keyed on the nodes of the old network. Additionally,
# there are dictionaries of models keyed on old models and
# machines keyed on old machines. The node and machine
# dictionaries must be built simultaneously.
# build model dict:
model_pairs = filter(collect(pairs)) do pair
first(pair) isa Model
end
models_ = models(W)
models_to_copy = setdiff(models_, first.(model_pairs))
model_copy_pairs = [model=>deepcopy(model) for model in models_to_copy]
newmodel_given_old = IdDict(vcat(model_pairs, model_copy_pairs))
# build complete source replacement pairs:
source_pairs = filter(collect(pairs)) do pair
first(pair) isa Source
end
sources_ = sources(W)
sources_to_copy = setdiff(sources_, first.(source_pairs))
source_copy_pairs = [source=>deepcopy(source) for source in sources_to_copy]
all_source_pairs = vcat(source_pairs, source_copy_pairs)
# drop source nodes from all nodes of network terminating at W:
nodes_ = filter(nodes(W)) do N
!(N isa Source)
end
# instantiate node and machine dictionaries:
newnode_given_old =
IdDict{AbstractNode,AbstractNode}(all_source_pairs)
newmach_given_old = IdDict{NodalMachine,NodalMachine}()
# build the new network:
for N in nodes_
args = [newnode_given_old[arg] for arg in N.args]
if N.machine == nothing
newnode_given_old[N] = node(N.operation, args...)
else
if N.machine in keys(newmach_given_old)
mach = newmach_given_old[N.machine]
else
train_args = [newnode_given_old[arg] for arg in N.machine.args]
mach = machine(newmodel_given_old[N.machine.model], train_args...)
newmach_given_old[N.machine] = mach
end
newnode_given_old[N] = N.operation(mach, args...)
end
end
return newnode_given_old[nodes_[end]]
end
"""
reset!(N::Node)
Place the learning network terminating at node `N` into a state in
which `fit!(N)` will retrain from scratch all machines in its
dependency tape. Does not actually train any machine or alter
fit-results. (The method simply resets `m.state` to zero, for every
machine `m` in the network.)
"""
function reset!(W::Node)
for mach in machines(W)
mach.state = 0 # to do: replace with dagger object
end
end
# closures for later:
function supervised_fit_method(network_Xs, network_ys, network_N,
network_models...)
function fit(model::M, verbosity, X, y) where M <:Supervised
Xs = source(X)
ys = source(y)
replacement_models = [getproperty(model, fld)
for fld in fieldnames(M)]
model_replacements = [network_models[j] => replacement_models[j]
for j in eachindex(network_models)]
source_replacements = [network_Xs => Xs, network_ys => ys]
replacements = vcat(model_replacements, source_replacements)
yhat = replace(network_N, replacements...)
fit!(yhat, verbosity=verbosity)
cache = nothing
report = nothing
return yhat, cache, report
end
return fit
end
function unsupervised_fit_method(network_Xs, network_N,
network_models...)
function fit(model::M, verbosity, X) where M <:Unsupervised
Xs = source(X)
replacement_models = [getproperty(model, fld)
for fld in fieldnames(M)]
model_replacements = [network_models[j] => replacement_models[j]
for j in eachindex(network_models)]
source_replacements = [network_Xs => Xs,]
replacements = vcat(model_replacements, source_replacements)
Xout = replace(network_N, replacements...)
fit!(Xout, verbosity=verbosity)
cache = nothing
report = nothing
return Xout, cache, report
end
return fit
end
"""
@from_network NewCompositeModel(fld1=model1, fld2=model2, ...) <= (Xs, N)
@from_network NewCompositeModel(fld1=model1, fld2=model2, ...) <= (Xs, ys, N)
Create, respectively, a new stand-alone unsupervised or superivsed
model type `NewCompositeModel` using a learning network as a
blueprint. Here `Xs`, `ys` and `N` refer to the input source, node,
target source node and terminating source node of the network. The
model type `NewCompositeModel` is equipped with fields named `:fld1`,
`:fld2`, ..., which correspond to component models `model1`, `model2`
appearing in the network (which must therefore be elements of
`models(N)`). Deep copies of the specified component models are used
as default values in an automatically generated keyword constructor
for `NewCompositeModel`.
Return value: A new `NewCompositeModel` instance, with default
field values.
For details and examples refer to the "Learning Networks" section of
the documentation.
"""
macro from_network(ex)
modeltype_ex = ex.args[2].args[1]
kw_exs = ex.args[2].args[2:end]
fieldname_exs = [k.args[1] for k in kw_exs]
model_exs = [k.args[2] for k in kw_exs]
Xs_ex = ex.args[3].args[1] # input node
N_ex = ex.args[3].args[end] # output node
# TODO: add more type and syntax checks here:
N = __module__.eval(N_ex)
N isa Node ||
error("$(typeof(N)) given where Node was expected. ")
models_ = [__module__.eval(e) for e in model_exs]
@show models_
@show models(N)
issubset(models_, models(N)) ||
error("One or more specified models not in the learning network "*
"terminating at $N_ex.\n Use models($N_ex) to inspect models. ")
nodes_ = nodes(N)
Xs = __module__.eval(Xs_ex)
Xs in nodes_ ||
error("Specified input source $Xs_ex is not a source of $N_ex.")
if length(ex.args[3].args) == 3
ys_ex = ex.args[3].args[2] # target node
ys = __module__.eval(ys_ex)
ys in nodes_ ||
error("Specified target source $ys_ex is not a source of $N_ex.")
from_network_(__module__, modeltype_ex, fieldname_exs, model_exs,
Xs_ex, ys_ex, N_ex)
else
from_network_(__module__, modeltype_ex, fieldname_exs, model_exs,
Xs_ex, N_ex)
end
esc(quote
$modeltype_ex()
end)
end
# supervised case:
function from_network_(mod, modeltype_ex, fieldname_exs, model_exs,
Xs_ex, ys_ex, N_ex)
N = mod.eval(N_ex)
if MLJBase.is_probabilistic(typeof(models(N)[1]))
subtype_ex = :ProbabilisticNetwork
else
subtype_ex = :DeterministicNetwork
end
# code defining the composite model struct and fit method:
program1 = quote
import MLJBase
mutable struct $modeltype_ex <: MLJ.$subtype_ex
$(fieldname_exs...)
end
MLJBase.fit(model::$modeltype_ex, verbosity::Integer, X, y) =
MLJ.supervised_fit_method($Xs_ex, $ys_ex, $N_ex,
$(model_exs...))(model, verbosity, X, y)
end
program2 = quote
defaults =
MLJBase.@set_defaults $modeltype_ex deepcopy.([$(model_exs...)])
# MLJBase.target_scitype_union($modeltype) =
end
mod.eval(program1)
mod.eval(program2)
end
# unsupervised case:
function from_network_(mod, modeltype_ex, fieldname_exs, model_exs,
Xs_ex, N_ex)
subtype_ex = :UnsupervisedNetwork
# code defining the composite model struct and fit method:
program1 = quote
import MLJBase
mutable struct $modeltype_ex <: MLJ.$subtype_ex
$(fieldname_exs...)
end
MLJBase.fit(model::$modeltype_ex, verbosity::Integer, X) =
MLJ.unsupervised_fit_method($Xs_ex, $N_ex,
$(model_exs...))(model, verbosity, X)
end
program2 = quote
defaults =
MLJBase.@set_defaults $modeltype_ex deepcopy.([$(model_exs...)])
end
mod.eval(program1)
mod.eval(program2)
end
## A COMPOSITE FOR TESTING PURPOSES
"""
SimpleDeterministicCompositeModel(;regressor=ConstantRegressor(),
transformer=FeatureSelector())
Construct a composite model consisting of a transformer
(`Unsupervised` model) followed by a `Deterministic` model. Mainly
intended for internal testing .
"""
mutable struct SimpleDeterministicCompositeModel{L<:Deterministic,
T<:Unsupervised} <: DeterministicNetwork
model::L
transformer::T
end
function SimpleDeterministicCompositeModel(; model=DeterministicConstantRegressor(),
transformer=FeatureSelector())
composite = SimpleDeterministicCompositeModel(model, transformer)
message = MLJ.clean!(composite)
isempty(message) || @warn message
return composite
end
MLJBase.is_wrapper(::Type{<:SimpleDeterministicCompositeModel}) = true
function MLJBase.fit(composite::SimpleDeterministicCompositeModel, verbosity::Int, Xtrain, ytrain)
X = source(Xtrain) # instantiates a source node
y = source(ytrain)
t = machine(composite.transformer, X)
Xt = transform(t, X)
l = machine(composite.model, Xt, y)
yhat = predict(l, Xt)
fit!(yhat, verbosity=verbosity)
fitresult = yhat
report = l.report
cache = l
return fitresult, cache, report
end
# MLJBase.predict(composite::SimpleDeterministicCompositeModel, fitresult, Xnew) = fitresult(Xnew)
MLJBase.load_path(::Type{<:SimpleDeterministicCompositeModel}) = "MLJ.SimpleDeterministicCompositeModel"
MLJBase.package_name(::Type{<:SimpleDeterministicCompositeModel}) = "MLJ"
MLJBase.package_uuid(::Type{<:SimpleDeterministicCompositeModel}) = ""
MLJBase.package_url(::Type{<:SimpleDeterministicCompositeModel}) = "https://github.com/alan-turing-institute/MLJ.jl"
MLJBase.is_pure_julia(::Type{<:SimpleDeterministicCompositeModel}) = true
# MLJBase.input_scitype_union(::Type{<:SimpleDeterministicCompositeModel}) =
# MLJBase.target_scitype_union(::Type{<:SimpleDeterministicCompositeModel}) =