From d114934f05e206782fbe4d4b6fd34d7d6a582127 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 10 May 2024 15:19:46 -0400 Subject: [PATCH] =?UTF-8?q?perf:=20=E2=9A=A1=EF=B8=8F=20Improve=20array=20?= =?UTF-8?q?size=20increase.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/leibnetz/leibnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/leibnetz/leibnet.py b/src/leibnetz/leibnet.py index 38aad90..f9a1d3b 100644 --- a/src/leibnetz/leibnet.py +++ b/src/leibnetz/leibnet.py @@ -313,12 +313,12 @@ def is_valid_input_shape(self, input_key, input_shape): == 0 ).all() - def step_up_size(self, steps: int = 1): + def step_up_size(self, steps: int = 1, step_size: int = 1): for n in range(steps): target_arrays = {} for name, metadata in self.output_shapes.items(): target_arrays[name] = tuple( - (tuple(s + 1 for s in metadata["shape"]), metadata["scale"]) + (tuple(s + step_size for s in metadata["shape"]), metadata["scale"]) ) self.compute_shapes(target_arrays, set=True)