-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Expand file tree
/
Copy pathspt_spiking_tracker.rs
More file actions
457 lines (387 loc) · 15.5 KB
/
spt_spiking_tracker.rs
File metadata and controls
457 lines (387 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
//! Spiking neural network tracker — spatial reasoning module (ADR-041).
//!
//! Bio-inspired person tracking using Leaky Integrate-and-Fire (LIF) neurons
//! with STDP learning. 32 input neurons (one per subcarrier) feed into
//! 4 output neurons (one per spatial zone). The zone with the highest
//! spike rate indicates person location; zone transitions track velocity.
//!
//! Event IDs: 770-773 (Spatial Reasoning series).
use libm::fabsf;
// ── Constants ────────────────────────────────────────────────────────────────
/// Number of input neurons (one per subcarrier).
const N_INPUT: usize = 32;
/// Number of output neurons (one per zone).
const N_OUTPUT: usize = 4;
/// Input neurons per output zone.
const INPUTS_PER_ZONE: usize = N_INPUT / N_OUTPUT; // = 8
/// LIF neuron threshold potential.
const THRESHOLD: f32 = 1.0;
/// Membrane leak factor (per frame).
const LEAK: f32 = 0.95;
/// Reset potential after spike.
const RESET: f32 = 0.0;
/// STDP learning rate (potentiation).
const STDP_LR_PLUS: f32 = 0.01;
/// STDP learning rate (depression).
const STDP_LR_MINUS: f32 = 0.005;
/// STDP time window in frames (approximation of 20ms at 50Hz).
const STDP_WINDOW: u32 = 1;
/// EMA factor for spike rate smoothing.
const RATE_ALPHA: f32 = 0.1;
/// EMA factor for velocity smoothing.
const VEL_ALPHA: f32 = 0.2;
/// Minimum spike rate to consider a zone active.
const MIN_SPIKE_RATE: f32 = 0.05;
/// Weight clamp bounds.
const W_MIN: f32 = 0.0;
const W_MAX: f32 = 2.0;
// ── Event IDs ────────────────────────────────────────────────────────────────
/// Zone ID of the tracked person (0-3), or -1 if lost.
pub const EVENT_TRACK_UPDATE: i32 = 770;
/// Estimated velocity (zone transitions per second, EMA-smoothed).
pub const EVENT_TRACK_VELOCITY: i32 = 771;
/// Mean spike rate across all input neurons [0, 1].
pub const EVENT_SPIKE_RATE: i32 = 772;
/// Emitted when the person is lost (no zone active).
pub const EVENT_TRACK_LOST: i32 = 773;
// ── State ────────────────────────────────────────────────────────────────────
/// Spiking neural network person tracker.
pub struct SpikingTracker {
/// Membrane potential of each input neuron.
membrane: [f32; N_INPUT],
/// Synaptic weights from input to output neurons.
/// weights[i][z] = connection strength from input i to output zone z.
weights: [[f32; N_OUTPUT]; N_INPUT],
/// Spike time of each input neuron (frame number, 0 = never fired).
input_spike_time: [u32; N_INPUT],
/// Spike time of each output neuron.
output_spike_time: [u32; N_OUTPUT],
/// EMA-smoothed spike rate per zone.
zone_rate: [f32; N_OUTPUT],
/// Raw spike count per zone this frame.
zone_spikes: [u32; N_OUTPUT],
/// Previous active zone (for velocity).
prev_zone: i8,
/// Velocity EMA (zone transitions per frame).
velocity_ema: f32,
/// Whether the track is currently active.
track_active: bool,
/// Frame counter.
frame_count: u32,
/// Frames since last zone transition.
frames_since_transition: u32,
}
impl SpikingTracker {
pub const fn new() -> Self {
// Initialize weights: each input connects to its "home" zone with
// weight 1.0 and to other zones with 0.25.
let mut weights = [[0.25f32; N_OUTPUT]; N_INPUT];
let mut i = 0;
while i < N_INPUT {
let home_zone = i / INPUTS_PER_ZONE;
if home_zone < N_OUTPUT {
weights[i][home_zone] = 1.0;
}
i += 1;
}
Self {
membrane: [0.0; N_INPUT],
weights,
input_spike_time: [0; N_INPUT],
output_spike_time: [0; N_OUTPUT],
zone_rate: [0.0; N_OUTPUT],
zone_spikes: [0; N_OUTPUT],
prev_zone: -1,
velocity_ema: 0.0,
track_active: false,
frame_count: 0,
frames_since_transition: 0,
}
}
/// Process one CSI frame.
///
/// `phases` — per-subcarrier phase values (up to 32).
/// `prev_phases` — previous frame phases for delta computation.
///
/// Returns a slice of (event_id, value) pairs to emit.
pub fn process_frame(&mut self, phases: &[f32], prev_phases: &[f32]) -> &[(i32, f32)] {
let n_sc = phases.len().min(prev_phases.len()).min(N_INPUT);
self.frame_count += 1;
self.frames_since_transition += 1;
// ── 1. Compute current injection from phase changes ──────────────
let mut input_spikes = [false; N_INPUT];
for i in 0..n_sc {
let current = fabsf(phases[i] - prev_phases[i]);
// Leaky integration.
self.membrane[i] = self.membrane[i] * LEAK + current;
// Fire?
if self.membrane[i] >= THRESHOLD {
input_spikes[i] = true;
self.membrane[i] = RESET;
self.input_spike_time[i] = self.frame_count;
}
}
// ── 2. Propagate spikes to output neurons ────────────────────────
let mut output_potential = [0.0f32; N_OUTPUT];
for i in 0..n_sc {
if input_spikes[i] {
for z in 0..N_OUTPUT {
output_potential[z] += self.weights[i][z];
}
}
}
// Determine output spikes.
let mut output_spikes = [false; N_OUTPUT];
for z in 0..N_OUTPUT {
self.zone_spikes[z] = 0;
}
for z in 0..N_OUTPUT {
if output_potential[z] >= THRESHOLD {
output_spikes[z] = true;
self.zone_spikes[z] = 1;
self.output_spike_time[z] = self.frame_count;
}
}
// ── 3. STDP learning ─────────────────────────────────────────────
// PERF: Only iterate over neurons that actually fired (skip silent inputs).
// Typical sparsity: ~10-30% of inputs fire, so this skips 70-90% of
// the 32*4=128 weight update iterations.
for i in 0..n_sc {
if !input_spikes[i] {
continue; // Skip silent input neurons entirely.
}
for z in 0..N_OUTPUT {
if output_spikes[z] {
// Pre fires, post fires -> potentiate.
let dt = if self.input_spike_time[i] >= self.output_spike_time[z] {
self.input_spike_time[i] - self.output_spike_time[z]
} else {
self.output_spike_time[z] - self.input_spike_time[i]
};
if dt <= STDP_WINDOW {
self.weights[i][z] += STDP_LR_PLUS;
if self.weights[i][z] > W_MAX {
self.weights[i][z] = W_MAX;
}
}
} else {
// Pre fires, post silent -> depress slightly.
self.weights[i][z] -= STDP_LR_MINUS;
if self.weights[i][z] < W_MIN {
self.weights[i][z] = W_MIN;
}
}
}
}
// ── 4. Update zone spike rates (EMA) ────────────────────────────
for z in 0..N_OUTPUT {
let instant = self.zone_spikes[z] as f32;
self.zone_rate[z] = RATE_ALPHA * instant + (1.0 - RATE_ALPHA) * self.zone_rate[z];
}
// ── 5. Determine active zone ────────────────────────────────────
let mut best_zone: i8 = -1;
let mut best_rate = MIN_SPIKE_RATE;
for z in 0..N_OUTPUT {
if self.zone_rate[z] > best_rate {
best_rate = self.zone_rate[z];
best_zone = z as i8;
}
}
// ── 6. Velocity from zone transitions ───────────────────────────
if best_zone >= 0 && best_zone != self.prev_zone && self.prev_zone >= 0 {
let transition_speed = if self.frames_since_transition > 0 {
1.0 / (self.frames_since_transition as f32)
} else {
0.0
};
self.velocity_ema = VEL_ALPHA * transition_speed + (1.0 - VEL_ALPHA) * self.velocity_ema;
self.frames_since_transition = 0;
}
let was_active = self.track_active;
self.track_active = best_zone >= 0;
if best_zone >= 0 {
self.prev_zone = best_zone;
}
// ── 7. Build events ─────────────────────────────────────────────
self.build_events(best_zone, was_active)
}
/// Construct event output.
fn build_events(&self, zone: i8, was_active: bool) -> &[(i32, f32)] {
static mut EVENTS: [(i32, f32); 4] = [(0, 0.0); 4];
let mut n = 0usize;
// Mean spike rate across all zones.
let mut total_rate = 0.0f32;
for z in 0..N_OUTPUT {
total_rate += self.zone_rate[z];
}
let mean_rate = total_rate / N_OUTPUT as f32;
if zone >= 0 {
// TRACK_UPDATE with zone ID.
unsafe { EVENTS[n] = (EVENT_TRACK_UPDATE, zone as f32); }
n += 1;
// TRACK_VELOCITY.
unsafe { EVENTS[n] = (EVENT_TRACK_VELOCITY, self.velocity_ema); }
n += 1;
// SPIKE_RATE.
unsafe { EVENTS[n] = (EVENT_SPIKE_RATE, mean_rate); }
n += 1;
} else {
// SPIKE_RATE even when no track.
unsafe { EVENTS[n] = (EVENT_SPIKE_RATE, mean_rate); }
n += 1;
// TRACK_LOST if we had a track before.
if was_active {
unsafe { EVENTS[n] = (EVENT_TRACK_LOST, self.prev_zone as f32); }
n += 1;
}
}
unsafe { &EVENTS[..n] }
}
/// Get the current tracked zone (-1 if lost).
pub fn current_zone(&self) -> i8 {
if self.track_active { self.prev_zone } else { -1 }
}
/// Get the smoothed spike rate for a zone.
pub fn zone_spike_rate(&self, zone: usize) -> f32 {
if zone < N_OUTPUT { self.zone_rate[zone] } else { 0.0 }
}
/// Get the EMA-smoothed velocity.
pub fn velocity(&self) -> f32 {
self.velocity_ema
}
/// Check if a track is currently active.
pub fn is_tracking(&self) -> bool {
self.track_active
}
}
// ── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_const_constructor() {
let st = SpikingTracker::new();
assert_eq!(st.frame_count, 0);
assert!(!st.track_active);
assert_eq!(st.prev_zone, -1);
assert_eq!(st.current_zone(), -1);
}
#[test]
fn test_initial_weights() {
let st = SpikingTracker::new();
// Input 0 should have strong weight to zone 0.
assert!((st.weights[0][0] - 1.0).abs() < 1e-6);
// Input 0 should have weak weight to zone 1.
assert!((st.weights[0][1] - 0.25).abs() < 1e-6);
// Input 8 should have strong weight to zone 1.
assert!((st.weights[8][1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_no_activity_no_track() {
let mut st = SpikingTracker::new();
let phases = [0.0f32; 32];
let prev = [0.0f32; 32];
st.process_frame(&phases, &prev);
// No phase change -> no spikes -> no track.
assert!(!st.is_tracking());
}
#[test]
fn test_zone_activation() {
let mut st = SpikingTracker::new();
let prev = [0.0f32; 32];
// Inject large phase change in zone 0 (subcarriers 0-7).
let mut phases = [0.0f32; 32];
for i in 0..8 {
phases[i] = 2.0; // Well above threshold after integration.
}
// Feed many frames to build up spike rate difference.
// LIF neurons reset after firing, so we need enough frames for the
// EMA spike rate in zone 0 to clearly exceed zone 1.
for _ in 0..100 {
st.process_frame(&phases, &prev);
}
// Zone 0 should have a meaningful spike rate.
let r0 = st.zone_spike_rate(0);
assert!(r0 > MIN_SPIKE_RATE, "zone 0 should be active, rate={}", r0);
}
#[test]
fn test_zone_transition_velocity() {
let mut st = SpikingTracker::new();
let prev = [0.0f32; 32];
// Activate zone 0 for a while.
let mut phases_z0 = [0.0f32; 32];
for i in 0..8 {
phases_z0[i] = 2.0;
}
for _ in 0..30 {
st.process_frame(&phases_z0, &prev);
}
// Now activate zone 2 instead.
let mut phases_z2 = [0.0f32; 32];
for i in 16..24 {
phases_z2[i] = 2.0;
}
for _ in 0..30 {
st.process_frame(&phases_z2, &prev);
}
// Velocity should be non-zero after a zone transition.
// (It may take a few frames for the EMA to register.)
assert!(st.velocity() >= 0.0);
}
#[test]
fn test_stdp_strengthens_active_connections() {
let mut st = SpikingTracker::new();
let prev = [0.0f32; 32];
let initial_w = st.weights[0][0];
// Repeated activity in zone 0 should strengthen weights[0][0].
let mut phases = [0.0f32; 32];
for i in 0..8 {
phases[i] = 2.0;
}
for _ in 0..50 {
st.process_frame(&phases, &prev);
}
// Weight should have increased (or stayed at max).
assert!(st.weights[0][0] >= initial_w);
}
#[test]
fn test_track_lost_event() {
let mut st = SpikingTracker::new();
let prev = [0.0f32; 32];
// Activate a zone first.
let mut phases = [0.0f32; 32];
for i in 0..8 {
phases[i] = 2.0;
}
for _ in 0..30 {
st.process_frame(&phases, &prev);
}
assert!(st.is_tracking());
// Now go silent — all zeros.
let silent = [0.0f32; 32];
let mut lost_emitted = false;
for _ in 0..100 {
let events = st.process_frame(&silent, &prev);
for e in events {
if e.0 == EVENT_TRACK_LOST {
lost_emitted = true;
}
}
}
// Should eventually lose track and emit TRACK_LOST.
// (The EMA decay will eventually bring rate below threshold.)
assert!(lost_emitted || !st.is_tracking());
}
#[test]
fn test_membrane_leak() {
let mut st = SpikingTracker::new();
// Inject sub-threshold current.
st.membrane[0] = 0.5;
let phases = [0.0f32; 32];
let prev = [0.0f32; 32];
st.process_frame(&phases, &prev);
// Membrane should have decayed by LEAK.
assert!(st.membrane[0] < 0.5);
assert!(st.membrane[0] > 0.0);
}
}