|
| 1 | +//! TCP CUBIC congestion control algorithm. |
| 2 | +
|
| 3 | +#![allow(dead_code)] |
| 4 | +#![allow(non_snake_case)] |
| 5 | +#![allow(unused_variables)] |
| 6 | + |
| 7 | +use core::cmp::max; |
| 8 | +use core::num::NonZeroU32; |
| 9 | +use hystart::HystartDetect; |
| 10 | +use kernel::net::tcp; |
| 11 | +use kernel::net::tcp::cong::{self, hystart}; |
| 12 | +use kernel::prelude::*; |
| 13 | +use kernel::time; |
| 14 | +use kernel::{c_str, module_cca}; |
| 15 | + |
| 16 | +const BICTCP_BETA_SCALE: u32 = 1024; |
| 17 | + |
| 18 | +// TODO: Convert to module parameters once they are available. Currently these |
| 19 | +// are the defaults from the C implementation. |
| 20 | +// TODO: Use NonZeroU32 where appropriate. |
| 21 | +/// Whether to use fast-convergence. |
| 22 | +const FAST_CONVERGENCE: bool = true; |
| 23 | +/// The factor for multiplicative decrease of cwnd upon a loss event. Will be |
| 24 | +/// divided by `BICTCP_BETA_SCALE`, approximately 0.7. |
| 25 | +const BETA: u32 = 717; |
| 26 | +/// The initial value of ssthresh for new connections. Setting this to `None` |
| 27 | +/// implies `i32::MAX`. |
| 28 | +const INITIAL_SSTHRESH: Option<u32> = None; |
| 29 | +/// TODO |
| 30 | +const BIC_SCALE: u32 = 41; |
| 31 | +/// TODO |
| 32 | +const TCP_FRIENDLINESS: bool = true; |
| 33 | +/// Whether to use the HyStart slow start algorithm. |
| 34 | +const HYSTART: bool = true; |
| 35 | + |
| 36 | +impl hystart::HyStart for Cubic { |
| 37 | + /// Which mechanism to use for deciding when it is time to exit slow start. |
| 38 | + const DETECT: HystartDetect = HystartDetect::Both; |
| 39 | + /// Lower bound for cwnd during hybrid slow start. |
| 40 | + const LOW_WINDOW: u32 = 16; |
| 41 | + /// Spacing between ACKs indicating an ACK-train. |
| 42 | + /// (Dimension: time. Unit: microseconds). |
| 43 | + const ACK_DELTA: time::Usecs32 = 2000; |
| 44 | +} |
| 45 | + |
| 46 | +// TODO: Those are computed based on the module parameters in the init. Even |
| 47 | +// with module parameters available this will be a bit tricky to do in Rust. |
| 48 | +/// Factor of `8/3 * (1 + beta) / (1 - beta)` that is used in various |
| 49 | +/// calculations. (Dimension: none) |
| 50 | +const BETA_SCALE: u32 = ((8 * (BICTCP_BETA_SCALE + BETA)) / 3) / (BICTCP_BETA_SCALE - BETA); |
| 51 | +/// TODO |
| 52 | +const CUBE_RTT_SCALE: u32 = BIC_SCALE * 10; |
| 53 | +/// TODO |
| 54 | +const CUBE_FACTOR: u64 = (1u64 << 40) / (CUBE_RTT_SCALE as u64); |
| 55 | + |
| 56 | +module_cca! { |
| 57 | + type: Cubic, |
| 58 | + name: "tcp_cubic_rust", |
| 59 | + author: "Rust for Linux Contributors", |
| 60 | + description: "TCP CUBIC congestion control algorithm, Rust implementation", |
| 61 | + license: "GPL v2", |
| 62 | +} |
| 63 | + |
| 64 | +struct Cubic {} |
| 65 | + |
| 66 | +#[vtable] |
| 67 | +impl cong::Algorithm for Cubic { |
| 68 | + type Data = CubicState; |
| 69 | + |
| 70 | + const NAME: &'static CStr = c_str!("bic_rust"); |
| 71 | + |
| 72 | + fn init(sk: &mut cong::Sock<'_, Self>) { |
| 73 | + if HYSTART { |
| 74 | + <Self as hystart::HyStart>::reset(sk) |
| 75 | + } else if let Some(ssthresh) = INITIAL_SSTHRESH { |
| 76 | + sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); |
| 77 | + } |
| 78 | + |
| 79 | + // TODO: remove |
| 80 | + pr_info!("Socket created: start {}", sk.inet_csk_ca().start_time); |
| 81 | + } |
| 82 | + |
| 83 | + // TODO: remove |
| 84 | + fn release(sk: &mut cong::Sock<'_, Self>) { |
| 85 | + pr_info!( |
| 86 | + "Socket destroyed: start {}, end {}", |
| 87 | + sk.inet_csk_ca().start_time, |
| 88 | + (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 89 | + ); |
| 90 | + } |
| 91 | + |
| 92 | + fn cwnd_event(sk: &mut cong::Sock<'_, Self>, ev: cong::Event) { |
| 93 | + if matches!(ev, cong::Event::TxStart) { |
| 94 | + // Here we cannot avoid jiffies as the `lsndtime` field is measured |
| 95 | + // in jiffies. |
| 96 | + let now = time::jiffies32(); |
| 97 | + let delta: time::Jiffies32 = now.wrapping_sub(sk.tcp_sk().lsndtime()); |
| 98 | + |
| 99 | + if (delta as i32) <= 0 { |
| 100 | + return; |
| 101 | + } |
| 102 | + |
| 103 | + let ca = sk.inet_csk_ca_mut(); |
| 104 | + // Ok, lets switch to SI time units. |
| 105 | + let now = time::jiffies_to_msecs(now as time::Jiffies); |
| 106 | + let delta = time::jiffies_to_msecs(delta as time::Jiffies); |
| 107 | + if ca.epoch_start != 0 { |
| 108 | + ca.epoch_start += delta; |
| 109 | + if tcp::after(ca.epoch_start, now) { |
| 110 | + ca.epoch_start = now; |
| 111 | + } |
| 112 | + }; |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { |
| 117 | + if matches!(new_state, cong::State::Loss) { |
| 118 | + pr_info!( |
| 119 | + // TODO: remove |
| 120 | + "Retransmission timeout fired: time {}, start {}", |
| 121 | + (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 122 | + sk.inet_csk_ca().start_time |
| 123 | + ); |
| 124 | + sk.inet_csk_ca_mut().reset(); |
| 125 | + <Self as hystart::HyStart>::reset(sk); |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { |
| 130 | + todo!() |
| 131 | + } |
| 132 | + |
| 133 | + fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { |
| 134 | + let cwnd = sk.tcp_sk().snd_cwnd(); |
| 135 | + let ca = sk.inet_csk_ca_mut(); |
| 136 | + |
| 137 | + pr_info!( |
| 138 | + // TODO: remove |
| 139 | + "Enter fast retransmit: time {}, start {}", |
| 140 | + (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 141 | + ca.start_time |
| 142 | + ); |
| 143 | + |
| 144 | + // Epoch has ended. |
| 145 | + ca.epoch_start = 0; |
| 146 | + ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { |
| 147 | + (cwnd * (BETA_SCALE + BETA)) / (2 * BETA_SCALE) |
| 148 | + } else { |
| 149 | + cwnd |
| 150 | + }; |
| 151 | + |
| 152 | + max((cwnd * BETA) / BETA_SCALE, 2) |
| 153 | + } |
| 154 | + |
| 155 | + fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { |
| 156 | + pr_info!( |
| 157 | + // TODO: remove |
| 158 | + "Undo cwnd reduction: time {}, start {}", |
| 159 | + (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 160 | + sk.inet_csk_ca().start_time |
| 161 | + ); |
| 162 | + |
| 163 | + cong::reno::undo_cwnd(sk) |
| 164 | + } |
| 165 | + |
| 166 | + fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { |
| 167 | + if !sk.tcp_is_cwnd_limited() { |
| 168 | + return; |
| 169 | + } |
| 170 | + |
| 171 | + let tp = sk.tcp_sk_mut(); |
| 172 | + |
| 173 | + if tp.in_slow_start() { |
| 174 | + acked = tp.slow_start(acked); |
| 175 | + if acked == 0 { |
| 176 | + pr_info!( |
| 177 | + // TODO: remove |
| 178 | + "New cwnd {}, time {}, ssthresh {}, start {}, ss 1", |
| 179 | + sk.tcp_sk().snd_cwnd(), |
| 180 | + (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 181 | + sk.tcp_sk().snd_ssthresh(), |
| 182 | + sk.inet_csk_ca().start_time |
| 183 | + ); |
| 184 | + return; |
| 185 | + } |
| 186 | + } |
| 187 | + |
| 188 | + let cwnd = tp.snd_cwnd(); |
| 189 | + let cnt = sk.inet_csk_ca_mut().update(cwnd, acked); |
| 190 | + sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); |
| 191 | + |
| 192 | + pr_info!( |
| 193 | + // TODO: remove |
| 194 | + "New cwnd {}, time {}, ssthresh {}, start {}, ss 0", |
| 195 | + sk.tcp_sk().snd_cwnd(), |
| 196 | + (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 197 | + sk.tcp_sk().snd_ssthresh(), |
| 198 | + sk.inet_csk_ca().start_time |
| 199 | + ); |
| 200 | + } |
| 201 | +} |
| 202 | + |
| 203 | +#[allow(non_snake_case)] |
| 204 | +struct CubicState { |
| 205 | + /// Increase the cwnd by one step after `cnt` ACKs. |
| 206 | + cnt: NonZeroU32, |
| 207 | + /// W__last_max |
| 208 | + last_max_cwnd: u32, |
| 209 | + last_cwnd: u32, |
| 210 | + /// Time when `last_cwnd` was updated. |
| 211 | + last_time: time::Msecs32, |
| 212 | + origin_point: u32, |
| 213 | + K: time::Msecs32, |
| 214 | + /// Time when the current epoch has started. |
| 215 | + epoch_start: time::Msecs32, |
| 216 | + ack_cnt: u32, |
| 217 | + /// Estimate for the cwnd of TCP Reno. |
| 218 | + tcp_cwnd: u32, |
| 219 | + /// State of the HyStart slow start algorithm. |
| 220 | + hystart_state: hystart::HyStartState, |
| 221 | + /// Time when the connection was created. |
| 222 | + // TODO: remove |
| 223 | + start_time: time::Usecs32, |
| 224 | +} |
| 225 | + |
| 226 | +impl hystart::HasHyStartState for CubicState { |
| 227 | + fn hy(&self) -> &hystart::HyStartState { |
| 228 | + &self.hystart_state |
| 229 | + } |
| 230 | + |
| 231 | + fn hy_mut(&mut self) -> &mut hystart::HyStartState { |
| 232 | + &mut self.hystart_state |
| 233 | + } |
| 234 | +} |
| 235 | + |
| 236 | +impl Default for CubicState { |
| 237 | + fn default() -> Self { |
| 238 | + Self { |
| 239 | + // NOTE: Initializing this to 1 deviates from the C code. It does |
| 240 | + // not change the behavior. |
| 241 | + cnt: NonZeroU32::MIN, |
| 242 | + last_max_cwnd: 0, |
| 243 | + last_cwnd: 0, |
| 244 | + last_time: 0, |
| 245 | + origin_point: 0, |
| 246 | + K: 0, |
| 247 | + epoch_start: 0, |
| 248 | + ack_cnt: 0, |
| 249 | + tcp_cwnd: 0, |
| 250 | + hystart_state: hystart::HyStartState::default(), |
| 251 | + // TODO: remove |
| 252 | + start_time: (time::ktime_get_boot_fast_ns() / time::NSEC_PER_USEC) as time::Usecs32, |
| 253 | + } |
| 254 | + } |
| 255 | +} |
| 256 | + |
| 257 | +impl CubicState { |
| 258 | + fn update(&mut self, cwnd: u32, acked: u32) -> NonZeroU32 { |
| 259 | + todo!() |
| 260 | + } |
| 261 | + |
| 262 | + fn reset(&mut self) { |
| 263 | + // TODO: remove |
| 264 | + let tmp = self.start_time; |
| 265 | + |
| 266 | + *self = Self::default(); |
| 267 | + |
| 268 | + // TODO: remove |
| 269 | + self.start_time = tmp; |
| 270 | + } |
| 271 | +} |
0 commit comments