Skip to content

Commit

Permalink
update module 1
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 30, 2021
1 parent 23f347c commit 671fb13
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 68 deletions.
89 changes: 63 additions & 26 deletions minitorch/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
variable_count = 1


# ## Module 1

# Variable is the main class for autodifferentiation logic for scalars
# and tensors.


class Variable:
"""
Attributes:
history (:class:`History` or None) : the Function calls that created this variable or None if constant
derivative (variable type): the derivative with respect to this variable
grad (variable type) : alias for derivative (PyTorch name)
name (string) : an optional name for debugging
grad (variable type) : alias for derivative, used for tensors
name (string) : a globally unique name of the variable
"""

def __init__(self, history, name=None):
Expand All @@ -26,7 +32,6 @@ def __init__(self, history, name=None):
self.name = name
else:
self.name = self.unique_id

self.used = 0

def requires_grad_(self, val):
Expand Down Expand Up @@ -61,7 +66,6 @@ def is_leaf(self):
"True if this variable created by the user (no `last_fn`)"
return self.history.last_fn is None

## IGNORE
def accumulate_derivative(self, val):
"""
Add `val` to the the derivative accumulated on this variable.
Expand Down Expand Up @@ -103,6 +107,9 @@ def zeros(self):
return 0.0


# Some helper functions for handling optional tuples.


def wrap_tuple(x):
"Turn a possible value into a tuple"
if isinstance(x, tuple):
Expand All @@ -117,15 +124,17 @@ def unwrap_tuple(x):
return x


# Classes for Functions.


class Context:
"""
Context class is used by `Function` to store information during the forward pass.
Attributes:
no_grad (bool) : do not save gradient information
saved_values (tuple) : tuple of values saved for backward pass
saved_tensors (tuple) : alias for saved_values (PyTorch name)
saved_tensors (tuple) : alias for saved_values
"""

def __init__(self, no_grad=False):
Expand Down Expand Up @@ -181,7 +190,8 @@ def backprop_step(self, d_output):
Returns:
list of numbers : a derivative with respect to `inputs`
"""
return self.last_fn.chain_rule(self.ctx, self.inputs, d_output)
# TODO: Implement for Task 1.4.
raise NotImplementedError('Need to implement for Task 1.4')


class FunctionBase:
Expand All @@ -195,10 +205,30 @@ class FunctionBase:

@staticmethod
def variable(raw, history):
# Implement by children class.
raise NotImplementedError()

@classmethod
def apply(cls, *vals):
"""
Apply is called by the user to run the Function.
Internally it does three things:
a) Creates a Context for the function call.
b) Calls forward to run the function.
c) Attaches the Context to the History of the new variable.
There is a bit of internal complexity in our implementation
to handle both scalars and tensors.
Args:
vals (list of Variables or constants) : The arguments to forward
Returns:
`Variable` : The new variable produced
"""
# Go through the variables to see if any needs grad.
raw_vals = []
need_grad = False
for v in vals:
Expand All @@ -209,12 +239,18 @@ def apply(cls, *vals):
raw_vals.append(v.get_data())
else:
raw_vals.append(v)

# Create the context.
ctx = Context(not need_grad)

# Call forward with the variables.
c = cls.forward(ctx, *raw_vals)
assert isinstance(c, cls.data_type), "Expected return typ %s got %s" % (
cls.data_type,
type(c),
)

# Create a new variable from the result with a new history.
back = None
if need_grad:
back = History(cls, ctx, vals)
Expand All @@ -231,49 +267,50 @@ def chain_rule(cls, ctx, inputs, d_output):
d_output (number) : The `d_output` value in the chain rule.
Returns:
list of (`Variable`, number) A list of non-constant variables with their derivatives
list of (`Variable`, number) : A list of non-constant variables with their derivatives
(see `is_constant` to remove unneeded variables)
"""
# Tip: Note when implementing this function that
# cls.backward may return either a value or a tuple.
# TODO: Implement for Task 1.3.
raise NotImplementedError('Need to implement for Task 1.3')


# Algorithms for backpropagation


def is_constant(val):
return not isinstance(val, Variable) or val.history is None


def topological_sort(variable):
"Returns nodes in topological order"
order = []
seen = set()
"""
Computes the topological order of the computation graph.
def visit(var):
if var.unique_id in seen:
return
if not var.is_leaf():
for m in var.history.inputs:
if not is_constant(m):
visit(m)
seen.add(var.unique_id)
order.insert(0, var)
Args:
variable (:class:`Variable`): The right-most variable
visit(variable)
return order
Returns:
list of Variables : Non-constant Variables in topological order
starting from the right.
"""
# TODO: Implement for Task 1.4.
raise NotImplementedError('Need to implement for Task 1.4')


def backpropagate(variable, deriv):
"""
Runs a breadth-first search on the computation graph in order to
backpropagate derivatives to the leaves.
Runs backpropagation on the computation graph in order to
compute derivatives for the leave nodes.
See :doc:`backpropagate` for details on the algorithm.
Args:
variable (:class:`Variable`): The final variable
variable (:class:`Variable`): The right-most variable
deriv (number) : Its derivative that we want to propagate backward to the leaves.
No return. Should write to its results to the derivative values of each leaf.
No return. Should write to its results to the derivative values of each leaf through `accumulate_derivative`.
"""
# TODO: Implement for Task 1.4.
raise NotImplementedError('Need to implement for Task 1.4')
8 changes: 6 additions & 2 deletions minitorch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def __add__(self, b):
# TODO: Implement for Task 1.2.
raise NotImplementedError('Need to implement for Task 1.2')

def __bool__(self):
return bool(self.data)

def __lt__(self, b):
# TODO: Implement for Task 1.2.
raise NotImplementedError('Need to implement for Task 1.2')
Expand Down Expand Up @@ -272,8 +275,9 @@ class LT(ScalarFunction):
def forward(ctx, a, b):
# TODO: Implement for Task 1.2.
raise NotImplementedError('Need to implement for Task 1.2')
# TODO: Implement for Task 1.2.
raise NotImplementedError('Need to implement for Task 1.2')

@staticmethod
def backward(ctx, d_output):
# TODO: Implement for Task 1.4.
raise NotImplementedError('Need to implement for Task 1.4')

Expand Down
18 changes: 18 additions & 0 deletions minitorch/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ def addConstant(a):
"Add contant to the argument"
return 5 + a

@staticmethod
def square(a):
"Manual square"
return a * a

@staticmethod
def cube(a):
"Manual cube"
return a * a * a

@staticmethod
def subConstant(a):
"Subtract a constant from the argument"
Expand Down Expand Up @@ -52,6 +62,10 @@ def exp(a):
"Apply exp to a smaller value"
return operators.exp(a - 200)

@staticmethod
def explog(a):
return operators.log(a + 100000) + operators.exp(a - 200)

@staticmethod
def add2(a, b):
"Add two arguments"
Expand Down Expand Up @@ -145,6 +159,10 @@ def relu(x):
def exp(a):
return (a - 200).exp()

@staticmethod
def explog(a):
return (a + 100000).log() + (a - 200).exp()

@staticmethod
def sum_red(a):
return a.sum(0)
Expand Down
19 changes: 17 additions & 2 deletions project/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import streamlit as st
from interface.streamlit_utils import get_img_tag
from interface.train import render_train_interface
import sys
from argparse import ArgumentParser
from run_torch import TorchTrain
from math_interface import render_math_sandbox

module_num = int(sys.argv[1])
parser = ArgumentParser()
parser.add_argument("module_num", type=int)
parser.add_argument(
"--hide_function_defs", action="store_true", dest="hide_function_defs"
)
args = parser.parse_args()
module_num = args.module_num
hide_function_defs = args.hide_function_defs

st.set_page_config(page_title="interactive minitorch")
st.sidebar.markdown(
Expand Down Expand Up @@ -70,11 +77,19 @@ def render_run_scalar_interface():

if module_selection == "Module 2":
from run_tensor import TensorTrain
from tensor_interface import render_tensor_sandbox
from show_expression_interface import render_show_expression

def render_run_tensor_interface():
st.header("Module 2 - Tensors")
render_train_interface(TensorTrain)

def render_m2_sandbox():
return render_math_sandbox(True, True)

PAGES["Tensor Sandbox"] = lambda: render_tensor_sandbox(hide_function_defs)
PAGES["Tensor Math Sandbox"] = render_m2_sandbox
PAGES["Autograd Sandbox"] = lambda: render_show_expression(True)
PAGES["Module 2: Tensor"] = render_run_tensor_interface


Expand Down
15 changes: 14 additions & 1 deletion project/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ def build_expression(code):
return out


def build_tensor_expression(code):
out = eval(
code,
{
"x": minitorch.tensor([[1.0, 2.0, 3.0]], requires_grad=True),
"y": minitorch.tensor([[1.0, 2.0, 3.0]], requires_grad=True),
"z": minitorch.tensor([[1.0, 2.0, 3.0]], requires_grad=True),
},
)
out.name = "out"
return out


class GraphBuilder:
def __init__(self):
self.op_id = 0
Expand Down Expand Up @@ -44,7 +57,7 @@ def run(self, final):
(cur,) = queue[0]
queue = queue[1:]

if cur.is_leaf():
if minitorch.is_constant(cur) or cur.is_leaf():
continue
else:
op = "%s (Op %d)" % (cur.history.last_fn.__name__, self.op_id)
Expand Down
14 changes: 14 additions & 0 deletions project/interface/streamlit_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import inspect
import streamlit as st

img_id_counter = 0


Expand Down Expand Up @@ -27,3 +30,14 @@ def get_img_tag(src, width=None):
""".format(
src, img_id, img_id, style
)


def render_function(fn):
st.markdown(
"""
```python
%s
```"""
% inspect.getsource(fn)
)
Loading

0 comments on commit 671fb13

Please sign in to comment.