From 4b9abebaa815349c0b1ffaa1e5e9dc15421334d9 Mon Sep 17 00:00:00 2001 From: Brian Knott Date: Mon, 10 May 2021 13:26:46 -0700 Subject: [PATCH] Fix communication timer for benchmarking communcation Summary: See https://github.com/facebookresearch/CrypTen/issues/255. Differential Revision: D28332892 fbshipit-source-id: c15177c347eccf9f9d650b13cf525d6efe285d6f --- crypten/communicator/communicator.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/crypten/communicator/communicator.py b/crypten/communicator/communicator.py index 38e95e0b..0788b098 100644 --- a/crypten/communicator/communicator.py +++ b/crypten/communicator/communicator.py @@ -115,18 +115,6 @@ def get_name(self): """Returns the party name of the current process.""" raise NotImplementedError("get_name is not implemented") - def reset_communication_stats(self): - """Resets communication statistics.""" - raise NotImplementedError("reset_communication_stats is not implemented") - - def print_communication_stats(self): - """Prints communication statistics.""" - raise NotImplementedError("print_communication_stats is not implemented") - - def _log_communication(self, nelement): - """Updates log of communication statistics.""" - raise NotImplementedError("_log_communication is not implemented") - def reset_communication_stats(self): """Resets communication statistics.""" self.comm_rounds = 0 @@ -135,10 +123,12 @@ def reset_communication_stats(self): def print_communication_stats(self): """Prints communication statistics.""" - logging.info("====Communication Stats====") - logging.info("Rounds: {}".format(self.comm_rounds)) - logging.info("Bytes : {}".format(self.comm_bytes)) - logging.info("Comm time: {}".format(self.comm_time)) + import crypten + + crypten.log("====Communication Stats====") + crypten.log("Rounds: {}".format(self.comm_rounds)) + crypten.log("Bytes : {}".format(self.comm_bytes)) + crypten.log("Comm time: {}".format(self.comm_time)) def _log_communication(self, nelement): """Updates log of communication statistics.""" @@ -201,9 +191,9 @@ def logging_wrapper(self, *args, **kwargs): else: # one tensor communicated self._log_communication(args[0].nelement()) - tic = timeit.timeit() + tic = timeit.default_timer() result = func(self, *args, **kwargs) - toc = timeit.timeit() + toc = timeit.default_timer() self._log_communication_time(toc - tic) return result