Skip to content

Commit 4ce0bfa

Browse files
committed
pass static activ buffers
1 parent 16851fd commit 4ce0bfa

6 files changed

Lines changed: 225 additions & 128 deletions

File tree

pufferlib/extensions/bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ TORCH_LIBRARY_IMPL(pufferlib, CUDA, m) {
226226
}
227227

228228
TORCH_LIBRARY(_C, m) {
229-
m.def("mingru_gate(Tensor state, Tensor combined) -> (Tensor, Tensor)");
229+
m.def("mingru_gate(Tensor state, Tensor combined, Tensor out, Tensor next_state) -> ()");
230230
m.def("fc_max(Tensor x, Tensor W, Tensor b) -> Tensor");
231231
}
232232

@@ -351,8 +351,8 @@ PYBIND11_MODULE(_C, m) {
351351
int num_layers, int num_atns, bool continuous) {
352352
return new Policy(alloc, input, hidden, output, num_layers, num_atns, continuous);
353353
}))
354-
.def("forward", &Policy::forward)
355-
.def("forward_train", &Policy::forward_train)
354+
.def("forward", static_cast<std::tuple<Logits, Tensor, Tensor> (Policy::*)(Tensor, Tensor)>(&Policy::forward))
355+
.def("forward_train", static_cast<std::tuple<Logits, Tensor> (Policy::*)(Tensor, Tensor)>(&Policy::forward_train))
356356
.def("init_weights", &Policy::init_weights)
357357
.def("parameters", &Policy::parameters)
358358
.def("named_parameters", [](Policy& self) {

pufferlib/extensions/cuda/kernels.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,12 @@ __global__ void ppo_loss_backward_kernel_optimized(
638638
int t = idx % T_seq;
639639
int nt = n * T_seq + t;
640640

641+
// Input strides (for reading non-contiguous logits/values_pred)
641642
int logits_base = n * logits_stride_n + t * logits_stride_t;
642643
int values_idx = n * values_stride_n + t * values_stride_t;
644+
// Output indices (for writing to contiguous grad buffers)
645+
int grad_logits_base = nt * A_total;
646+
int grad_values_idx = nt;
643647

644648
float old_logp = to_float(old_logprobs[nt]);
645649
float adv = float(advantages[nt]);
@@ -672,7 +676,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
672676
} else {
673677
d_val_pred = val_pred - ret;
674678
}
675-
grad_values_pred[values_idx] = dL * vf_coef * d_val_pred;
679+
grad_values_pred[grad_values_idx] = dL * vf_coef * d_val_pred;
676680

677681
if (is_continuous) {
678682
// Continuous: compute total log prob first for ratio
@@ -724,14 +728,14 @@ __global__ void ppo_loss_backward_kernel_optimized(
724728

725729
// Gradient wrt mean: d_log_prob/d_mean = (action - mean) / var
726730
float d_mean = d_new_logp * diff / var;
727-
grad_logits[logits_base + h * logits_stride_a] = d_mean;
731+
grad_logits[grad_logits_base + h] = d_mean;
728732

729733
// Gradient wrt log_std:
730734
// d_log_prob/d_log_std = (action - mean)^2 / var - 1
731735
// d_entropy/d_log_std = 1
732736
// Total: d_new_logp * ((diff^2/var) - 1) + d_entropy_term * 1
733737
float d_log_std = d_new_logp * (diff * diff / var - 1.0f) + d_entropy_term;
734-
grad_logstd[logits_base + h * logits_stride_a] = d_log_std;
738+
grad_logstd[grad_logits_base + h] = d_log_std;
735739
}
736740
} else {
737741
// Discrete: original implementation
@@ -822,7 +826,7 @@ __global__ void ppo_loss_backward_kernel_optimized(
822826
// Each head's entropy contributes independently to total entropy
823827
d_logit += d_entropy_term * p * (-ent - logp);
824828

825-
grad_logits[logits_base + (logits_offset + a) * logits_stride_a] = d_logit;
829+
grad_logits[grad_logits_base + logits_offset + a] = d_logit;
826830
}
827831

828832
logits_offset += A;

pufferlib/extensions/cuda/modules.cu

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using AutogradCtx = torch::autograd::AutogradContext;
2525
// returns {out, next_state} where:
2626
// out (B, H) = sigmoid(proj) * mingru_out
2727
// next_state (B, H) = mingru_out (for recurrence)
28-
vector<Tensor> mingru_gate(Tensor state, Tensor combined) {
28+
void mingru_gate(Tensor state, Tensor combined, Tensor out, Tensor next_state) {
2929
TORCH_CHECK(state.is_cuda(), "state must be on CUDA");
3030
TORCH_CHECK(combined.is_cuda(), "combined must be on CUDA");
3131
TORCH_CHECK(state.dtype() == combined.dtype(), "dtypes must match");
@@ -36,9 +36,6 @@ vector<Tensor> mingru_gate(Tensor state, Tensor combined) {
3636

3737
int B = static_cast<int>(state.size(0));
3838
int H = static_cast<int>(state.size(1));
39-
40-
auto out = torch::empty_like(state);
41-
auto next_state = torch::empty_like(state);
4239
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
4340

4441
mingru_gate_inference_kernel<<<grid_size(B * H), BLOCK_SIZE, 0, stream>>>(
@@ -47,7 +44,6 @@ vector<Tensor> mingru_gate(Tensor state, Tensor combined) {
4744
(const precision_t*)combined.data_ptr(),
4845
(const precision_t*)state.data_ptr(),
4946
H, B);
50-
return {out, next_state};
5147
}
5248

5349
// PrefixScan: checkpointed associative scan for MinGRU training
@@ -452,16 +448,16 @@ void ppo_loss_fwd_bwd(
452448
TORCH_CHECK(act_sizes.is_cuda() && act_sizes.dtype() == torch::kInt32,
453449
"act_sizes must be int32 on CUDA");
454450

455-
// Make inputs contiguous for both kernels
456-
logits = logits.contiguous();
457-
values_pred = values_pred.contiguous();
458-
old_logprobs = old_logprobs.contiguous();
459-
advantages = advantages.contiguous();
460-
prio = prio.contiguous();
461-
values = values.contiguous();
462-
returns = returns.contiguous();
451+
// logits/values_pred may be non-contiguous (fused decoder output) — kernel handles via strides
452+
// Grad outputs use contiguous layout (nt * A_total indexing)
453+
TORCH_CHECK(old_logprobs.is_contiguous(), "old_logprobs must be contiguous");
454+
TORCH_CHECK(advantages.is_contiguous(), "advantages must be contiguous");
455+
TORCH_CHECK(prio.is_contiguous(), "prio must be contiguous");
456+
TORCH_CHECK(values.is_contiguous(), "values must be contiguous");
457+
TORCH_CHECK(returns.is_contiguous(), "returns must be contiguous");
463458

464459
bool is_continuous = logstd.defined() && logstd.numel() > 0;
460+
// TODO: pre-allocate contiguous logstd buffer to remove this alloc
465461
if (is_continuous) logstd = logstd.contiguous();
466462

467463
int N = static_cast<int>(logits.size(0));
@@ -471,7 +467,8 @@ void ppo_loss_fwd_bwd(
471467
int total = N * T;
472468

473469
auto [adv_var, adv_mean] = torch::var_mean(advantages);
474-
auto actions_flat = actions.reshape({total, num_atns}).contiguous();
470+
auto actions_flat = actions.reshape({total, num_atns});
471+
TORCH_CHECK(actions_flat.is_contiguous(), "actions must be contiguous");
475472

476473
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
477474

0 commit comments

Comments
 (0)