forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestBatched.test_while.expect
61 lines (61 loc) · 2.58 KB
/
TestBatched.test_while.expect
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
graph(%a.1_data : Tensor
%a.1_mask : Tensor
%a.1_dims : Tensor
%b_data : Tensor
%b_mask : Tensor
%b_dims : Tensor) {
%6 : int = prim::Constant[value=1]()
%7 : int = prim::Constant[value=9223372036854775807]()
%8 : Tensor = aten::gt(%a.1_data, %b_data)
%9 : Tensor = aten::mul(%a.1_mask, %b_mask)
%10 : Tensor = aten::__or__(%a.1_dims, %b_dims)
%11 : int = prim::Constant[value=0]()
%12 : Tensor = aten::mul(%8, %9)
%13 : Tensor = aten::sum(%12)
%14 : Tensor = aten::gt(%13, %11)
%15 : bool = prim::Bool(%14)
%16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
block0(%loop_num : int, %cond_data.2 : Tensor, %cond_mask.2 : Tensor, %6_data : Tensor, %6_mask : Tensor, %6_dims : Tensor) {
%27 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::Float(%27)
%data : Tensor = aten::sub(%6_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%6_mask, %b_mask)
%dims : Tensor = aten::__or__(%6_dims, %b_dims)
%32 : Tensor = aten::gt(%data, %b_data)
%33 : Tensor = aten::mul(%mask, %b_mask)
%34 : bool = prim::Constant[value=1]()
%35 : int = prim::Constant[value=1]()
%36 : Tensor = aten::type_as(%cond_mask.2, %cond_data.2)
%data.2 : Tensor = aten::mul(%cond_data.2, %36)
%38 : int = aten::dim(%data.2)
%39 : bool = aten::eq(%38, %35)
%cond_data : Tensor, %cond_mask : Tensor = prim::If(%39)
block0() {
%42 : int = aten::dim(%data)
%43 : int = aten::sub(%42, %35)
%data.4 : Tensor = prim::Loop(%43, %34, %data.2)
block0(%45 : int, %46 : Tensor) {
%47 : int = aten::dim(%46)
%data.3 : Tensor = aten::unsqueeze(%46, %47)
-> (%34, %data.3)
}
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
-> (%cond_data.1, %cond_mask.1)
}
block1() {
-> (%data.2, %data.2)
}
%res_data : Tensor = aten::where(%cond_data, %data, %6_data)
%res_mask : Tensor = aten::where(%cond_mask, %mask, %6_mask)
%res_dims : Tensor = aten::__or__(%dims, %6_dims)
%54 : int = prim::Constant[value=0]()
%55 : Tensor = aten::mul(%32, %33)
%56 : Tensor = aten::sum(%55)
%57 : Tensor = aten::gt(%56, %54)
%58 : bool = prim::Bool(%57)
-> (%58, %32, %33, %res_data, %res_mask, %res_dims)
}
%59 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%a, %19, %20)
return (%59);
}