Skip to content

Commit

Permalink
sls + layernorm test (pytorch#43799)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#43799

Test Plan: https://www.internalfb.com/intern/testinfra/testconsole/testrun/3096224784866350/

Reviewed By: venkatacrc

Differential Revision: D23383351

fbshipit-source-id: c312d481ad15bded83bea90beaaae7742d0c54b8
  • Loading branch information
Hector Yuen authored and facebook-github-bot committed Dec 14, 2020
1 parent be849ed commit 9e3c25f
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions caffe2/contrib/fakelowp/test/test_sls_8bit_nnpi_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,147 @@ def test_small_sls(self, seed):
)
assert 0

@given(seed=st.integers(0, 65535))
@settings(deadline=datetime.timedelta(seconds=10))
def test_sls_layernorm(self, seed):
np.random.seed(seed)
workspace.ResetWorkspace()

n = 2
DIM = 3
data = 4 * (np.random.random_sample((n, DIM)) + 1).astype(np.float32)

lengths = np.array([n], dtype=np.int32)
indices = np.array(range(n), dtype=np.int64)
weights = np.random.uniform(low=0.01, high=0.5, size=[n]).astype(np.float32)

pred_net = caffe2_pb2.NetDef()
pred_net.name = "pred"
pred_net.external_input.extend(
["quantized_data", "weights", "indices", "lengths"]
)
pred_net.external_output.append("Y_norm")
pred_net.external_output.append("Y_mean")
pred_net.external_output.append("Y_std")

pred_net.op.add().CopyFrom(
core.CreateOperator(
"SparseLengthsWeightedSumFused8BitRowwise",
["quantized_data", "weights", "indices", "lengths"],
["Y"],
)
)

pred_net.op.add().CopyFrom(
core.CreateOperator(
"LayerNorm",
["Y"],
["Y_norm", "Y_mean", "Y_std"],
epsilon=1e-4,
)
)

ref_net = caffe2_pb2.NetDef()
ref_net.name = "ref"
ref_net.external_input.extend(
["quantized_data", "weights", "indices", "lengths"]
)
ref_net.external_output.append("Y_norm")
ref_net.external_output.append("Y_mean")
ref_net.external_output.append("Y_std")

ref_net.op.add().CopyFrom(
core.CreateOperator(
"SparseLengthsWeightedSumFused8BitRowwiseFakeFP16NNPI",
["quantized_data", "weights", "indices", "lengths"],
["Y"],
)
)

ref_net.op.add().CopyFrom(
core.CreateOperator(
"LayerNormFakeFP16NNPI",
["Y"],
["Y_norm", "Y_mean", "Y_std"],
epsilon=1e-4,
axis=1,
elementwise_affine=False
)
)

workspace.FeedBlob("data", data)
workspace.RunOperatorOnce(
core.CreateOperator(
"FloatToFused8BitRowwiseQuantized", ["data"], ["quantized_data"]
)
)

quantized_data = workspace.FetchBlob("quantized_data")

onnxified_net = onnxifi_caffe2_net(
pred_net,
{},
max_batch_size=1,
max_seq_size=n,
debug=True,
adjust_batch=True,
use_onnx=False,
)
print("before", pred_net)
print("after", onnxified_net)
workspace.FeedBlob("indices", indices)
workspace.FeedBlob("lengths", lengths)
workspace.FeedBlob("weights", weights)

workspace.CreateNet(onnxified_net)
workspace.CreateNet(ref_net)

workspace.RunNet(onnxified_net.name)
Y_glow = workspace.FetchBlob("Y_norm")
Y_mean_glow = workspace.FetchBlob("Y_mean")
Y_std_glow = workspace.FetchBlob("Y_std")

workspace.RunNet(ref_net.name)
Y = workspace.FetchBlob("Y")
print("pre normalization", Y)
Y_ref = workspace.FetchBlob("Y_norm")
Y_mean_ref = workspace.FetchBlob("Y_mean")
Y_std_ref = workspace.FetchBlob("Y_std")

# print(Y_ref, Y_glow)
# print(Y_ref.shape, Y_glow.shape)

diff = np.abs(Y_ref - Y_glow)
max_err = np.max(diff, axis=1)
num_offenders = (max_err > 0).sum()
if num_offenders > 0:
np.set_printoptions(precision=12)
print(
"ref",
Y_ref.astype(np.float16).astype(np.float32),
"glow",
Y_glow.astype(np.float16).astype(np.float32),
)
print_test_debug_info(
"slws_fused_8bit_rowwise_inv_scale",
{
"seed": seed,
"indices": indices,
"data": data,
"quantized_data": quantized_data,
"lengths": lengths,
"weights": weights,
"Y_norm_glow": Y_glow,
"Y_norm_ref": Y_ref,
"Y_mean_glow": Y_mean_glow,
"Y_std_glow": Y_std_glow,
"Y_mean_ref": Y_mean_ref,
"Y_std_ref": Y_std_ref,
"diff": diff,
"rowwise_diff": np.max(diff, axis=1),
},
)
assert 0


if __name__ == '__main__':
Expand Down

0 comments on commit 9e3c25f

Please sign in to comment.