-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsimulation_functions.py
More file actions
162 lines (129 loc) · 6.23 KB
/
simulation_functions.py
File metadata and controls
162 lines (129 loc) · 6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import numpy as np
from typing import List, Tuple
def run_simulation(transition_matrix, mean_matrix, std_dev_matrix,
min_cutoff_matrix, max_cutoff_matrix, distribution_matrix,
initial_state_idx, states: List[str]) -> List[Tuple[str, float]]:
"""Run a single simulation of disease progression
Note: Input matrices contain times in days, but output timeline is in hours
Args:
transition_matrix: Matrix of transition probabilities
mean_matrix: Matrix of mean transition times
std_dev_matrix: Matrix of standard deviations
min_cutoff_matrix: Matrix of minimum transition times
max_cutoff_matrix: Matrix of maximum transition times
distribution_matrix: Matrix of distribution types
initial_state_idx: Index of initial state
states: List of state names
Returns:
List of (state_name, time) tuples where time is in hours
"""
HOURS_PER_DAY = 24 # Convert days to hours
# Initialize current state with the provided index
current_state = int(initial_state_idx) # Ensure it's an integer
timeline = [(states[current_state], 0.0)] # Start at time 0
total_time = 0.0
max_iterations = 1000 # Add safety limit
iteration = 0
while iteration < max_iterations:
iteration += 1
# Get transition probabilities for current state
transition_probs = transition_matrix[current_state]
# If all transition probabilities are 0, we're in a terminal state
if np.sum(transition_probs) == 0:
break
# Choose next state based on transition probabilities
next_state = np.random.choice(len(states), p=transition_probs)
# If we stay in the same state, increment counter but don't add to timeline
if next_state == current_state:
continue
# Generate time for this transition (in days)
time_days = generate_transition_time(
mean_matrix[current_state][next_state],
std_dev_matrix[current_state][next_state],
min_cutoff_matrix[current_state][next_state],
max_cutoff_matrix[current_state][next_state],
distribution_matrix[current_state][next_state]
)
# Convert days to hours
time_hours = time_days * HOURS_PER_DAY
total_time += time_hours
timeline.append((states[next_state], total_time))
current_state = next_state
# Check if we've reached a terminal state (Recovered or Deceased)
if states[current_state] in ["Recovered", "Deceased"]:
break
if iteration >= max_iterations:
print(f"Warning: Simulation stopped after {max_iterations} iterations")
# Print timeline for each run
# print("\nSimulation Timeline:")
# for state, time in timeline:
# print(f"{time:.2f} hours: {state}")
return timeline
def generate_transition_time(mean: float, std_dev: float,
min_cutoff: float, max_cutoff: float,
distribution_type: int) -> float:
"""Generate time for transitioning between states
Args:
mean: Mean time for transition
std_dev: Standard deviation
min_cutoff: Minimum allowed time
max_cutoff: Maximum allowed time
distribution_type: Type of distribution (0-4)
0: Fixed (mean)
1: Normal
2: Uniform
3: Log-normal
4: Gamma
Returns:
float: Generated transition time
"""
while True:
if distribution_type == 0: # Fixed time
return mean
elif distribution_type == 1: # Normal
time = np.random.normal(mean, std_dev)
elif distribution_type == 2: # Uniform
time = np.random.uniform(mean - std_dev, mean + std_dev)
elif distribution_type == 3: # Log-normal
# Convert mean and std_dev to log-normal parameters
mu = np.log(mean**2 / np.sqrt(std_dev**2 + mean**2))
sigma = np.sqrt(np.log(1 + (std_dev**2 / mean**2)))
time = np.random.lognormal(mu, sigma)
elif distribution_type == 4: # Gamma
# Convert mean and std_dev to gamma parameters
shape = (mean / std_dev)**2
scale = std_dev**2 / mean
time = np.random.gamma(shape, scale)
else:
raise ValueError(f"Unknown distribution type: {distribution_type}")
# Ensure time falls within allowed range
if min_cutoff <= time <= max_cutoff:
return time
def validate_matrices(transition_matrix, mean_matrix, std_dev_matrix,
min_cutoff_matrix, max_cutoff_matrix, distribution_matrix):
"""Validate all matrices have correct properties
Args:
transition_matrix: Matrix of transition probabilities
mean_matrix: Matrix of mean transition times
std_dev_matrix: Matrix of standard deviations
min_cutoff_matrix: Matrix of minimum transition times
max_cutoff_matrix: Matrix of maximum transition times
distribution_matrix: Matrix of distribution types
Raises:
ValueError: If any matrix fails validation
"""
# Check transition matrix properties
if not np.all((transition_matrix >= 0) & (transition_matrix <= 1)):
raise ValueError("Transition probabilities must be between 0 and 1")
row_sums = np.sum(transition_matrix, axis=1)
if not np.allclose(row_sums, 1.0) and not np.allclose(row_sums, 0.0):
raise ValueError("Transition matrix rows must sum to 1 or 0")
# Check other matrices
if not np.all(mean_matrix >= 0):
raise ValueError("Mean times must be non-negative")
if not np.all(std_dev_matrix >= 0):
raise ValueError("Standard deviations must be non-negative")
if not np.all(min_cutoff_matrix <= max_cutoff_matrix):
raise ValueError("Min cutoff must be less than or equal to max cutoff")
if not np.all((distribution_matrix >= 0) & (distribution_matrix <= 4)):
raise ValueError("Distribution types must be between 0 and 4")