Skip to content

Commit

Permalink
Fix communication timer for benchmarking communcation
Browse files Browse the repository at this point in the history
Summary: See #255.

Differential Revision: D28332892

fbshipit-source-id: c15177c347eccf9f9d650b13cf525d6efe285d6f
  • Loading branch information
knottb authored and facebook-github-bot committed May 10, 2021
1 parent c4d24c5 commit 4b9abeb
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions crypten/communicator/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,6 @@ def get_name(self):
"""Returns the party name of the current process."""
raise NotImplementedError("get_name is not implemented")

def reset_communication_stats(self):
"""Resets communication statistics."""
raise NotImplementedError("reset_communication_stats is not implemented")

def print_communication_stats(self):
"""Prints communication statistics."""
raise NotImplementedError("print_communication_stats is not implemented")

def _log_communication(self, nelement):
"""Updates log of communication statistics."""
raise NotImplementedError("_log_communication is not implemented")

def reset_communication_stats(self):
"""Resets communication statistics."""
self.comm_rounds = 0
Expand All @@ -135,10 +123,12 @@ def reset_communication_stats(self):

def print_communication_stats(self):
"""Prints communication statistics."""
logging.info("====Communication Stats====")
logging.info("Rounds: {}".format(self.comm_rounds))
logging.info("Bytes : {}".format(self.comm_bytes))
logging.info("Comm time: {}".format(self.comm_time))
import crypten

crypten.log("====Communication Stats====")
crypten.log("Rounds: {}".format(self.comm_rounds))
crypten.log("Bytes : {}".format(self.comm_bytes))
crypten.log("Comm time: {}".format(self.comm_time))

def _log_communication(self, nelement):
"""Updates log of communication statistics."""
Expand Down Expand Up @@ -201,9 +191,9 @@ def logging_wrapper(self, *args, **kwargs):
else: # one tensor communicated
self._log_communication(args[0].nelement())

tic = timeit.timeit()
tic = timeit.default_timer()
result = func(self, *args, **kwargs)
toc = timeit.timeit()
toc = timeit.default_timer()

self._log_communication_time(toc - tic)
return result
Expand Down

0 comments on commit 4b9abeb

Please sign in to comment.