diff --git a/dumbo/core.py b/dumbo/core.py index 56c28e0..1c12ac8 100644 --- a/dumbo/core.py +++ b/dumbo/core.py @@ -463,7 +463,7 @@ def valwrapper(data, valfunc): def mapfunc_iter(data, mapfunc): for (key, value) in data: - for output in mapfunc(key, value): + for output in mapfunc(key, value) or (): yield output @@ -478,7 +478,7 @@ def itermap(data, mapfunc, valfunc=None): def redfunc_iter(data, redfunc): for (key, values) in data: - for output in redfunc(key, values): + for output in redfunc(key, values) or (): yield output diff --git a/dumbo/lib/__init__.py b/dumbo/lib/__init__.py index 2193388..fc066e6 100644 --- a/dumbo/lib/__init__.py +++ b/dumbo/lib/__init__.py @@ -17,6 +17,7 @@ import heapq import os import types +import inspect from itertools import chain, imap, izip from math import sqrt from copy import copy @@ -128,7 +129,7 @@ def __call__normalkey(self, data): path, key = key for pattern, mapper in mappers: if pattern in path: - for output in mapper(key, value): + for output in mapper(key, value, path=path): yield output def __call__joinkey(self, data): @@ -138,7 +139,7 @@ def __call__joinkey(self, data): key.body = key.body[1] for pattern, mapper in mappers: if pattern in path: - for output in mapper(key, value): + for output in mapper(key, value, path=path): yield output def add(self, pattern, mapper): @@ -169,15 +170,20 @@ def configure(self): mapper.configure() if hasattr(mapper, 'close'): self.closefunc = mapper.close + mapper_call = mapper + if isinstance(mapper_call, mrbase_class): + mapper_call = mapper.__call__ self.mapper = mapper + if not inspect.getargspec(mapper_call).keywords: + self.mapper = lambda key, value, **kwargs: mapper(key, value) def close(self): if self.closefunc: self.closefunc() - def __call__(self, key, value): + def __call__(self, key, value, **kwargs): key.isprimary = self.isprimary - for k, v in self.mapper(key.body, value): + for k, v in self.mapper(key.body, value, **kwargs) or (): jk = copy(key) jk.body = k yield jk, v @@ -202,14 +208,12 @@ class JoinCombiner(object): def __call__(self, key, values): if key.isprimary: self._key = key.body - output = self.primary(key.body, values) - if output: - for k, v in output: - jk = copy(key) - jk.body = k - yield jk, v + for k, v in self.primary(key.body, values) or (): + jk = copy(key) + jk.body = k + yield jk, v elif not self.secondary_blocked(key.body): - for k, v in self.secondary(key.body, values): + for k, v in self.secondary(key.body, values) or (): jk = copy(key) jk.body = k yield jk, v