@@ -25,7 +25,7 @@ using AutogradCtx = torch::autograd::AutogradContext;
2525// returns {out, next_state} where:
2626// out (B, H) = sigmoid(proj) * mingru_out
2727// next_state (B, H) = mingru_out (for recurrence)
28- vector<Tensor> mingru_gate (Tensor state, Tensor combined) {
28+ void mingru_gate (Tensor state, Tensor combined, Tensor out, Tensor next_state ) {
2929 TORCH_CHECK (state.is_cuda (), " state must be on CUDA" );
3030 TORCH_CHECK (combined.is_cuda (), " combined must be on CUDA" );
3131 TORCH_CHECK (state.dtype () == combined.dtype (), " dtypes must match" );
@@ -36,9 +36,6 @@ vector<Tensor> mingru_gate(Tensor state, Tensor combined) {
3636
3737 int B = static_cast <int >(state.size (0 ));
3838 int H = static_cast <int >(state.size (1 ));
39-
40- auto out = torch::empty_like (state);
41- auto next_state = torch::empty_like (state);
4239 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
4340
4441 mingru_gate_inference_kernel<<<grid_size(B * H), BLOCK_SIZE, 0 , stream>>> (
@@ -47,7 +44,6 @@ vector<Tensor> mingru_gate(Tensor state, Tensor combined) {
4744 (const precision_t *)combined.data_ptr (),
4845 (const precision_t *)state.data_ptr (),
4946 H, B);
50- return {out, next_state};
5147}
5248
5349// PrefixScan: checkpointed associative scan for MinGRU training
@@ -452,16 +448,16 @@ void ppo_loss_fwd_bwd(
452448 TORCH_CHECK (act_sizes.is_cuda () && act_sizes.dtype () == torch::kInt32 ,
453449 " act_sizes must be int32 on CUDA" );
454450
455- // Make inputs contiguous for both kernels
456- logits = logits.contiguous ();
457- values_pred = values_pred.contiguous ();
458- old_logprobs = old_logprobs.contiguous ();
459- advantages = advantages.contiguous ();
460- prio = prio.contiguous ();
461- values = values.contiguous ();
462- returns = returns.contiguous ();
451+ // logits/values_pred may be non-contiguous (fused decoder output) — kernel handles via strides
452+ // Grad outputs use contiguous layout (nt * A_total indexing)
453+ TORCH_CHECK (old_logprobs.is_contiguous (), " old_logprobs must be contiguous" );
454+ TORCH_CHECK (advantages.is_contiguous (), " advantages must be contiguous" );
455+ TORCH_CHECK (prio.is_contiguous (), " prio must be contiguous" );
456+ TORCH_CHECK (values.is_contiguous (), " values must be contiguous" );
457+ TORCH_CHECK (returns.is_contiguous (), " returns must be contiguous" );
463458
464459 bool is_continuous = logstd.defined () && logstd.numel () > 0 ;
460+ // TODO: pre-allocate contiguous logstd buffer to remove this alloc
465461 if (is_continuous) logstd = logstd.contiguous ();
466462
467463 int N = static_cast <int >(logits.size (0 ));
@@ -471,7 +467,8 @@ void ppo_loss_fwd_bwd(
471467 int total = N * T;
472468
473469 auto [adv_var, adv_mean] = torch::var_mean (advantages);
474- auto actions_flat = actions.reshape ({total, num_atns}).contiguous ();
470+ auto actions_flat = actions.reshape ({total, num_atns});
471+ TORCH_CHECK (actions_flat.is_contiguous (), " actions must be contiguous" );
475472
476473 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
477474
0 commit comments