Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat2 in method wrapper_CUDA_bmm) #34695

Open
ra-MANUJ-an opened this issue Nov 12, 2024 · 0 comments

Comments

@ra-MANUJ-an
Copy link

Reproduction

I am trying to finetune Qwen2-0.5B model on some training data using a multi-GPU setup. The same code (given further below) seems to work in a single-GPU setting (when i set CUDA_VISIBLE_DEVICES=0):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], line 4
      2 import torch
      3 torch.autograd.set_detect_anomaly(True)
----> 4 main()

Cell In[12], line 15, in main()
      8 trainer = Trainer(env_params=env_params,
      9                   model_params=model_params,
     10                   optimizer_params=optimizer_params,
     11                   trainer_params=trainer_params)
     13 copy_all_src(trainer.result_folder)
---> 15 trainer.run()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:82, in TSPTrainer.run(self)
     79 self.scheduler.step()
     81 # Train
---> 82 train_score, train_loss = self._train_one_epoch(epoch)
     83 self.result_log.append('train_score', epoch, train_score)
     84 self.result_log.append('train_loss', epoch, train_loss)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:139, in TSPTrainer._train_one_epoch(self, epoch)
    136 remaining = train_num_episode - episode
    137 batch_size = min(self.trainer_params['train_batch_size'], remaining)
--> 139 avg_score, avg_loss = self._train_one_batch(batch_size)
    140 score_AM.update(avg_score, batch_size)
    141 loss_AM.update(avg_loss, batch_size)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:175, in TSPTrainer._train_one_batch(self, batch_size)
    173 # print(4, type(state), state)
    174 while not done:
--> 175     selected, prob = self.model.module(state)
    176     # print(3, selected.shape)
    177     state, reward, done = self.env.step(selected)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:52, in TSPTransformer.forward(self, state)
     50     return self._init_sequence(batch_size, pomo_size)
     51 else:
---> 52     return self._continue_sequence(state, batch_size, pomo_size)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:100, in TSPTransformer._continue_sequence(self, state, batch_size, pomo_size)
     96     state.ninf_mask = state.ninf_mask.to(self.device)
     98 # Get probabilities from decoder
--> 100 probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
    102 # Select next node
    103 if self.training or self.model_params['eval_type'] == 'softmax':

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:261, in Decoder.forward(self, seq_so_far, inp_mask, ninf_mask)
    258 flat_mask = flat_mask.to(self.device)
    260 # Forward pass through model
--> 261 outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
    262 logits = outputs.logits.to(self.device)
    264 # Get last valid position

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/peft_model.py:1644, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1642     with self._enable_peft_forward_hooks(**kwargs):
   1643         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1644         return self.base_model(
   1645             input_ids=input_ids,
   1646             attention_mask=attention_mask,
   1647             inputs_embeds=inputs_embeds,
   1648             labels=labels,
   1649             output_attentions=output_attentions,
   1650             output_hidden_states=output_hidden_states,
   1651             return_dict=return_dict,
   1652             **kwargs,
   1653         )
   1655 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1656 if attention_mask is not None:
   1657     # concat prompt attention mask

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1164, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1161 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1163 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1164 outputs = self.model(
   1165     input_ids=input_ids,
   1166     attention_mask=attention_mask,
   1167     position_ids=position_ids,
   1168     past_key_values=past_key_values,
   1169     inputs_embeds=inputs_embeds,
   1170     use_cache=use_cache,
   1171     output_attentions=output_attentions,
   1172     output_hidden_states=output_hidden_states,
   1173     return_dict=return_dict,
   1174     cache_position=cache_position,
   1175 )
   1177 hidden_states = outputs[0]
   1178 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:871, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    868 hidden_states = inputs_embeds
    870 # create position embeddings to be shared across the decoder layers
--> 871 position_embeddings = self.rotary_emb(hidden_states, position_ids)
    873 # decoder layers
    874 all_hidden_states = () if output_hidden_states else None

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:163, in Qwen2RotaryEmbedding.forward(self, x, position_ids)
    161 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
    162 with torch.autocast(device_type=device_type, enabled=False):
--> 163     freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    164     emb = torch.cat((freqs, freqs), dim=-1)
    165     cos = emb.cos()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

Code for the above error is given below:

Trainer.py

import torch
from logging import getLogger

from TSPEnvQuant import TSPEnv as Env
from TSPTransformerModelQuant_b import TSPTransformer as Model

from torch.optim import Adam as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler

from utils.utils import *


class TSPTrainer:
    def __init__(self,
                 env_params,
                 model_params,
                 optimizer_params,
                 trainer_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.optimizer_params = optimizer_params
        self.trainer_params = trainer_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()
        self.result_log = LogData()

        # cuda
        USE_CUDA = self.trainer_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.trainer_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            self.device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')

        # print(self.device)
        # Main Components
        self.model = Model(**self.model_params).to(self.device)
        self.env = Env(**self.env_params)
        self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
        self.scheduler = Scheduler(self.optimizer, **self.optimizer_params['scheduler'])

        # Wrap the model in DataParallel if multiple GPUs are available
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])  # Specify device IDs
            self.logger.info(f"Using {torch.cuda.device_count()} GPUs for DataParallel.")

        # device = torch.device('cuda' if trainer_params['use_cuda'] else 'cpu')
        self.model.to(self.device)        
        
        # Restore
        self.start_epoch = 1
        model_load = trainer_params['model_load']
        if model_load['enable']:
            checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
            checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict']) if torch.cuda.device_count() > 1 else self.model.load_state_dict(checkpoint['model_state_dict'])
            self.start_epoch = 1 + model_load['epoch']
            self.result_log.set_raw_data(checkpoint['result_log'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.last_epoch = model_load['epoch'] - 1
            self.logger.info('Saved Model Loaded !!')

        # utility
        self.time_estimator = TimeEstimator()

    def run(self):
        self.time_estimator.reset(self.start_epoch)
        for epoch in range(self.start_epoch, self.trainer_params['epochs']+1):
            self.logger.info('=================================================================')

            # LR Decay
            self.scheduler.step()

            # Train
            train_score, train_loss = self._train_one_epoch(epoch)
            self.result_log.append('train_score', epoch, train_score)
            self.result_log.append('train_loss', epoch, train_loss)

            ############################
            # Logs & Checkpoint
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
            self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
                epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))

            all_done = (epoch == self.trainer_params['epochs'])
            model_save_interval = self.trainer_params['logging']['model_save_interval']
            img_save_interval = self.trainer_params['logging']['img_save_interval']

            if epoch > 1:  # save latest images, every epoch
                self.logger.info("Saving log_image")
                image_prefix = '{}/latest'.format(self.result_folder)
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
                                               self.result_log, labels=['train_score'])
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
                                               self.result_log, labels=['train_loss'])

            if all_done or (epoch % model_save_interval) == 0:
                self.logger.info("Saving trained_model")
                checkpoint_dict = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict() if torch.cuda.device_count() > 1 else self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'result_log': self.result_log.get_raw_data()
                }
                torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))

            if all_done or (epoch % img_save_interval) == 0:
                image_prefix = '{}/img/checkpoint-{}'.format(self.result_folder, epoch)
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
                                               self.result_log, labels=['train_score'])
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
                                               self.result_log, labels=['train_loss'])

            if all_done:
                self.logger.info(" *** Training Done *** ")
                self.logger.info("Now, printing log array...")
                util_print_log_array(self.logger, self.result_log)

    def _train_one_epoch(self, epoch):
        score_AM = AverageMeter()
        loss_AM = AverageMeter()

        train_num_episode = self.trainer_params['train_episodes']
        episode = 0
        loop_cnt = 0
        while episode < train_num_episode:
            remaining = train_num_episode - episode
            batch_size = min(self.trainer_params['train_batch_size'], remaining)

            avg_score, avg_loss = self._train_one_batch(batch_size)
            score_AM.update(avg_score, batch_size)
            loss_AM.update(avg_loss, batch_size)

            episode += batch_size

            # Log First 10 Batch, only at the first epoch
            if epoch == self.start_epoch:
                loop_cnt += 1
                if loop_cnt <= 10:
                    self.logger.info('Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%)  Score: {:.4f},  Loss: {:.4f}'
                                     .format(epoch, episode, train_num_episode, 100. * episode / train_num_episode,
                                             score_AM.avg, loss_AM.avg))

        # Log Once, for each epoch
        self.logger.info('Epoch {:3d}: Train ({:3.0f}%)  Score: {:.4f},  Loss: {:.4f}'
                         .format(epoch, 100. * episode / train_num_episode,
                                 score_AM.avg, loss_AM.avg))

        return score_AM.avg, loss_AM.avg

    def _train_one_batch(self, batch_size):
        self.model.train()
        # print(5, batch_size)
        self.env.load_problems(batch_size)
        reset_state, _, _ = self.env.reset()
        # self.model.pre_forward(reset_state)
        # Modify this line in _train_one_batch
        self.model.module.pre_forward(reset_state)

        prob_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0)).to(self.device)  # Explicitly move to device

        # POMO Rollout
        state, reward, done = self.env.pre_step()
        # print(4, type(state), state)
        while not done:
            selected, prob = self.model.module(state)
            # print(3, selected.shape)
            state, reward, done = self.env.step(selected)
            prob = prob.to(prob_list.device)
            prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

        advantage = reward - reward.float().mean(dim=1, keepdims=True)
        log_prob = prob_list.log().sum(dim=2)
        loss = -(advantage * log_prob).mean()

        max_pomo_reward, _ = reward.max(dim=1)
        score_mean = -max_pomo_reward.float().mean()

        self.model.zero_grad()
        loss.backward()
        self.optimizer.step()
        return score_mean.item(), loss.item()
Model.py

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, TaskType
from typing import Optional, Dict, Any, Tuple

class TSPTransformer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.model_params = kwargs
        # Set device first
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        
        # Initialize components
        self.encoder = Encoder(**kwargs).to(self.device)
        self.embedding_size = kwargs.get('embedding_dim', 896)
        
        # Load the model with LoRA and 4-bit quantization if needed
        self.model = load_model(kwargs)
        kwargs['device'] = self.device  # Ensure decoder gets the same device
        self.decoder = Decoder(self.model, **kwargs)
        
        # Initialize state storage
        self.encoded_nodes = None
        self.seq_so_far = None
        self.input_mask = None
        self.t = None
    
    def pre_forward(self, reset_state):
        """Initialize model state for new sequence"""
        # Move input to correct device
        problems = reset_state.problems.to(self.device)
        self.encoded_nodes = self.encoder(problems)
        self.problem_size = problems.size(1)
        self.batch_size = problems.size(0)
    
    def forward(self, state) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Ensure state tensors are on correct device
        state.BATCH_IDX = state.BATCH_IDX.to(self.device)
        state.POMO_IDX = state.POMO_IDX.to(self.device)
        if state.ninf_mask is not None:
            state.ninf_mask = state.ninf_mask.to(self.device)
        if state.current_node is not None:
            state.current_node = state.current_node.to(self.device)

        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.current_node is None:
            return self._init_sequence(batch_size, pomo_size)
        else:
            return self._continue_sequence(state, batch_size, pomo_size)

    def _init_sequence(self, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize sequence state"""
        self.t = 0  # Start at 0 instead of -1
        
        # Create new tensors instead of modifying in place
        selected = torch.arange(pomo_size, device=self.device).expand(batch_size, pomo_size)
        prob = torch.ones(size=(batch_size, pomo_size), device=self.device)
        
        # Initialize sequence storage with proper dimensions
        self.seq_so_far = torch.zeros(
            (batch_size, pomo_size, self.problem_size, self.embedding_size),
            device=self.device
        )
        
        self.input_mask = torch.zeros(
            (batch_size, pomo_size, self.problem_size),
            dtype=torch.bool,
            device=self.device
        )
        
        return selected, prob

    def _continue_sequence(self, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Continue sequence generation"""
        # Get encoded representation of current node
        encoded_current = self._get_encoded_node(state.current_node)
        
        # Move tensors to correct device
        encoded_current = encoded_current.to(self.device)
        
        # Create new tensor for updated sequence
        new_seq = self.seq_so_far.clone().to(self.device)
        new_seq[:, :, self.t, :] = encoded_current
        self.seq_so_far = new_seq
        
        # Create new tensor for updated mask
        new_mask = self.input_mask.clone().to(self.device)
        new_mask[:, :, self.t] = True
        self.input_mask = new_mask
        
        # Ensure state.ninf_mask is on correct device
        if state.ninf_mask is not None:
            state.ninf_mask = state.ninf_mask.to(self.device)
        
        # Get probabilities from decoder

        probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
        
        # Select next node
        if self.training or self.model_params['eval_type'] == 'softmax':
            selected, prob = self._sample_node(probs, state, batch_size, pomo_size)
        else:
            selected = probs.argmax(dim=2)
            prob = None
        
        self.t += 1
        return selected, prob

    def _get_encoded_node(self, node_indices: torch.Tensor) -> torch.Tensor:
        """Get encoded representation of nodes safely"""
        batch_size, pomo_size = node_indices.shape
        embedding_dim = self.encoded_nodes.size(2)
        
        # Create gathering indice
        gather_idx = node_indices[:, :, None].expand(batch_size, pomo_size, embedding_dim)
        # gather_idx = gather_idx.to(self.encoded_nodes.device)

        self.encoded_nodes = self.encoded_nodes.to(self.device)
        gather_idx = gather_idx.to(self.device)

        # Ensure gather_idx is within the range of self.encoded_nodes.size(1)
        max_valid_index = self.encoded_nodes.size(1) - 1
        gather_idx = torch.clamp(gather_idx, min=0, max=max_valid_index)
        # assert gather_idx.max() <= max_valid_index, "gather_idx contains indices out of bounds"
        
        # Gather encoded representations
        return self.encoded_nodes.gather(dim=1, index=gather_idx)

    def _sample_node(self, probs: torch.Tensor, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample next node with retry logic"""
        max_attempts = 100
        for _ in range(max_attempts):
            # 
            probs = probs.to(self.device)
            # Reshape for sampling
            flat_probs = probs.reshape(batch_size * pomo_size, -1)
            
            # Sample indices
            selected = flat_probs.multinomial(1, replacement=True)
            selected = selected.reshape(batch_size, pomo_size)
            
            # Calculate probabilities
            prob = probs[state.BATCH_IDX, state.POMO_IDX, selected]
            prob = prob.reshape(batch_size, pomo_size)
            
            if (prob > 0).all():
                return selected, prob
        
        raise RuntimeError(f"Failed to sample valid nodes after {max_attempts} attempts")

class Encoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.embedding_dim = kwargs.get('embedding_dim', 896)
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        
        # Linear layer to embed node coordinates
        self.embed_layer = nn.Linear(2, self.embedding_dim)
        
        # Multi-head self-attention layer
        self.num_heads = kwargs.get('num_attention_heads', 8)
        self.attention_layer = nn.MultiheadAttention(
            embed_dim=self.embedding_dim,
            num_heads=self.num_heads,
            batch_first=True
        )
        
        # Register positional encoding as a buffer so it's not updated during training
        self.register_buffer("positional_encoding", self._generate_positional_encoding(kwargs.get('problem_size', 20), self.embedding_dim))

    def _generate_positional_encoding(self, problem_size: int, embed_dim: int) -> torch.Tensor:
        """Generate sinusoidal positional encoding for input sequences."""
        # Create a matrix of shape (problem_size, embed_dim) to hold the positional encodings
        encoding = torch.zeros(problem_size, embed_dim)
        position = torch.arange(0, problem_size, dtype=torch.float).unsqueeze(1)  # Shape: (problem_size, 1)
        
        # Compute the division terms for sine and cosine functions
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim))
        
        # Apply sine to even indices and cosine to odd indices
        encoding[:, 0::2] = torch.sin(position * div_term)  # Sine for even dimensions
        encoding[:, 1::2] = torch.cos(position * div_term)  # Cosine for odd dimensions
        
        return encoding.unsqueeze(0)  # Shape: (1, problem_size, embed_dim) to allow broadcasting

    def forward(self, problems):
        # Ensure `problems` is on the same device as the embedding layer
        problems = problems.to(self.embed_layer.weight.device)
    
        batch_size, problem_size = problems.shape[:2]
    
        # Embed node coordinates
        embedded = self.embed_layer(problems.reshape(-1, 2))
        embedded = embedded.reshape(batch_size, problem_size, self.embedding_dim)
    
        # Align positional encoding to the device of `embedded`
        pos_encoding = self.positional_encoding[:, :problem_size, :].to(embedded.device)
        embedded = embedded + pos_encoding  # Broadcast positional encoding across the batch
    
        # Apply self-attention
        attention_output, _ = self.attention_layer(embedded, embedded, embedded)
        
        # Create position indices for concatenation if needed
        # ids = torch.arange(problem_size, device=self.device).expand(batch_size, problem_size)
        
        # Concatenate position indices with attention output (optional, depends on the architecture)
        # return torch.cat([ids.unsqueeze(-1).float(), attention_output], dim=-1)
        return attention_output


class Decoder(nn.Module):
    def __init__(self, model: nn.Module, **kwargs):
        super().__init__()
        self.model = model
        self.problem_size = kwargs.get('problem_size', 20)
        self.use_lora = kwargs.get('use_lora', True)
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        
        self._setup_model()
    
    
    def _setup_model(self):
        """Configure model architecture"""
        # Get base model if wrapped in DataParallel
        base_model = self.model
        
        # Modify output size
        base_model.lm_head = nn.Linear(
            base_model.config.hidden_size,
            self.problem_size
        ).to(self.device)
        
        # Apply LoRA if requested
        if self.use_lora:
            lora_config = LoraConfig(
                r=4,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.1,
                bias="none",
                task_type=TaskType.CAUSAL_LM
            )
            self.model = get_peft_model(base_model, lora_config)
    
    def forward(self, seq_so_far: torch.Tensor, inp_mask: torch.Tensor, ninf_mask: torch.Tensor) -> torch.Tensor:
        batch_size, pomo_size, problem_size, embedding_dim = seq_so_far.shape
        
        # Reshape inputs
        flat_seq = seq_so_far.reshape(batch_size * pomo_size, problem_size, embedding_dim)
        flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)
        
        try:
            # Ensure inputs are on the correct device
            flat_seq = flat_seq.to(self.device)
            flat_mask = flat_mask.to(self.device)
            
            # Forward pass through model
            outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
            logits = outputs.logits.to(self.device)
            
            # Get last valid position
            last_positions = flat_mask.sum(dim=1).long() - 1
            
            # Gather logits for last positions
            batch_indices = torch.arange(batch_size * pomo_size, device=self.device)
            gathered_logits = logits[batch_indices, last_positions]
            
            # Reshape and apply mask
            logits = gathered_logits.reshape(batch_size, pomo_size, problem_size)
            ninf_mask = ninf_mask.to(self.device)
            masked_logits = logits + ninf_mask.float()
            
            # Return probabilities
            return torch.softmax(masked_logits, dim=2)
            
        except Exception as e:
            print(f"Error in decoder forward pass: {e}")
            print(f"Device info - Model device: {self.device}, Input: {flat_seq.device}, Mask: {flat_mask.device}")
            raise

def load_model(config: Dict[str, Any]) -> nn.Module:
    """Load model with proper configuration"""
    # print(config)
    device = config.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    if config.get('checkpoint_path'):
        # print('checkpoint_path')
        try:
            return PeftModel.from_pretrained(
                config['model_name'],
                config['checkpoint_path'],
                is_trainable=True
            ).to(device)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Falling back to base model...")
    
    if config.get('use_4bit', True):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            bnb_4bit_use_double_quant=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config
        )
        model = prepare_model_for_kbit_training(model)
        model.config.use_cache = False
    else:
        # print('else')
        model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        ).to(device)
    
    return model

Expected behavior
Expected behavior is that the model should train in a multi-GPU setting without throwing any errors. The same script works in single-GPU setting but throws the above error in a multi-GPU setting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant