forked from PaddlePaddle/PaddleSlim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dy2prog.py
78 lines (61 loc) · 2.3 KB
/
test_dy2prog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import sys
sys.path.append("../")
os.environ['FLAGS_enable_eager_mode'] = "1"
import paddle
import unittest
from paddleslim.core import dygraph2program
class Model(paddle.nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv = paddle.nn.Conv2D(
in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1)
self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D([1, 1])
self.out = paddle.nn.Linear(256, 10)
def forward(self, inputs):
inputs = paddle.reshape(inputs, shape=[0, 1, 28, 28])
y = self.conv(inputs)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, 256])
y = self.out(y)
return y
class TestEagerDygraph2Program(unittest.TestCase):
def setUp(self):
self.prepare_inputs()
self.prepare_layer()
def prepare_inputs(self):
self.inputs = [3, 28, 28]
self.ops = [
'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
]
def prepare_layer(self):
self.layer = Model()
def test_dy2prog(self):
program = dygraph2program(self.layer, self.inputs)
self.assert_program(program)
def assert_program(self, program):
self.assertListEqual([op.type for op in program.block(0).ops], self.ops)
class TestEagerDygraph2Program2(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = [[3, 28, 28]]
self.ops = [
'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
]
class TestEagerDygraph2Program3(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = paddle.randn([3, 28, 28])
self.ops = [
'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
'reshape2', 'matmul_v2', 'elementwise_add'
]
class TestEagerDygraph2Program4(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = [paddle.randn([3, 28, 28])]
self.ops = [
'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
'reshape2', 'matmul_v2', 'elementwise_add'
]
if __name__ == "__main__":
unittest.main()