18. Final for Outbreak Science I#
import numpy as np
import scipy
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
WMMM = pd.read_csv("https://raw.githubusercontent.com/computationalUncertaintyLab/outbreak_book/refs/heads/main/outbreak_science/WMM_data/wmm_spring2026__2026-04-30T13-44_export.csv")
WMMM = WMMM.loc[ WMMM.success == 1 ]
WMMM = WMMM.rename( columns = {"Actor":"Infector", "Audience":"Infectee"} )
print(WMMM.head(5))
#--create an empty graph
g = nx.Graph()
#--loops through every edge in the WMM dataframe
#--add an edge for every (infector, infected) pair.
for index,row in WMMM.iterrows():
g.add_edge( row.Infector, row.Infectee )
#--assign s,i,r states to every node
#--make sure everyone is susceptible
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--add an infection
def infect_random_node(g):
nodes = list(g.nodes()) #--a list of all the nodes in the network
patient0 = np.random.choice(nodes) #--choose a node at random
#--change that node's disease state from susceptible to infected.
g.nodes[patient0]["s"]=0
g.nodes[patient0]["i"]=1
g.nodes[patient0]["r"]=0
infect_random_node(g) #<--Run function above
#--Add recovery times to your network
#--numnber of nodes in network
N = len(g)
recovery_times = np.random.gamma(5,1,size=N)
##--assigning infectious periods to everyone from the list recovery_times
for r,node in zip(recovery_times, g.nodes()):
g.nodes[node]["t"] = r
Unnamed: 0 Infector Infectee success timestamp
0 0 exp626 thm220 1 3/24/26 9:28
1 1 thm220 oya226 1 3/24/26 16:21
2 2 thm220 edj227 1 3/24/26 16:27
3 3 thm220 tbp226 1 3/24/26 16:28
5 5 thm220 jeo227 1 3/24/26 16:28
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
degree_dict = dict(g.degree())
degrees = np.array([degree_dict[n] for n in g.nodes()])
node_sizes = 30 + 120 * (degrees - degrees.min()) / (
degrees.max() - degrees.min() + 1e-9
)
# Put all nodes on one radial shell
#pos = nx.shell_layout(g)
#pos = nx.nx_agraph.graphviz_layout(g, prog="twopi")
pos = nx.nx_pydot.graphviz_layout(g, prog="neato")
fig, ax = plt.subplots(figsize=(8, 8), dpi=150)
nx.draw_networkx_edges(
g,
pos,
ax=ax,
edge_color="black",
alpha=0.08,
width=0.5
)
nx.draw_networkx_nodes(
g,
pos,
ax=ax,
node_size=node_sizes,
node_color=degrees,
cmap="viridis",
alpha=0.9,
linewidths=0.5,
edgecolors="white"
)
ax.set_title("Watermelon Meow Meow Network", fontsize=18, fontweight="bold", pad=15)
ax.set_axis_off()
ax.set_aspect("equal")
plt.tight_layout()
plt.show()
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[3], line 15
11
12 # Put all nodes on one radial shell
13 #pos = nx.shell_layout(g)
14 #pos = nx.nx_agraph.graphviz_layout(g, prog="twopi")
---> 15 pos = nx.nx_pydot.graphviz_layout(g, prog="neato")
16
17 fig, ax = plt.subplots(figsize=(8, 8), dpi=150)
18
File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/networkx/drawing/nx_pydot.py:280, in graphviz_layout(G, prog, root)
250 def graphviz_layout(G, prog="neato", root=None):
251 """Create node positions using Pydot and Graphviz.
252
253 Returns a dictionary of positions keyed by node.
(...) 278 This is a wrapper for pydot_layout.
279 """
--> 280 return pydot_layout(G=G, prog=prog, root=root)
File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/networkx/drawing/nx_pydot.py:321, in pydot_layout(G, prog, root)
283 def pydot_layout(G, prog="neato", root=None):
284 """Create node positions using :mod:`pydot` and Graphviz.
285
286 Parameters
(...) 319
320 """
--> 321 import pydot
323 P = to_pydot(G)
324 if root is not None:
ModuleNotFoundError: No module named 'pydot'
Function one: From Susceptible to infected#
We need a function titled from_s_to_i that takes as input a node (n) and a network (g) and updates this node with the disease state infected.
This means that we change the state S to 0, the state I to 1, and the state R to 0.
Your function will need to select a node n from a graph g, using the g.nodes method and modify the “s”, “i”, and the “r” attributes.
Hint: Take a look at the code above where i assigned all the nodes to the susceptible state.
#--this is the answer the function that you will need to implement.
def from_s_to_i(n,g):
pass
#--this is for students to test their code
def test_function_one( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--change state
fun(1,g)
node = g.nodes[1]
if node["s"] == 0 and node["i"] == 1 and node["r"] == 0:
print("Pass")
else:
print("Did not pass, keep trying!")
#--example test
test_function_one( from_s_to_i )
Did not pass, keep trying!
Function two: From Infected to Removed#
We need a function titled from_i_to_r that takes as input a node and a network and updates this node with the disease state removed. This means that we change the state S to 0, the state I to 0, and the state R to 1.
This code should look extremely similar to the code you used to prouce from_s_to_i except that we are changing the disease state of a node to removed. Hint Start by copy and pasting your from_s_to_i code.
#--this is the answer
def from_i_to_r(n,g):
pass
#--this is for students
def test_function_two( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--change state
fun(5,g)
node = g.nodes[5]
if node["s"] == 0 and node["i"] == 0 and node["r"] == 1:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_two( from_i_to_r)
Did not pass, keep trying!
Function three: Check if a node is in the infected state#
We need a function titled is_infected that takes as input a node and a network and returns the value one if the node is in the infected state and the value zero otherwise.
Remeber that you can access a dictionary of the attributes of a node using the g.nodes method.
For example, if i access the node titled “thm220” then a dictionary of attributes will be returned.
#--accessing a dictionary of attirbutes that are attached to the node "thm220"
g.nodes["thm220"]
{'s': 1, 'i': 0, 'r': 0, 't': np.float64(4.867878776169396)}
#--this is the correct answer
def is_infected(n,g):
pass
#--this is for students
def test_function_three( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--change state
from_s_to_i(5,g)
if is_infected(5,g)==1 and is_infected(2,g)==0:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_three( is_infected )
Did not pass, keep trying!
Function four: Check if a node is in the removed state#
We need a function titled is_removed that takes as input a node and a network and returns the value one if the node is in the removed state and the value zero otherwise.
This will look very close to the function above is_infected. Hint Start by copy and pasting your code above for is_infected.
#--this is the correct answer
def is_removed(n,g):
pass
#--this is for students
def test_function_four( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--change state
from_s_to_i(5,g)
from_i_to_r(3,g)
if is_removed(3,g)==1 and is_removed(5,g)==0:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_four( is_removed )
Did not pass, keep trying!
Function five: Count the number of infected nodes#
We need a function titled ninfected that takes as input a node and a network and returns the number of neighbors of node that are in the infected state. You will need the networkx method neighbors https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.neighbors.html and you should use the function you wrote above called is_infected to count all the nodes that are infected.
The neigbors method takes as input a node and returns an iterator of all the neighbors of that node.
Iterators can be a bit tricky to work with and so we can change change an interator to a list with the list function.
For example, suppose i wish to return the list of all nodes that are neighbors of the node “profm”.
Then i can write
list(g.neighbors("thm220"))
['exp626',
'oya226',
'edj227',
'tbp226',
'jeo227',
'mil426',
'jek227',
'sab926',
'haw226',
'dal226',
'lel225',
'rac327',
'lom226',
'mav228',
'anp326',
'abn226',
'zcb226']
#--this is the correct answer
def ninfected(n,g):
pass
#--this is for students to test
def test_function_five( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--choose node 5 and add neighbors until three
num_neighbors = len(list(g.neighbors(5)))
while num_neighbors <= 3:
non_neighbors = set(np.arange(10)) - set(list(g.neighbors(5)))
node_to_add = np.random.choice(list(non_neighbors))
g.add_edge(5,node_to_add)
num_neighbors = len(list(g.neighbors(5)))
#--change state
random_neighbors = np.random.choice(list(g.neighbors(5)), 3, replace=False)
for node in random_neighbors:
from_s_to_i(node,g)
result = ninfected(5,g)
if not isinstance(result, int):
print("Make sure to return an integer")
else:
if ninfected(5,g)==3:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_five(ninfected)
Make sure to return an integer
Function six: Probability of infection#
We need a function titled prob_of_infection that takes as input a node, network, and a probability of transmission p and returns the probability that at least one infected neighbor will infect node.
You should use the function ninfected to compute the number of infected neighbors of node.
#--this is the correct answer
def prob_of_infection(n,g,p):
pass
#--this is for students to test
def test_function_six( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--add some certain edges
g.add_edge(1,5)
g.add_edge(2,5)
g.add_edge(9,5)
#--infect
from_s_to_i(1,g)
from_s_to_i(9,g)
#--test
p=0.3
correct_ans = 1 - (1-p)**2
proposed_ans = prob_of_infection(5,g,p)
if correct_ans == proposed_ans:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_six(prob_of_infection)
Did not pass, keep trying!
Function seven: List all the infected nodes#
We will need a function titled infected_nodes that takes as input a network and returns a list of all the nodes that are in the infected state. You can create an empty list titled infected_nodes, loop through all nodes \((n)\) in a network g using the nodes() method, determine if the node \(n\) is infected and append it to the list infected_nodes. If a node is not infected than we skip that node and move to the next node in the loop.
Below is a template for how to loop through all the nodes in your network g
for node in g.nodes():
<Code to determine if a node is infected and add to a list titled infected_nodes>
#--this is the correct answer
def infected_nodes(g):
pass
#--this is for students to test
def test_function_seven( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
#--infect
from_s_to_i(1,g)
from_s_to_i(9,g)
from_s_to_i(3,g)
from_s_to_i(2,g)
ans = fun(g)
if not isinstance(ans, list):
print("Must be a list")
else:
if set(ans) == set([1,9,3,2]):
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_seven(infected_nodes)
Must be a list
Function eight: Decrement recovery time#
We will need a function titled decrement_time that takes as input a node and a network and decreases the t attribute that is attached to node by one.
This function does not need to return anything, it needs to decrement an attribute of our network.
#--this is the correct answer
def decrement_time(n,g):
pass
#--this is for students to test
def test_function_eight( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
g.nodes[node]["t"] = 0
g.nodes[5]["t"] = 1
#--run function
decrement_time(5,g)
ans = g.nodes[5]["t"]
correct_ans = 0
if ans == correct_ans:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_eight( decrement_time )
Did not pass, keep trying!
Function nine: Is time up#
We will need a function titled is_time_up that takes as input a node and a network and returns the value one if the t attribute is less than or equal to zero and returns the value zero otherwise.
This function will determine if the infectious period for a node is over.
#--this is the correct answer
def is_time_up(n,g):
pass
#--this is for students to test
def test_function_nine( fun ):
import networkx as nx
g = nx.gnm_random_graph(10, 5)
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
g.nodes[node]["t"] = 0
g.nodes[5]["t"] = 1
#--run function
decrement_time(5,g)
ans = is_time_up(5,g)
correct_ans = 1
if ans == correct_ans:
print("Pass")
else:
print("Did not pass, keep trying!")
test_function_nine( is_time_up )
Did not pass, keep trying!
Putting all of these functions to use to simulate an outbreak.#
Given a parameter p that describes the probability of an effective contact, we will simulate an outbreak on a network over T time steps.
Here are all the steps that use your nine function to generate an outbreak.
#--set our parameters
T = 20
p = 0.90
#--infect one node at random.
infect_random_node(g)
#--create an empty list that will store the number of prevalent infections over time
infections_over_time = []
for time in range(T):
break #--<<remove this when you start coding
#--Create a var called previous_infections that is a list of all infected nodes
previous_infections = infected_nodes(g)
#--Create a var called num_previous_infections that counts the number of infected nodes
num_previous_infections = len(previous_infections)
#--Append num_previous_infections to the list infections_over_time
infections_over_time.append(num_previous_infections)
#--Section: propogate infections
for node in g.nodes():
#--Compute the probability that node is infected (use your prob_of_infection function)
#--MAKE SURE TO CALL THIS PROBABILITY "prob"
prob = prob_of_infection(node,g,p)
#--If the node is in the removed state, skip this node by using the "continue" comand
#--Note: you will want to use your is_removed function.
if is_removed(node,g):
continue
#--This piece of code infects a node with probability prob
if np.random.random() < prob:
from_s_to_i(node,g)
#--Section: Recovery
#--Loop through all infected individuals
for infected in previous_infections:
#--include code that changes the disease state of the node infected to the recovered state
#--you will need your function titled from_i_to_r
from_i_to_r(infected,g)
Please plot the list infections_over_time.
Add an xlabel using either plt.xlabel or ax.set_xlabel with the text “Time” and a ylabel with the text “Number of infections”.
#--student will complete this plot
FINAL STEP#
Please wrap all of the code above into a function titled simulate_outbreak.
This function will take two arguments: T and p and returns infections_over_time.
def simulate_outbreak(T,p=0.1):
#--make sure everyone is susceptible
for node in g.nodes():
g.nodes[node]["s"] = 1
g.nodes[node]["i"] = 0
g.nodes[node]["r"] = 0
##--assigning infectious periods to everyone if they get infected
recovery_times = np.random.gamma(5,1, size=len(g))
for r,node in zip(recovery_times, g.nodes()):
g.nodes[node]["t"] = r
#--add an infection
patient0 = np.random.choice(list(g.nodes()))
g.nodes[patient0]["s"]=0
g.nodes[patient0]["i"]=1
g.nodes[patient0]["r"]=0
infections_over_time = []
for time in range(T):
previous_infections = infected_nodes(g)
num_previous_infections = len(previous_infections)
infections_over_time.append(num_previous_infections)
#--propogate infections
for node in g.nodes():
prob = prob_of_infection(node,g,p)
#--skip this node if it is removed
if is_removed(node,g):
continue
#--infection!
if np.random.random() < prob:
from_s_to_i(node,g)
#--remove all previous infections
for infected in previous_infections:
decrement_time(infected,g)
if is_time_up(infected,g):
from_i_to_r(infected,g)
return infections_over_time
The codebelow takes your simulate outbreak function, runs this 1000 times, and plots the number of prevalent cases for each iteration.
import sys
infections = []
for i in range(1000):
break #<--remove this when you are ready
sys.stdout.write("\r Simulating {:03d}".format(i))
infections_over_time = simulate_outbreak(T=50,p=0.1)
plt.plot(infections_over_time, color= "blue", alpha=0.1)
plt.ylabel("Incident infections")
plt.xlabel("time")
plt.show()