Skip to content

Commit

Permalink
Add support for mixed effects.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Jul 11, 2021
1 parent 8350aea commit 03c05aa
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 12 deletions.
10 changes: 10 additions & 0 deletions formulaic/model_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,15 @@ def __init__(self, data, spec=None):
def model_spec(self):
return self._self_model_spec

def subset(self, tags=None, exact=True):
tags = tags or set()
cols = [
name
for term, factors, names in self.model_spec.structure
if exact and set(term.tags) == set(tags) or not exact and set(tags).intersection(term.tags)
for name in names
]
return self[cols]

def __repr__(self):
return self.__wrapped__.__repr__() # pragma: no cover
5 changes: 4 additions & 1 deletion formulaic/parser/algos/infix_to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def infix_to_ast(tokens, operator_resolver):

output_queue = []
operator_stack = []
group_operators = []

def stack_operator(operator, token):
operator_stack.append(OrderedOperator(operator, token, len(output_queue)))
Expand Down Expand Up @@ -44,16 +45,18 @@ def operate(ordered_operator, output_queue):
else:
if token.token == '(':
stack_operator(token, token)
group_operators.append('(')
elif token.token == ')':
while operator_stack and operator_stack[-1].token != '(':
output_queue = operate(operator_stack.pop(), output_queue)
if operator_stack and operator_stack[-1].token == '(':
operator_stack.pop()
group_operators.pop()
else:
raise exc_for_token(token, "Could not find matching parenthesis.")
else:
max_prefix_arity = len(output_queue) - operator_stack[-1].index if operator_stack else len(output_queue)
operators = operator_resolver.resolve(token, max_prefix_arity)
operators = operator_resolver.resolve(token, max_prefix_arity, group_kind=group_operators[-1] if group_operators else None, group_depth=len(group_operators))

for operator in operators:

Expand Down
16 changes: 12 additions & 4 deletions formulaic/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def power(arg, power):
return [
Operator("~", arity=2, precedence=-100, associativity=None, to_terms=formula_separator_expansion),
Operator("~", arity=1, precedence=-100, associativity=None, fixity='prefix', to_terms=lambda expr: (expr.to_terms(), )),
Operator("|", arity=2, precedence=50, associativity='left', to_terms=lambda *args: {
functools.reduce(lambda x, y: x * y, term).with_tag('group')
for term in itertools.product(*[arg.to_terms() for arg in args])
}, group_kinds={'('}, group_depths={1}),
Operator("||", arity=2, precedence=70, associativity='left', to_terms=lambda *args: {
functools.reduce(lambda x, y: x * y, term).with_tag('group_independent')
for term in itertools.product(*[arg.to_terms() for arg in args])
}, group_kinds={'('}, group_depths={1}),
Operator("+", arity=2, precedence=100, associativity='left', to_terms=lambda *args: set(itertools.chain(*[arg.to_terms() for arg in args]))),
Operator("-", arity=2, precedence=100, associativity='left', to_terms=lambda left, right: set(set(left.to_terms()).difference(right.to_terms()))),
Operator("+", arity=1, precedence=100, associativity='right', fixity='prefix', to_terms=lambda arg: arg.to_terms()),
Expand All @@ -137,9 +145,9 @@ def power(arg, power):
Operator("**", arity=2, precedence=500, associativity='right', to_terms=power),
]

def resolve(self, token: Token, max_prefix_arity) -> List[Operator]:
def resolve(self, token: Token, max_prefix_arity, group_kind, group_depth) -> List[Operator]:
if token.token in self.operator_table:
return super().resolve(token, max_prefix_arity)
return super().resolve(token, max_prefix_arity, group_kind, group_depth)

symbol = token.token

Expand All @@ -148,9 +156,9 @@ def resolve(self, token: Token, max_prefix_arity) -> List[Operator]:
symbol = re.sub(r'[+]{2,}', '+', symbol) # multiple sequential '+' -> '+'

if symbol in self.operator_table:
return [self._resolve(token, symbol, max_prefix_arity)]
return [self._resolve(token, symbol, max_prefix_arity, group_kind, group_depth)]

return [
self._resolve(token, sym, max_prefix_arity if i == 0 else 0)
self._resolve(token, sym, max_prefix_arity if i == 0 else 0, group_kind, group_depth)
for i, sym in enumerate(symbol)
]
4 changes: 3 additions & 1 deletion formulaic/parser/types/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ class Fixity(Enum):
INFIX = 'infix'
POSTFIX = 'postfix'

def __init__(self, symbol, *, arity=None, precedence=None, associativity=None, fixity='infix', to_terms=None):
def __init__(self, symbol, *, arity=None, precedence=None, associativity=None, fixity='infix', group_kinds=None, group_depths=None, to_terms=None):
self.symbol = symbol
self.arity = arity
self.precedence = precedence
self.associativity = associativity
self.fixity = fixity
self.group_kinds = group_kinds
self.group_depths = group_depths
self._to_terms = to_terms

@property
Expand Down
19 changes: 14 additions & 5 deletions formulaic/parser/types/operator_resolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List
from typing import List, Optional

from ..utils import exc_for_token
from .operator import Operator
Expand All @@ -19,16 +19,25 @@ def __init__(self):
def operators(self) -> List[Operator]:
return [] # pragma: no cover

def resolve(self, token: Token, max_prefix_arity) -> List[Operator]:
return [self._resolve(token, token.token, max_prefix_arity)]
def resolve(self, token: Token, max_prefix_arity, group_kind, group_depth) -> List[Operator]:
return [self._resolve(token, token.token, max_prefix_arity, group_kind, group_depth)]

def _resolve(self, token: Token, symbol: str, max_prefix_arity: int) -> Operator:
def _resolve(self, token: Token, symbol: str, max_prefix_arity: int, group_kind: Optional[str], group_depth: int) -> Operator:
if symbol not in self.operator_table:
raise exc_for_token(token, f"Unknown operator '{symbol}'.")
candidates = [
candidate
for candidate in self.operator_table[symbol]
if max_prefix_arity == 0 and candidate.fixity is Operator.Fixity.PREFIX or max_prefix_arity > 0 and candidate.fixity is not Operator.Fixity.PREFIX
if (candidate.group_kinds is None or group_kind in candidate.group_kinds) and (candidate.group_depths is None or group_depth in candidate.group_depths) and (
(
max_prefix_arity == 0
and candidate.fixity is Operator.Fixity.PREFIX
)
or (
max_prefix_arity > 0
and candidate.fixity is not Operator.Fixity.PREFIX
)
)
]
if not candidates:
raise exc_for_token(token, f"Operator `{symbol}` is incorrectly used.")
Expand Down
7 changes: 6 additions & 1 deletion formulaic/parser/types/term.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
class Term:

def __init__(self, factors):
def __init__(self, factors, tags=None):
self.factors = set(factors)
self.tags = tags or []

def __mul__(self, other):
if isinstance(other, Term):
Expand Down Expand Up @@ -33,3 +34,7 @@ def __lt__(self, other):

def __repr__(self):
return ':'.join(self._tuple)

def with_tag(self, tag):
self.tags.append(tag)
return self

0 comments on commit 03c05aa

Please sign in to comment.