From 25794a477f4d309bb911a8838dcc1be55865cb29 Mon Sep 17 00:00:00 2001 From: Anshul Singhal Date: Thu, 5 Jun 2025 04:00:48 +0530 Subject: [PATCH] tsai_core.py edit for correct conversion by importing torch --- tsai/data/core.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tsai/data/core.py b/tsai/data/core.py index 69d554c7..10001374 100644 --- a/tsai/data/core.py +++ b/tsai/data/core.py @@ -335,8 +335,26 @@ def __init__(self, items, tfms=None, splits=None, split_idx=None, types=None, do def subset(self, i, **kwargs): return type(self)(self.items, splits=self.splits[i], split_idx=i, do_setup=False, types=self.types, **kwargs) def __getitem__(self, it): - if hasattr(self.items, 'oindex'): return self.items.oindex[self._splits[it]] - else: return self.items[self._splits[it]] + # if hasattr(self.items, 'oindex'): return self.items.oindex[self._splits[it]] + # else: return self.items[self._splits[it]] + # changed the old code for our issue. + # 1) compute the raw index or slice from self._splits + idx = self._splits[it] + + # 2) if it's a NumPy scalar (e.g. dtype=np.int8), turn it into a plain Python int + if isinstance(idx, np.generic): + idx = idx.item() + # 3) if it's a NumPy array (e.g. dtype=int8 or int32), turn it into a Python list + elif isinstance(idx, np.ndarray): + idx = idx.tolist() + + # 4) finally, index self.items (or self.items.oindex) with that pure‐Python index + if hasattr(self.items, 'oindex'): + return self.items.oindex[idx] + else: + return self.items[idx] + + def __len__(self): return len(self._splits) def __repr__(self): if hasattr(self.items, "shape"): @@ -486,12 +504,20 @@ def __init__(self, X=None, y=None, items=None, sel_vars=None, sel_steps=None, tf self.tls = L(lt(item, t, **kwargs) for lt,item,t in zip(lts, items, self.tfms)) # if len(self.tls) > 0 and len(self.tls[0]) > 0: # self.typs = [type(tl.items[0]) if isinstance(tl.items[0], torch.Tensor) else self.typs[i] for i,tl in enumerate(self.tls)] + # if self.inplace and (tfms is None or tfms == [None] * len(self.tls)): + # for tl,typ in zip(self.tls, self.typs): + # tl.items = typ(tl.items) + #replacing the above with the following safer version because torch.as_tensor can't automatically infer the dtype from a numpy.int64 array without a valid PyTorch tensor type + import torch + if self.inplace and (tfms is None or tfms == [None] * len(self.tls)): - for tl,typ in zip(self.tls, self.typs): - tl.items = typ(tl.items) + for tl in self.tls: + if isinstance(tl.items, np.ndarray): + tl.items = torch.as_tensor(tl.items, dtype=torch.float32) + self.ptls = self.tls self.no_tfm = True - else: + else: self.ptls = L([typ(stack(tl[:]))[...,self.sel_vars, self.sel_steps] if (i==0 and self.multi_index) else typ(stack(tl[:])) \ for i,(tl,typ) in enumerate(zip(self.tls,self.typs))]) if inplace else self.tls self.no_tfm = False @@ -653,7 +679,9 @@ def create_batch(self, b): if hasattr(self, "split_idxs"): self.input_idxs = self.split_idxs[b] else: self.input_idxs = self.idxs - return self.dataset[b] + # return self.dataset[b] + return self.dataset[[int(i) for i in b]] + def create_item(self, s): if self.indexed: return self.dataset[s or 0]