forked from google-deepmind/alphafold3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_alphafold.py
674 lines (596 loc) · 22.1 KB
/
run_alphafold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""AlphaFold 3 structure prediction script.
AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
To request access to the AlphaFold 3 model parameters, follow the process set
out at https://github.com/google-deepmind/alphafold3. You may only use these
if received directly from Google. Use is subject to terms of use available at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""
from collections.abc import Callable, Iterable, Sequence
import csv
import dataclasses
import functools
import multiprocessing
import os
import pathlib
import shutil
import string
import textwrap
import time
import typing
from typing import Protocol, Self, TypeVar, overload
from absl import app
from absl import flags
from alphafold3.common import base_config
from alphafold3.common import folding_input
from alphafold3.common import resources
from alphafold3.constants import chemical_components
import alphafold3.cpp
from alphafold3.data import featurisation
from alphafold3.data import pipeline
from alphafold3.jax.attention import attention
from alphafold3.model import features
from alphafold3.model import params
from alphafold3.model import post_processing
from alphafold3.model.components import base_model
from alphafold3.model.components import utils
from alphafold3.model.diffusion import model as diffusion_model
import haiku as hk
import jax
from jax import numpy as jnp
import numpy as np
_HOME_DIR = pathlib.Path(os.environ.get('HOME'))
DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
# Input and output paths.
_JSON_PATH = flags.DEFINE_string(
'json_path',
None,
'Path to the input JSON file.',
)
_INPUT_DIR = flags.DEFINE_string(
'input_dir',
None,
'Path to the directory containing input JSON files.',
)
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
None,
'Path to a directory where the results will be saved.',
)
_MODEL_DIR = flags.DEFINE_string(
'model_dir',
DEFAULT_MODEL_DIR.as_posix(),
'Path to the model to use for inference.',
)
_FLASH_ATTENTION_IMPLEMENTATION = flags.DEFINE_enum(
'flash_attention_implementation',
default='triton',
enum_values=['triton', 'cudnn', 'xla'],
help=(
"Flash attention implementation to use. 'triton' and 'cudnn' uses a"
' Triton and cuDNN flash attention implementation, respectively. The'
' Triton kernel is fastest and has been tested more thoroughly. The'
" Triton and cuDNN kernels require Ampere GPUs or later. 'xla' uses an"
' XLA attention implementation (no flash attention) and is portable'
' across GPU devices.'
),
)
# Control which stages to run.
_RUN_DATA_PIPELINE = flags.DEFINE_bool(
'run_data_pipeline',
True,
'Whether to run the data pipeline on the fold inputs.',
)
_RUN_INFERENCE = flags.DEFINE_bool(
'run_inference',
True,
'Whether to run inference on the fold inputs.',
)
# Binary paths.
_JACKHMMER_BINARY_PATH = flags.DEFINE_string(
'jackhmmer_binary_path',
shutil.which('jackhmmer'),
'Path to the Jackhmmer binary.',
)
_NHMMER_BINARY_PATH = flags.DEFINE_string(
'nhmmer_binary_path',
shutil.which('nhmmer'),
'Path to the Nhmmer binary.',
)
_HMMALIGN_BINARY_PATH = flags.DEFINE_string(
'hmmalign_binary_path',
shutil.which('hmmalign'),
'Path to the Hmmalign binary.',
)
_HMMSEARCH_BINARY_PATH = flags.DEFINE_string(
'hmmsearch_binary_path',
shutil.which('hmmsearch'),
'Path to the Hmmsearch binary.',
)
_HMMBUILD_BINARY_PATH = flags.DEFINE_string(
'hmmbuild_binary_path',
shutil.which('hmmbuild'),
'Path to the Hmmbuild binary.',
)
# Database paths.
_DB_DIR = flags.DEFINE_string(
'db_dir',
DEFAULT_DB_DIR.as_posix(),
'Path to the directory containing the databases.',
)
_SMALL_BFD_DATABASE_PATH = flags.DEFINE_string(
'small_bfd_database_path',
'${DB_DIR}/bfd-first_non_consensus_sequences.fasta',
'Small BFD database path, used for protein MSA search.',
)
_MGNIFY_DATABASE_PATH = flags.DEFINE_string(
'mgnify_database_path',
'${DB_DIR}/mgy_clusters_2022_05.fa',
'Mgnify database path, used for protein MSA search.',
)
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string(
'uniprot_cluster_annot_database_path',
'${DB_DIR}/uniprot_all_2021_04.fa',
'UniProt database path, used for protein paired MSA search.',
)
_UNIREF90_DATABASE_PATH = flags.DEFINE_string(
'uniref90_database_path',
'${DB_DIR}/uniref90_2022_05.fa',
'UniRef90 database path, used for MSA search. The MSA obtained by '
'searching it is used to construct the profile for template search.',
)
_NTRNA_DATABASE_PATH = flags.DEFINE_string(
'ntrna_database_path',
'${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta',
'NT-RNA database path, used for RNA MSA search.',
)
_RFAM_DATABASE_PATH = flags.DEFINE_string(
'rfam_database_path',
'${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta',
'Rfam database path, used for RNA MSA search.',
)
_RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string(
'rna_central_database_path',
'${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta',
'RNAcentral database path, used for RNA MSA search.',
)
_PDB_DATABASE_PATH = flags.DEFINE_string(
'pdb_database_path',
'${DB_DIR}/pdb_2022_09_28_mmcif_files.tar',
'PDB database directory with mmCIF files path, used for template search.',
)
_SEQRES_DATABASE_PATH = flags.DEFINE_string(
'seqres_database_path',
'${DB_DIR}/pdb_seqres_2022_09_28.fasta',
'PDB sequence database path, used for template search.',
)
# Number of CPUs to use for MSA tools.
_JACKHMMER_N_CPU = flags.DEFINE_integer(
'jackhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8),
'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going'
' beyond 8 CPUs provides very little additional speedup.',
)
_NHMMER_N_CPU = flags.DEFINE_integer(
'nhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8),
'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going'
' beyond 8 CPUs provides very little additional speedup.',
)
# Compilation cache.
_JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string(
'jax_compilation_cache_dir',
None,
'Path to a directory for the JAX compilation cache.',
)
# Compilation buckets.
_BUCKETS = flags.DEFINE_list(
'buckets',
# pyformat: disable
['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
'3584', '4096', '4608', '5120'],
# pyformat: enable
'Strictly increasing order of token sizes for which to cache compilations.'
' For any input with more tokens than the largest bucket size, a new bucket'
' is created for exactly that number of tokens.',
)
class ConfigurableModel(Protocol):
"""A model with a nested config class."""
class Config(base_config.BaseConfig):
...
def __call__(self, config: Config) -> Self:
...
@classmethod
def get_inference_result(
cls: Self,
batch: features.BatchDict,
result: base_model.ModelResult,
target_name: str = '',
) -> Iterable[base_model.InferenceResult]:
...
ModelT = TypeVar('ModelT', bound=ConfigurableModel)
def make_model_config(
*,
model_class: type[ModelT] = diffusion_model.Diffuser,
flash_attention_implementation: attention.Implementation = 'triton',
):
config = model_class.Config()
if hasattr(config, 'global_config'):
config.global_config.flash_attention_implementation = (
flash_attention_implementation
)
return config
class ModelRunner:
"""Helper class to run structure prediction stages."""
def __init__(
self,
model_class: ConfigurableModel,
config: base_config.BaseConfig,
device: jax.Device,
model_dir: pathlib.Path,
):
self._model_class = model_class
self._model_config = config
self._device = device
self._model_dir = model_dir
@functools.cached_property
def model_params(self) -> hk.Params:
"""Loads model parameters from the model directory."""
return params.get_model_haiku_params(model_dir=self._model_dir)
@functools.cached_property
def _model(
self,
) -> Callable[[jnp.ndarray, features.BatchDict], base_model.ModelResult]:
"""Loads model parameters and returns a jitted model forward pass."""
assert isinstance(self._model_config, self._model_class.Config)
@hk.transform
def forward_fn(batch):
result = self._model_class(self._model_config)(batch)
result['__identifier__'] = self.model_params['__meta__']['__identifier__']
return result
return functools.partial(
jax.jit(forward_fn.apply, device=self._device), self.model_params
)
def run_inference(
self, featurised_example: features.BatchDict, rng_key: jnp.ndarray
) -> base_model.ModelResult:
"""Computes a forward pass of the model on a featurised example."""
featurised_example = jax.device_put(
jax.tree_util.tree_map(
jnp.asarray, utils.remove_invalidly_typed_feats(featurised_example)
),
self._device,
)
result = self._model(rng_key, featurised_example)
result = jax.tree.map(np.asarray, result)
result = jax.tree.map(
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x,
result,
)
result['__identifier__'] = result['__identifier__'].tobytes()
return result
def extract_structures(
self,
batch: features.BatchDict,
result: base_model.ModelResult,
target_name: str,
) -> list[base_model.InferenceResult]:
"""Generates structures from model outputs."""
return list(
self._model_class.get_inference_result(
batch=batch, result=result, target_name=target_name
)
)
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResultsForSeed:
"""Stores the inference results (diffusion samples) for a single seed.
Attributes:
seed: The seed used to generate the samples.
inference_results: The inference results, one per sample.
full_fold_input: The fold input that must also include the results of
running the data pipeline - MSA and templates.
"""
seed: int
inference_results: Sequence[base_model.InferenceResult]
full_fold_input: folding_input.Input
def predict_structure(
fold_input: folding_input.Input,
model_runner: ModelRunner,
buckets: Sequence[int] | None = None,
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""
print(f'Featurising data for seeds {fold_input.rng_seeds}...')
featurisation_start_time = time.time()
ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd)
featurised_examples = featurisation.featurise_input(
fold_input=fold_input, buckets=buckets, ccd=ccd, verbose=True
)
print(
f'Featurising data for seeds {fold_input.rng_seeds} took '
f' {time.time() - featurisation_start_time:.2f} seconds.'
)
all_inference_start_time = time.time()
all_inference_results = []
for seed, example in zip(fold_input.rng_seeds, featurised_examples):
print(f'Running model inference for seed {seed}...')
inference_start_time = time.time()
rng_key = jax.random.PRNGKey(seed)
result = model_runner.run_inference(example, rng_key)
print(
f'Running model inference for seed {seed} took '
f' {time.time() - inference_start_time:.2f} seconds.'
)
print(f'Extracting output structures (one per sample) for seed {seed}...')
extract_structures = time.time()
inference_results = model_runner.extract_structures(
batch=example, result=result, target_name=fold_input.name
)
print(
f'Extracting output structures (one per sample) for seed {seed} took '
f' {time.time() - extract_structures:.2f} seconds.'
)
all_inference_results.append(
ResultsForSeed(
seed=seed,
inference_results=inference_results,
full_fold_input=fold_input,
)
)
print(
'Running model inference and extracting output structures for seed'
f' {seed} took {time.time() - inference_start_time:.2f} seconds.'
)
print(
'Running model inference and extracting output structures for seeds'
f' {fold_input.rng_seeds} took '
f' {time.time() - all_inference_start_time:.2f} seconds.'
)
return all_inference_results
def write_fold_input_json(
fold_input: folding_input.Input,
output_dir: os.PathLike[str] | str,
) -> None:
"""Writes the input JSON to the output directory."""
os.makedirs(output_dir, exist_ok=True)
with open(
os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json'), 'wt'
) as f:
f.write(fold_input.to_json())
def write_outputs(
all_inference_results: Sequence[ResultsForSeed],
output_dir: os.PathLike[str] | str,
job_name: str,
) -> None:
"""Writes outputs to the specified output directory."""
ranking_scores = []
max_ranking_score = None
max_ranking_result = None
output_terms = (
pathlib.Path(alphafold3.cpp.__file__).parent / 'OUTPUT_TERMS_OF_USE.md'
).read_text()
os.makedirs(output_dir, exist_ok=True)
for results_for_seed in all_inference_results:
seed = results_for_seed.seed
for sample_idx, result in enumerate(results_for_seed.inference_results):
sample_dir = os.path.join(output_dir, f'seed-{seed}_sample-{sample_idx}')
os.makedirs(sample_dir, exist_ok=True)
post_processing.write_output(
inference_result=result, output_dir=sample_dir
)
ranking_score = float(result.metadata['ranking_score'])
ranking_scores.append((seed, sample_idx, ranking_score))
if max_ranking_score is None or ranking_score > max_ranking_score:
max_ranking_score = ranking_score
max_ranking_result = result
if max_ranking_result is not None: # True iff ranking_scores non-empty.
post_processing.write_output(
inference_result=max_ranking_result,
output_dir=output_dir,
# The output terms of use are the same for all seeds/samples.
terms_of_use=output_terms,
name=job_name,
)
# Save csv of ranking scores with seeds and sample indices, to allow easier
# comparison of ranking scores across different runs.
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f:
writer = csv.writer(f)
writer.writerow(['seed', 'sample', 'ranking_score'])
writer.writerows(ranking_scores)
@overload
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
model_runner: None,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
) -> folding_input.Input:
...
@overload
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
model_runner: ModelRunner,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
) -> Sequence[ResultsForSeed]:
...
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
model_runner: ModelRunner | None,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input.
Args:
fold_input: Fold input to process.
data_pipeline_config: Data pipeline config to use. If None, skip the data
pipeline.
model_runner: Model runner to use. If None, skip inference.
output_dir: Output directory to write to.
buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation
of the model. If None, calculate the appropriate bucket size from the
number of tokens. If not None, must be a sequence of at least one integer,
in strictly increasing order. Will raise an error if the number of tokens
is more than the largest bucket size.
Returns:
The processed fold input, or the inference results for each seed.
Raises:
ValueError: If the fold input has no chains.
"""
print(f'Processing fold input {fold_input.name}')
if not fold_input.chains:
raise ValueError('Fold input has no chains.')
if model_runner is not None:
# If we're running inference, check we can load the model parameters before
# (possibly) launching the data pipeline.
print('Checking we can load the model parameters...')
_ = model_runner.model_params
if data_pipeline_config is None:
print('Skipping data pipeline...')
else:
print('Running data pipeline...')
fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input)
print(f'Output directory: {output_dir}')
print(f'Writing model input JSON to {output_dir}')
write_fold_input_json(fold_input, output_dir)
if model_runner is None:
print('Skipping inference...')
output = fold_input
else:
print(
f'Predicting 3D structure for {fold_input.name} for seed(s)'
f' {fold_input.rng_seeds}...'
)
all_inference_results = predict_structure(
fold_input=fold_input,
model_runner=model_runner,
buckets=buckets,
)
print(
f'Writing outputs for {fold_input.name} for seed(s)'
f' {fold_input.rng_seeds}...'
)
write_outputs(
all_inference_results=all_inference_results,
output_dir=output_dir,
job_name=fold_input.sanitised_name(),
)
output = all_inference_results
print(f'Done processing fold input {fold_input.name}.')
return output
def main(_):
if _JAX_COMPILATION_CACHE_DIR.value is not None:
jax.config.update(
'jax_compilation_cache_dir', _JAX_COMPILATION_CACHE_DIR.value
)
if _JSON_PATH.value is None == _INPUT_DIR.value is None:
raise ValueError(
'Exactly one of --json_path or --input_dir must be specified.'
)
if not _RUN_INFERENCE.value and not _RUN_DATA_PIPELINE.value:
raise ValueError(
'At least one of --run_inference or --run_data_pipeline must be'
' set to true.'
)
if _INPUT_DIR.value is not None:
fold_inputs = folding_input.load_fold_inputs_from_dir(
pathlib.Path(_INPUT_DIR.value)
)
elif _JSON_PATH.value is not None:
fold_inputs = folding_input.load_fold_inputs_from_path(
pathlib.Path(_JSON_PATH.value)
)
else:
raise AssertionError(
'Exactly one of --json_path or --input_dir must be specified.'
)
# Make sure we can create the output directory before running anything.
try:
os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
except OSError as e:
print(f'Failed to create output directory {_OUTPUT_DIR.value}: {e}')
raise
notice = textwrap.wrap(
'Running AlphaFold 3. Please note that standard AlphaFold 3 model'
' parameters are only available under terms of use provided at'
' https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.'
' If you do not agree to these terms and are using AlphaFold 3 derived'
' model parameters, cancel execution of AlphaFold 3 inference with'
' CTRL-C, and do not use the model parameters.',
break_long_words=False,
break_on_hyphens=False,
width=80,
)
print('\n'.join(notice))
if _RUN_DATA_PIPELINE.value:
replace_db_dir = lambda x: string.Template(x).substitute(
DB_DIR=_DB_DIR.value
)
data_pipeline_config = pipeline.DataPipelineConfig(
jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value,
nhmmer_binary_path=_NHMMER_BINARY_PATH.value,
hmmalign_binary_path=_HMMALIGN_BINARY_PATH.value,
hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value,
hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value,
small_bfd_database_path=replace_db_dir(_SMALL_BFD_DATABASE_PATH.value),
mgnify_database_path=replace_db_dir(_MGNIFY_DATABASE_PATH.value),
uniprot_cluster_annot_database_path=replace_db_dir(
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value
),
uniref90_database_path=replace_db_dir(_UNIREF90_DATABASE_PATH.value),
ntrna_database_path=replace_db_dir(_NTRNA_DATABASE_PATH.value),
rfam_database_path=replace_db_dir(_RFAM_DATABASE_PATH.value),
rna_central_database_path=replace_db_dir(
_RNA_CENTRAL_DATABASE_PATH.value
),
pdb_database_path=replace_db_dir(_PDB_DATABASE_PATH.value),
seqres_database_path=replace_db_dir(_SEQRES_DATABASE_PATH.value),
jackhmmer_n_cpu=_JACKHMMER_N_CPU.value,
nhmmer_n_cpu=_NHMMER_N_CPU.value,
)
else:
print('Skipping running the data pipeline.')
data_pipeline_config = None
if _RUN_INFERENCE.value:
devices = jax.local_devices(backend='gpu')
print(f'Found local devices: {devices}')
print('Building model from scratch...')
model_runner = ModelRunner(
model_class=diffusion_model.Diffuser,
config=make_model_config(
flash_attention_implementation=typing.cast(
attention.Implementation, _FLASH_ATTENTION_IMPLEMENTATION.value
)
),
device=devices[0],
model_dir=pathlib.Path(_MODEL_DIR.value),
)
else:
print('Skipping running model inference.')
model_runner = None
print(f'Processing {len(fold_inputs)} fold inputs.')
for fold_input in fold_inputs:
process_fold_input(
fold_input=fold_input,
data_pipeline_config=data_pipeline_config,
model_runner=model_runner,
output_dir=os.path.join(_OUTPUT_DIR.value, fold_input.sanitised_name()),
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
)
print(f'Done processing {len(fold_inputs)} fold inputs.')
if __name__ == '__main__':
flags.mark_flags_as_required([
'output_dir',
])
app.run(main)