Skip to content

Commit

Permalink
Small modifications to crypten.load (facebookresearch#203)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/CrypTen#203

- Modified crypten.load to allow the input `f` to be None.
- Added assert to ensure exactly  one of `f` or `preloaded` is None
- Updated unit tests to ensure preloaded loads correctly.

Reviewed By: knottb

Differential Revision: D21212000

fbshipit-source-id: 82042e208a3a95328ec5472dba9ce964d8033075
  • Loading branch information
Shobha Venkataraman authored and facebook-github-bot committed Apr 24, 2020
1 parent 1aa8ced commit 1b295d9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
11 changes: 9 additions & 2 deletions crypten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _setup_przs():


def load_from_party(
f,
f=None,
preloaded=None,
encrypted=False,
dummy_model=None,
Expand Down Expand Up @@ -245,9 +245,16 @@ def load_from_party(
src >= 0 and src < comm.get().get_world_size()
), "Load failed: src must be in [0, world_size)"

assert (f is None and (preloaded is not None)) or (
(f is not None) and preloaded is None
), "Exactly one of f and preloaded must not be None"

# source party
if comm.get().get_rank() == src:
result = preloaded if preloaded else load_closure(f, **kwargs)
if f is None:
result = preloaded
if preloaded is None:
result = load_closure(f, **kwargs)

# Zero out the tensors / modules to hide loaded data from broadcast
if torch.is_tensor(result):
Expand Down
10 changes: 10 additions & 0 deletions test/test_crypten.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ def custom_save_function(obj, f):
complete_file, src=src, load_closure=(lambda f: None)
)

# test pre-loaded
encrypted_preloaded = crypten.load_from_party(
src=src, preloaded=tensor
)
self._check(
encrypted_preloaded,
reference,
"crypten.load() failed using preloaded",
)

def test_save_load_module(self):
"""Test that crypten.save and crypten.load properly save and load modules"""
import tempfile
Expand Down

0 comments on commit 1b295d9

Please sign in to comment.