# This file is part of SIMPLICITY
# Copyright (C) 2025 Pietro Gerletti
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Intra host transient state model of SARS-COV-2 infection
@author: Pietro Gerletti
"""
import numpy as np
import scipy.linalg
from multiprocessing import Pool
from tqdm import tqdm
import matplotlib.pyplot as plt
import simplicity.output_manager as om
import os
import pickle
from types import MethodType
[docs]
class Host:
'''
This class defines the intra-host model of SARS-CoV-2 pathogenesis.
'''
[docs]
def __init__(self, tau_1, tau_2, tau_3, tau_4, update_mode = 'matrix'):
# set up the model matrix
self.tau_1 = tau_1
self.tau_2 = tau_2
self.tau_3 = tau_3
self.tau_4 = tau_4
self.A = self._get_A_matrix(tau_1, tau_2, tau_3, tau_4)
# attributes for model solution
self.n_states = self.A.shape[0]
self.states = np.arange(0, self.n_states)
self.update_mode = update_mode # either jump or matrix
self.use_precomputed_matrix = True
if self.use_precomputed_matrix and update_mode == 'matrix':
self.exp_table = self._load_or_precompute_exponentials()
self.delta_t_not_in_table = []
self.get_A_t = self.factory_get_A_t()
[docs]
def get_update_mode(self):
return self.update_mode
[docs]
def get_jump_rate(self, state):
return -self.A[state][state]
@staticmethod
def _get_A_matrix(tau_1, tau_2, tau_3, tau_4):
'''
Generate the intra-host transition matrix A.
'''
# subphases number for each phase
n_1 = 5 # pre-detection
n_2 = 1 # pre-symptomatic
n_3 = 13 # infectious
n_4 = 1 # post-infectious
# last state is recovered
compartments = [[n_1, tau_1], [n_2, tau_2], [n_3, tau_3], [n_4, tau_4]]
dim = sum([n for n, _ in compartments]) + 1
A = np.zeros((dim, dim))
start = 0
comp = 0
for n, tau in compartments:
comp += n
r = n / tau
for i in range(start, comp + 1):
A[i][i] = -r
if A[i][i - 1] == 0:
A[i][i - 1] = r
A[i][-1] = 0
start += n
return A
[docs]
def factory_get_A_t(self):
'''
Compute or retrieve matrix exponential expm(A * t) from the precomputed table.
'''
if self.use_precomputed_matrix:
def get_A_t(self, delta_t):
key = round(delta_t, 8)
try:
return self.exp_table[key]
except KeyError:
self.delta_t_not_in_table.append(key)
return scipy.linalg.expm(self.A * delta_t)
else:
def get_A_t(self, delta_t):
return scipy.linalg.expm(self.A * delta_t)
return MethodType(get_A_t, self)
def _load_or_precompute_exponentials(self):
file_path = om.get_procomputed_matrix_table_filepath(self.tau_1,self.tau_2,self.tau_3,self.tau_4)
if os.path.exists(file_path):
with open(file_path, "rb") as f:
# print(f"Loaded matrix exponential table from {file_path}")
return pickle.load(f)
else:
print('Precomputing matrix exponentials...')
dts = self._generate_dts()
exp_table = {round(dt, 8): scipy.linalg.expm(self.A * dt) for dt in dts}
with open(file_path, "wb") as f:
pickle.dump(exp_table, f)
print(f"Saved {len(exp_table)} matrix exponentials to {file_path}")
return exp_table
@staticmethod
def _generate_dts():
dts_small = np.logspace(-5, -2, 200, endpoint=False)
dts_large = np.linspace(0.01, 10, 100)
return np.concatenate([dts_small, dts_large])
[docs]
def get_p_t(self, A_t, state):
'''
Compute state probability vector p(t) for given A^t and state.
'''
p0 = np.zeros(self.n_states)
p0[state] = 1
return np.matmul(A_t, p0)
[docs]
@staticmethod
def update_state(p_t, tau):
'''
Sample next state based on rejection sampling.
'''
p_cum = np.cumsum(p_t)
new_state = np.where(tau <= p_cum)[0][0]
return new_state
[docs]
def compute_all_probabilities(self, delta_t):
'''
Compute probability vectors for all initial states.
'''
A_t = self.get_A_t(delta_t)
return [self.get_p_t(A_t, i) for i in self.states]
[docs]
def data_plot_ih_solution(self,state,time,step):
'''
Compute:
p_inf - probability of being infectious after a time t
p_det - probability of being detectable after a time t
p_rec - probability of being recovered after a time t
Parameters
----------
state : int
Intra-host model starting state
time : float
Time for the intra-host model solution
Returns
-------
Either p_inf, p_dia or p_red. The output is used to plot the
intra-host model results.
'''
t = np.arange(0,time,step)
p_inf = []
p_det = []
p_rec = []
for time_point in t:
A_t = self.get_A_t(time_point)
p_i = self.get_p_t(A_t,state)
p_inf.append(np.sum(p_i[5:19]))
p_det.append(np.sum(p_i[5:20]))
p_rec.append(p_i[20])
return p_inf,p_det,p_rec
[docs]
def simulate_trajectory(self, delta_t, rng=None, exponential_dt=False):
'''
Simulate disease progression trajectory.
Parameters
----------
delta_t : float
Mean time step or fixed step size
rng : np.random.Generator
Optional random generator for reproducibility
exponential_dt : bool
If True, draw step from Exp(delta_t)
Returns
-------
trajectory, time_points, info : tuple
Full simulation result
'''
if rng is None:
rng = np.random.default_rng()
state = 0
t = 0
trajectory = [(t, state)] # initial state
while state < 20:
dt = rng.exponential(delta_t) if exponential_dt else delta_t
probabilities = self.compute_all_probabilities(dt)
p_t = probabilities[state]
tau = rng.random()
new_state = self.update_state(p_t, tau)
t += dt
if new_state != state:
trajectory.append((t, new_state))
state = new_state
return trajectory
[docs]
def save_results(filename, data):
with open(filename, 'wb') as f:
pickle.dump(data, f)
[docs]
def load_results(filename):
with open(filename, 'rb') as f:
return pickle.load(f)
def _simulate_worker(args):
delta_t, tau_1, tau_2, tau_3, tau_4, exponential_dt, seed = args
rng = np.random.default_rng(seed)
host = Host(tau_1=tau_1, tau_2=tau_2, tau_3=tau_3, tau_4=tau_4)
return host.simulate_trajectory(delta_t, rng=rng, exponential_dt=exponential_dt)
[docs]
def run_parallel_simulations(delta_t, n_runs=100, tau_1=2.86, tau_2=3.91, tau_3=7.5, tau_4=8.0, exponential_dt=False, base_seed=None):
args_list = [
(delta_t, tau_1, tau_2, tau_3, tau_4, exponential_dt, None if base_seed is None else base_seed + i)
for i in range(n_runs)
]
trajectories = []
with Pool() as pool:
results = list(tqdm(pool.imap(_simulate_worker, args_list), total=n_runs))
for trajectory in results:
trajectories.append(trajectory)
return trajectories
[docs]
def compute_state_durations(trajectories, max_state=20):
"""
Correctly compute residence time in each state, accounting for repeated states.
Returns: dict of state -> list of durations
"""
state_durations = {s: [] for s in range(max_state + 1)}
for traj in trajectories:
if not traj:
continue
prev_time, prev_state = traj[0]
for curr_time, curr_state in traj[1:]:
dt = curr_time - prev_time
state_durations[prev_state].append(dt)
prev_time, prev_state = curr_time, curr_state
return state_durations
[docs]
def compute_phase_durations(trajectories):
"""
Compute durations in major infection phases for each trajectory.
Returns: list of dicts with phase durations per individual
"""
results = []
infect_start=5
infect_end=18
detect_start=5
detect_end = 19
for traj in trajectories:
if not traj or len(traj) < 2:
continue
phase_info = {
"pre_infectious_duration": 0.0,
"infectious_duration": 0.0,
"detectable_duration": 0.0,
"total_duration": traj[-1][0] # time of final transition
}
for i in range(len(traj) - 1):
t0, s = traj[i]
t1, _ = traj[i + 1]
dt = t1 - t0
if s < infect_start:
phase_info["pre_infectious_duration"] += dt
if infect_start <= s <=infect_end:
phase_info["infectious_duration"] += dt
if detect_start <= s <= detect_end:
phase_info["detectable_duration"] += dt
results.append(phase_info)
return results
[docs]
def plot_state_duration_stats_grid(trajectories_dict, keys, title_prefix):
"""
Create a 2x2 grid of bar charts showing durations in each intra-host state (0–20)
for each Δt or λ value.
Parameters
----------
trajectories_dict : dict
Dict of { Δt or λ : list of trajectories ([(t, s), ...]) }
keys : list
List of Δt or λ values to plot
title_prefix : str
Title prefix for each subplot (e.g., 'Fixed', 'Exp')
"""
fig, axs = plt.subplots(2, 3, figsize=(14, 8), sharey=True)
axs = axs.flatten()
for idx, key in enumerate(keys):
ax = axs[idx]
durations = compute_state_durations(trajectories_dict[key])
states = sorted(durations.keys())
means = [np.mean(durations[s]) if durations[s] else 0 for s in states]
stds = [np.std(durations[s]) if durations[s] else 0 for s in states]
ax.bar(states, means, yerr=stds, capsize=4, color='skyblue', edgecolor='black')
ax.set_title(f"{title_prefix} {key}")
ax.set_xlabel("State")
ax.set_ylabel("Mean Duration (days)")
ax.set_xticks(states)
ax.grid(True, axis='y')
plt.tight_layout()
plt.show()
[docs]
def plot_duration_summary_scatter(fixed_results, exp_results, fixed_dts, exp_lambdas, x_shift=0.015):
"""
Compare phase durations vs. Δt (fixed) and λ (exp) using scatter plots with error bars.
Points are shifted slightly for clarity.
"""
categories = ["pre_infectious_duration", "infectious_duration", "detectable_duration", "total_duration"]
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
axs = axs.flatten()
for i, cat in enumerate(categories):
ax = axs[i]
# Shift fixed Δt slightly to the left
fixed_xs = [dt - x_shift*dt for dt in fixed_dts]
fixed_means = [np.mean(fixed_results[dt][cat]) for dt in fixed_dts]
fixed_stds = [np.std(fixed_results[dt][cat]) for dt in fixed_dts]
ax.errorbar(fixed_xs, fixed_means, yerr=fixed_stds, fmt='o', label='Fixed Δt', color='blue', capsize=4)
# Shift exp λ slightly to the right
exp_xs = [lmbda + x_shift*lmbda for lmbda in exp_lambdas]
exp_means = [np.mean(exp_results[lmbda][cat]) for lmbda in exp_lambdas]
exp_stds = [np.std(exp_results[lmbda][cat]) for lmbda in exp_lambdas]
ax.errorbar(exp_xs, exp_means, yerr=exp_stds, fmt='s', label='Exp(λ)', color='green', capsize=4)
ax.set_title(cat.replace('_', ' ').title())
ax.set_xlabel("Δt / λ")
ax.set_ylabel("Duration (days)")
ax.grid(True)
ax.legend()
ax.set_xscale('log')
plt.suptitle("Phase Durations: Fixed Δt vs Exp(λ)", fontsize=14)
plt.tight_layout()
plt.show()
[docs]
def plot_infectious_duration_vs_step(fixed_phase_durations, exp_phase_durations, fixed_dts, exp_lambdas):
"""
Plot infectious duration vs Δt or λ on log scale.
Parameters
----------
fixed_phase_durations : dict
{ Δt: list of dicts with phase durations }
exp_phase_durations : dict
{ λ: list of dicts with phase durations }
fixed_dts : list of floats
Fixed step sizes
exp_lambdas : list of floats
Exponential step scales
"""
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots(figsize=(8, 5))
# Fixed Δt
x_fixed = fixed_dts
y_fixed = [np.mean([d["infectious_duration"] for d in fixed_phase_durations[dt]]) for dt in fixed_dts]
yerr_fixed = [np.std([d["infectious_duration"] for d in fixed_phase_durations[dt]]) for dt in fixed_dts]
ax.errorbar(x_fixed, y_fixed, yerr=yerr_fixed, fmt='o', color='blue', label='Fixed Δt', capsize=4)
# Exp(λ)
x_exp = exp_lambdas
y_exp = [np.mean([d["infectious_duration"] for d in exp_phase_durations[lmbda]]) for lmbda in exp_lambdas]
yerr_exp = [np.std([d["infectious_duration"] for d in exp_phase_durations[lmbda]]) for lmbda in exp_lambdas]
ax.errorbar(x_exp, y_exp, yerr=yerr_exp, fmt='s', color='green', label='Exp(λ)', capsize=4)
ax.set_xscale('log')
ax.set_xlabel("Δt / λ (log scale)")
ax.set_ylabel("Infectious Duration (days)")
ax.set_title("Infectious Duration vs Δt / λ")
ax.grid(True, which='both', axis='both')
ax.legend()
plt.tight_layout()
plt.show()
[docs]
def plot_state_timeline_summary(state_durations_dict, phase_durations_dict, title_prefix=""):
"""
Visualize average residence times for each state as a timeline-style plot (one per Δt or λ).
Each state's duration is shown as a horizontal line, placed sequentially on the time axis.
States are color-coded by the infection phase they belong to.
Parameters
----------
state_durations_dict : dict
Dictionary of {Δt or λ: state_durations}, where each state_durations is a dict:
{ state_index -> list of durations }
(e.g. output of compute_state_durations)
title_prefix : str
Prefix to add to each subplot title, e.g. "Fixed" or "Exp"
"""
# Define which states belong to which biological phase (color-coded)
def get_phase_color(state):
if state < 5:
return "Pre-infectious", "#1f77b4" # blue
elif 5 <= state <= 18:
return "Infectious", "#d62728" # red
elif state == 19:
return "Detectable", "#2ca02c" # green
else:
return "Final", "gray" # absorbing state
fig, axs = plt.subplots(2, 3, figsize=(14, 8), sharey=True)
axs = axs.flatten()
keys = list(state_durations_dict.keys())
# compute global max time
max_total_time = 0
for durations in state_durations_dict.values():
total_time = sum(np.mean(durations[s]) for s in range(20) if durations.get(s))
max_total_time = max(max_total_time, total_time)
for idx, key in enumerate(keys):
ax = axs[idx]
state_durations = state_durations_dict[key]
center_time = 0 # center of first state bar
prev_mean_dur = 0
for state in range(20):
durations = state_durations.get(state, [])
if not durations:
continue
mean_dur = np.mean(durations)
std_dur = np.std(durations)
phase_label, color = get_phase_color(state)
start = center_time - mean_dur / 2
end = center_time + mean_dur / 2
ax.hlines(
y=state,
xmin=start,
xmax=end,
color=color,
linewidth=3,
label=phase_label if state in [0, 5, 19] else ""
)
ax.errorbar(
x=center_time,
y=state,
xerr=std_dur / 2,
fmt='none',
ecolor='black',
capsize=3
)
# advance center to the next state center
center_time += 0.5 * (prev_mean_dur + mean_dur)
prev_mean_dur = mean_dur
# === Compute mean durations for each phase and total ===
phases = phase_durations_dict.get(key, [])
if phases:
pre_mean = np.mean([p["pre_infectious_duration"] for p in phases])
inf_mean = np.mean([p["infectious_duration"] for p in phases])
det_mean = np.mean([p["detectable_duration"] for p in phases])
total_mean = np.mean([p["total_duration"] for p in phases])
# === Plot vertical lines for mean durations ===
# Pre-infectious duration
ax.axvline(pre_mean, color='blue', linestyle='--')
ax.text(pre_mean, 11, f"Pre: {pre_mean:.1f}", color='blue', ha='left', fontsize=8)
# Infectious duration (cumulative from pre)
ax.axvline(inf_mean, color='red', linestyle='--')
ax.text(inf_mean, 9, f"Inf: {inf_mean:.1f}", color='red', ha='left', fontsize=8)
# Detectable duration (cumulative from pre + inf)
ax.axvline(det_mean, color='green', linestyle='--')
ax.text(det_mean, 7, f"Det: {det_mean:.1f}", color='green', ha='left', fontsize=8)
# Total duration (may differ slightly due to overlap)
ax.axvline(total_mean, color='black', linestyle=':', linewidth=1.5)
ax.text(total_mean, 5, f"Total: {total_mean:.1f}", color='black', ha='left', fontsize=8)
ax.set_title(f"{title_prefix} Δt = {key}")
ax.set_xlabel("Time (days)")
ax.set_ylabel("State")
ax.set_yticks(range(0, 20))
ax.grid(True, axis='x')
ax.set_xlim(0, max_total_time)
ax.legend()
plt.suptitle("Timeline of Average State Residence Times (Phase Colored)", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
[docs]
def main():
fixed_dts = [0.0001,0.001, 0.01, 0.1, 1.0]#, 10]
exp_lambdas = [0.0001,0.001, 0.01, 0.1, 1.0]#, 10]
# fixed_dts = [1,2,4,8,10]
# exp_lambdas = [1,2,4,8,10]
n_runs = 100
# File paths
fixed_results_file = 'fixed_results.pkl'
exp_results_file = 'exp_results.pkl'
## -----------------------------
# Load or run simulations
# -----------------------------
if os.path.exists(fixed_results_file):
print("Loading fixed Δt results...")
fixed_trajectories = load_results(fixed_results_file)
else:
fixed_trajectories = {}
for dt in fixed_dts:
print(f"Running fixed Δt = {dt}")
fixed_trajectories[dt] = run_parallel_simulations(
delta_t=dt, n_runs=n_runs, exponential_dt=False, base_seed=13
)
save_results(fixed_results_file, fixed_trajectories)
if os.path.exists(exp_results_file):
print("Loading exp(λ) results...")
exp_trajectories = load_results(exp_results_file)
else:
exp_trajectories = {}
for lmbda in exp_lambdas:
print(f"Running exp(λ) = {lmbda}")
exp_trajectories[lmbda] = run_parallel_simulations(
delta_t=lmbda, n_runs=n_runs, exponential_dt=True, base_seed=42
)
save_results(exp_results_file, exp_trajectories)
# -----------------------------
# Compute state & phase durations
# -----------------------------
fixed_state_durations = {
dt: compute_state_durations(fixed_trajectories[dt])
for dt in fixed_dts
}
exp_state_durations = {
lmbda: compute_state_durations(exp_trajectories[lmbda])
for lmbda in exp_lambdas
}
fixed_phase_durations = {
dt: compute_phase_durations(fixed_trajectories[dt])
for dt in fixed_dts
}
exp_phase_durations = {
lmbda: compute_phase_durations(exp_trajectories[lmbda])
for lmbda in exp_lambdas
}
# -----------------------------
# Print summary stats
# -----------------------------
print("\nFixed Δt Phase Durations:")
for dt in fixed_dts:
print(f"Δt = {dt}")
for key in ["pre_infectious_duration", "infectious_duration", "detectable_duration", "total_duration"]:
vals = [d[key] for d in fixed_phase_durations[dt]]
print(f" {key}: mean = {np.mean(vals):.2f}, std = {np.std(vals):.2f}")
print("\nExp(λ) Phase Durations:")
for lmbda in exp_lambdas:
print(f"λ = {lmbda}")
for key in ["pre_infectious_duration", "infectious_duration", "detectable_duration", "total_duration"]:
vals = [d[key] for d in exp_phase_durations[lmbda]]
print(f" {key}: mean = {np.mean(vals):.2f}, std = {np.std(vals):.2f}")
# ----------------------------
# Plotting
# ----------------------------
print("\nPlotting: State Duration Stats (grid)...")
plot_state_duration_stats_grid(fixed_trajectories, fixed_dts, title_prefix="Fixed")
plot_state_duration_stats_grid(exp_trajectories, exp_lambdas, title_prefix="Exp")
print("\nPlotting: Phase Duration Summary (scatter)...")
fixed_phase_for_plot = {
dt: {k: [d[k] for d in fixed_phase_durations[dt]] for k in fixed_phase_durations[dt][0]}
for dt in fixed_dts
}
exp_phase_for_plot = {
lmbda: {k: [d[k] for d in exp_phase_durations[lmbda]] for k in exp_phase_durations[lmbda][0]}
for lmbda in exp_lambdas
}
plot_duration_summary_scatter(fixed_phase_for_plot, exp_phase_for_plot, fixed_dts, exp_lambdas)
print("\nPlotting: Average State Residence Time Timelines...")
plot_state_timeline_summary(fixed_state_durations, fixed_phase_durations, title_prefix="Fixed")
plot_state_timeline_summary(exp_state_durations, exp_phase_durations, title_prefix="Exp")
plot_infectious_duration_vs_step(
fixed_phase_durations=fixed_phase_durations,
exp_phase_durations=exp_phase_durations,
fixed_dts=fixed_dts,
exp_lambdas=exp_lambdas
)
if __name__ == '__main__':
main()
# import cProfile
# import pstats
# if __name__ == '__main__':
# with cProfile.Profile() as pr:
# main() # or whatever your entry function is
# stats = pstats.Stats(pr)
# stats.strip_dirs()
# stats.sort_stats("cumtime").print_stats(20) # top 20 slowest calls