forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_variable_factories.py
80 lines (65 loc) · 2.97 KB
/
gen_variable_factories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables.
#
# This writes one file: variable_factories.h
import re
from .utils import CodeTemplate, write
from .gen_variable_type import format_trace
FUNCTION_TEMPLATE = CodeTemplate("""\
inline at::Tensor ${name}(${formals}) {
${pre_record_trace}
at::Tensor tensor = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
return at::${name}(${actuals});
})();
at::Tensor result =
autograd::make_variable(std::move(tensor), /*requires_grad=*/${requires_grad});
${post_record_trace}
return result;
}
""")
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
def fully_qualified_type(argument_type):
def maybe_optional_type(t, opt_match):
return 'c10::optional<{}>'.format(t) if opt_match else t
opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
if opt_match:
argument_type = argument_type[opt_match.start(1):opt_match.end(1)]
match = TYPE_PATTERN.match(argument_type)
if match is None:
return maybe_optional_type(argument_type, opt_match)
index = match.start(1)
qualified_type = "{}at::{}".format(argument_type[:index], argument_type[index:])
return maybe_optional_type(qualified_type, opt_match)
def gen_variable_factories(out, declarations, template_path, disable_autograd=False):
function_definitions = []
for decl in declarations:
has_tensor_options = any(a["simple_type"] == "TensorOptions" for a in decl["arguments"])
is_namespace_fn = 'namespace' in decl['method_of']
if (has_tensor_options or decl["name"].endswith("_like")) and is_namespace_fn:
function_definitions.append(
process_function(decl, has_tensor_options, disable_autograd=disable_autograd))
write(out,
"variable_factories.h",
CodeTemplate.from_file(template_path + "/variable_factories.h"),
{"function_definitions": function_definitions})
def process_function(decl, has_tensor_options, disable_autograd):
formals = []
actuals = []
for argument in decl["arguments"]:
type = fully_qualified_type(argument["type"])
default = " = {}".format(argument["default"]) if "default" in argument else ""
formals.append("{} {}{}".format(type, argument["name"], default))
actual = argument["name"]
if argument["simple_type"] == "TensorOptions":
actual = "at::TensorOptions({})".format(actual)
actuals.append(actual)
requires_grad = "options.requires_grad()" if has_tensor_options else "false"
if not disable_autograd:
pre_record_trace, post_record_trace = format_trace(decl)
else:
pre_record_trace, post_record_trace = '', ''
return FUNCTION_TEMPLATE.substitute(
name=decl["name"], formals=formals, actuals=actuals, requires_grad=requires_grad,
pre_record_trace=pre_record_trace, post_record_trace=post_record_trace
)