Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ dist
*.egg-info
build
.DS_Store
.idea
.vscode
Empty file added models/jax_utils/__init__.py
Empty file.
552 changes: 552 additions & 0 deletions models/jax_utils/modeling_flax_pytorch_utils.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions models/llama3/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def generate(
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
logits_processor: Optional[Callable[[Array | KVTensor, Array | KVTensor], Array | KVTensor]] = None,
) -> Generator:
params = self.model.params

Expand Down Expand Up @@ -455,11 +455,11 @@ def sample_top_p(probs, p):
Perform top-p (nucleus) sampling on a probability distribution.

Args:
probs (torch.Tensor): Probability distribution tensor.
probs (Array | KVTensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.

Returns:
torch.Tensor: Sampled token indices.
Array | KVTensor: Sampled token indices.

Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
Expand Down
28 changes: 14 additions & 14 deletions models/llama3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(self, x):
return output * self.weight


def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
def apply_scaling(freqs: Array | KVTensor) -> Array | KVTensor:
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
Expand Down Expand Up @@ -72,7 +72,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled:
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
def reshape_for_broadcast(freqs_cis: Array | KVTensor, x: Array | KVTensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
Expand All @@ -81,10 +81,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq: Array | KVTensor,
xk: Array | KVTensor,
freqs_cis: Array | KVTensor,
) -> Tuple[Array | KVTensor, Array | KVTensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
Expand All @@ -93,7 +93,7 @@ def apply_rotary_emb(
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
def repeat_kv(x: Array | KVTensor, n_rep: int) -> Array | KVTensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
Expand Down Expand Up @@ -163,10 +163,10 @@ def __init__(self, args: ModelArgs):

def forward(
self,
x: torch.Tensor,
x: Array | KVTensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
freqs_cis: Array | KVTensor,
mask: Optional[Array | KVTensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
Expand Down Expand Up @@ -244,10 +244,10 @@ def __init__(self, layer_id: int, args: ModelArgs):

def forward(
self,
x: torch.Tensor,
x: Array | KVTensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
freqs_cis: Array | KVTensor,
mask: Optional[Array | KVTensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, params: ModelArgs):
)

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
def forward(self, tokens: Array | KVTensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
Expand Down
4 changes: 2 additions & 2 deletions models/llama3/multimodal/encoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_sc


def build_encoder_attention_mask(
x: torch.Tensor,
ar: torch.Tensor,
x: Array | KVTensor,
ar: Array | KVTensor,
ntok: int,
num_chunks: int,
n_heads: int,
Expand Down
14 changes: 7 additions & 7 deletions models/llama3/multimodal/image_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_factors(n: int) -> Set[int]:
factors_set.add(n // i)
return factors_set

def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> Array | KVTensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Expand All @@ -101,7 +101,7 @@ def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> to
patch_size (int): Size of the side of the patch.

Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Array | KVTensor: List of possible resolutions as tuples (height, width).

Example:
>>> max_num_chunks = 5
Expand Down Expand Up @@ -182,7 +182,7 @@ def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_im.paste(image)
return new_im

def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
def _split(self, image: Array | KVTensor, ncw: int, nch: int) -> Array | KVTensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
Expand All @@ -194,10 +194,10 @@ def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:

def resize_without_distortion(
self,
image: torch.Tensor,
image: Array | KVTensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
) -> Array | KVTensor:
"""
Used to resize an image to target_resolution, without distortion.

Expand Down Expand Up @@ -259,7 +259,7 @@ def resize_without_distortion(
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
possible_resolutions: Array | KVTensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Expand All @@ -283,7 +283,7 @@ def get_best_fit(

Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
possible_resolutions (Array | KVTensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.

Expand Down
Loading