We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
---> 25 utils.log_model_info(log_file, 26 train_state_initializer.global_train_state_shape, 27 partitioner)
File /kaggle/working/t5x/t5x/utils.py:1391, in log_model_info(log_file, full_train_state, partitioner) 1387 return 1389 state_dict = full_train_state.state_dict() 1390 total_num_params = jax.tree_util.tree_reduce( -> 1391 np.add, jax.tree.map(np.size, state_dict['target']) 1392 ) 1394 logical_axes = partitioner.get_logical_axes(full_train_state).state_dict() 1396 mesh_axes = jax.tree.map( 1397 lambda x: tuple(x) if x is not None else None, 1398 partitioner.get_mesh_axes(full_train_state).state_dict(), 1399 )
File /opt/conda/lib/python3.10/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr..getattr(name) 51 warnings.warn(message, DeprecationWarning, stacklevel=2) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'
The text was updated successfully, but these errors were encountered:
same here
Sorry, something went wrong.
I use jax.tree_util.tree_map() instead. API can be found here.
jax.tree_util.tree_map()
No branches or pull requests
---> 25 utils.log_model_info(log_file,
26 train_state_initializer.global_train_state_shape,
27 partitioner)
File /kaggle/working/t5x/t5x/utils.py:1391, in log_model_info(log_file, full_train_state, partitioner)
1387 return
1389 state_dict = full_train_state.state_dict()
1390 total_num_params = jax.tree_util.tree_reduce(
-> 1391 np.add, jax.tree.map(np.size, state_dict['target'])
1392 )
1394 logical_axes = partitioner.get_logical_axes(full_train_state).state_dict()
1396 mesh_axes = jax.tree.map(
1397 lambda x: tuple(x) if x is not None else None,
1398 partitioner.get_mesh_axes(full_train_state).state_dict(),
1399 )
File /opt/conda/lib/python3.10/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr..getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'
The text was updated successfully, but these errors were encountered: