18. Final for Outbreak Science I#

import numpy as np 
import scipy
import networkx as nx 
import matplotlib.pyplot as plt 
import pandas as pd 
WMMM = pd.read_csv("https://raw.githubusercontent.com/computationalUncertaintyLab/outbreak_book/refs/heads/main/outbreak_science/WMM_data/wmm_spring2026__2026-04-30T13-44_export.csv")

WMMM = WMMM.loc[ WMMM.success == 1 ]
WMMM = WMMM.rename( columns = {"Actor":"Infector", "Audience":"Infectee"} )
print(WMMM.head(5))

#--create an empty graph
g = nx.Graph()

#--loops through every edge in the WMM dataframe
#--add an edge for every (infector, infected) pair. 
for index,row in WMMM.iterrows():
    g.add_edge( row.Infector, row.Infectee )

#--assign s,i,r states to every node
#--make sure everyone is susceptible
for node in g.nodes():
    g.nodes[node]["s"] = 1
    g.nodes[node]["i"] = 0
    g.nodes[node]["r"] = 0

#--add an infection 
def infect_random_node(g):
    nodes = list(g.nodes())            #--a list of all the nodes in the network
    patient0 = np.random.choice(nodes) #--choose a node at random
        
    #--change that node's disease state from susceptible to infected. 
    g.nodes[patient0]["s"]=0
    g.nodes[patient0]["i"]=1
    g.nodes[patient0]["r"]=0
infect_random_node(g) #<--Run function above


#--Add recovery times to your network
#--numnber of nodes in network
N = len(g)
recovery_times = np.random.gamma(5,1,size=N)

##--assigning infectious periods to everyone from the list recovery_times
for r,node in zip(recovery_times, g.nodes()):
    g.nodes[node]["t"] = r
   Unnamed: 0 Infector Infectee  success      timestamp
0           0   exp626   thm220        1   3/24/26 9:28
1           1   thm220   oya226        1  3/24/26 16:21
2           2   thm220   edj227        1  3/24/26 16:27
3           3   thm220   tbp226        1  3/24/26 16:28
5           5   thm220   jeo227        1  3/24/26 16:28
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

degree_dict = dict(g.degree())
degrees = np.array([degree_dict[n] for n in g.nodes()])

node_sizes = 30 + 120 * (degrees - degrees.min()) / (
    degrees.max() - degrees.min() + 1e-9
)

# Put all nodes on one radial shell
#pos = nx.shell_layout(g)
#pos = nx.nx_agraph.graphviz_layout(g, prog="twopi")
pos = nx.nx_pydot.graphviz_layout(g, prog="neato")

fig, ax = plt.subplots(figsize=(8, 8), dpi=150)

nx.draw_networkx_edges(
    g,
    pos,
    ax=ax,
    edge_color="black",
    alpha=0.08,
    width=0.5
)

nx.draw_networkx_nodes(
    g,
    pos,
    ax=ax,
    node_size=node_sizes,
    node_color=degrees,
    cmap="viridis",
    alpha=0.9,
    linewidths=0.5,
    edgecolors="white"
)

ax.set_title("Watermelon Meow Meow Network", fontsize=18, fontweight="bold", pad=15)
ax.set_axis_off()
ax.set_aspect("equal")
plt.tight_layout()
plt.show()
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[3], line 15
     11 
     12 # Put all nodes on one radial shell
     13 #pos = nx.shell_layout(g)
     14 #pos = nx.nx_agraph.graphviz_layout(g, prog="twopi")
---> 15 pos = nx.nx_pydot.graphviz_layout(g, prog="neato")
     16 
     17 fig, ax = plt.subplots(figsize=(8, 8), dpi=150)
     18 

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/networkx/drawing/nx_pydot.py:280, in graphviz_layout(G, prog, root)
    250 def graphviz_layout(G, prog="neato", root=None):
    251     """Create node positions using Pydot and Graphviz.
    252 
    253     Returns a dictionary of positions keyed by node.
   (...)    278     This is a wrapper for pydot_layout.
    279     """
--> 280     return pydot_layout(G=G, prog=prog, root=root)

File /opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/networkx/drawing/nx_pydot.py:321, in pydot_layout(G, prog, root)
    283 def pydot_layout(G, prog="neato", root=None):
    284     """Create node positions using :mod:`pydot` and Graphviz.
    285 
    286     Parameters
   (...)    319 
    320     """
--> 321     import pydot
    323     P = to_pydot(G)
    324     if root is not None:

ModuleNotFoundError: No module named 'pydot'

Function one: From Susceptible to infected#

We need a function titled from_s_to_i that takes as input a node (n) and a network (g) and updates this node with the disease state infected. This means that we change the state S to 0, the state I to 1, and the state R to 0. Your function will need to select a node n from a graph g, using the g.nodes method and modify the “s”, “i”, and the “r” attributes. Hint: Take a look at the code above where i assigned all the nodes to the susceptible state.

#--this is the answer the function that you will need to implement.
def from_s_to_i(n,g):
    pass
    
#--this is for students to test their code
def test_function_one( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    #--change state 
    fun(1,g)
    
    node = g.nodes[1]
    if node["s"] == 0 and node["i"] == 1 and node["r"] == 0:
        print("Pass")
    else:
        print("Did not pass, keep trying!")
        
#--example test
test_function_one( from_s_to_i )
Did not pass, keep trying!

Function two: From Infected to Removed#

We need a function titled from_i_to_r that takes as input a node and a network and updates this node with the disease state removed. This means that we change the state S to 0, the state I to 0, and the state R to 1. This code should look extremely similar to the code you used to prouce from_s_to_i except that we are changing the disease state of a node to removed. Hint Start by copy and pasting your from_s_to_i code.

#--this is the answer
def from_i_to_r(n,g):
    pass

#--this is for students
def test_function_two( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    #--change state 
    fun(5,g)
    
    node = g.nodes[5]
    if node["s"] == 0 and node["i"] == 0 and node["r"] == 1:
        print("Pass")
    else:
        print("Did not pass, keep trying!")
        
test_function_two( from_i_to_r)
Did not pass, keep trying!

Function three: Check if a node is in the infected state#

We need a function titled is_infected that takes as input a node and a network and returns the value one if the node is in the infected state and the value zero otherwise. Remeber that you can access a dictionary of the attributes of a node using the g.nodes method.

For example, if i access the node titled “thm220” then a dictionary of attributes will be returned.

#--accessing a dictionary of attirbutes that are attached to the node "thm220"
g.nodes["thm220"]
{'s': 1, 'i': 0, 'r': 0, 't': np.float64(4.867878776169396)}
#--this is the correct answer
def is_infected(n,g):
    pass

#--this is for students
def test_function_three( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    #--change state 
    from_s_to_i(5,g)
    
    if is_infected(5,g)==1 and is_infected(2,g)==0:
        print("Pass")
    else:
        print("Did not pass, keep trying!")

test_function_three( is_infected )
Did not pass, keep trying!

Function four: Check if a node is in the removed state#

We need a function titled is_removed that takes as input a node and a network and returns the value one if the node is in the removed state and the value zero otherwise. This will look very close to the function above is_infected. Hint Start by copy and pasting your code above for is_infected.

#--this is the correct answer
def is_removed(n,g):
    pass

#--this is for students
def test_function_four( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    #--change state 
    from_s_to_i(5,g)
    from_i_to_r(3,g)
    
    if is_removed(3,g)==1 and is_removed(5,g)==0:
        print("Pass")
    else:
        print("Did not pass, keep trying!")

test_function_four( is_removed )
Did not pass, keep trying!

Function five: Count the number of infected nodes#

We need a function titled ninfected that takes as input a node and a network and returns the number of neighbors of node that are in the infected state. You will need the networkx method neighbors https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.neighbors.html and you should use the function you wrote above called is_infected to count all the nodes that are infected.

The neigbors method takes as input a node and returns an iterator of all the neighbors of that node. Iterators can be a bit tricky to work with and so we can change change an interator to a list with the list function. For example, suppose i wish to return the list of all nodes that are neighbors of the node “profm”. Then i can write

list(g.neighbors("thm220"))
['exp626',
 'oya226',
 'edj227',
 'tbp226',
 'jeo227',
 'mil426',
 'jek227',
 'sab926',
 'haw226',
 'dal226',
 'lel225',
 'rac327',
 'lom226',
 'mav228',
 'anp326',
 'abn226',
 'zcb226']
#--this is the correct answer
def ninfected(n,g):
    pass


#--this is for students to test
def test_function_five( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    #--choose node 5 and add neighbors until three
    num_neighbors = len(list(g.neighbors(5)))
    while num_neighbors <= 3:
            non_neighbors = set(np.arange(10)) - set(list(g.neighbors(5)))
            node_to_add = np.random.choice(list(non_neighbors))
            g.add_edge(5,node_to_add)
            
            num_neighbors = len(list(g.neighbors(5)))
    
    #--change state 
    random_neighbors = np.random.choice(list(g.neighbors(5)), 3, replace=False) 
    for node in random_neighbors:
        from_s_to_i(node,g)
    
    result = ninfected(5,g)
    
    if not isinstance(result, int):
        print("Make sure to return an integer")
    else:
        if ninfected(5,g)==3:
            print("Pass")
        else:
            print("Did not pass, keep trying!")
test_function_five(ninfected)
Make sure to return an integer

Function six: Probability of infection#

We need a function titled prob_of_infection that takes as input a node, network, and a probability of transmission p and returns the probability that at least one infected neighbor will infect node. You should use the function ninfected to compute the number of infected neighbors of node.

#--this is the correct answer
def prob_of_infection(n,g,p):
    pass


#--this is for students to test
def test_function_six( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    #--add some certain edges
    g.add_edge(1,5)
    g.add_edge(2,5)
    g.add_edge(9,5)
    
    #--infect
    from_s_to_i(1,g)
    from_s_to_i(9,g)
    
    #--test
    p=0.3
    
    correct_ans = 1 - (1-p)**2
    proposed_ans = prob_of_infection(5,g,p)
    
    if correct_ans == proposed_ans:
        print("Pass")
    else:
        print("Did not pass, keep trying!")
test_function_six(prob_of_infection)
Did not pass, keep trying!

Function seven: List all the infected nodes#

We will need a function titled infected_nodes that takes as input a network and returns a list of all the nodes that are in the infected state. You can create an empty list titled infected_nodes, loop through all nodes \((n)\) in a network g using the nodes() method, determine if the node \(n\) is infected and append it to the list infected_nodes. If a node is not infected than we skip that node and move to the next node in the loop.

Below is a template for how to loop through all the nodes in your network g

for node in g.nodes():
    <Code to determine if a node is infected and add to a list titled infected_nodes>
#--this is the correct answer
def infected_nodes(g):
    pass


#--this is for students to test
def test_function_seven( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0
    
    #--infect
    from_s_to_i(1,g)
    from_s_to_i(9,g)
    from_s_to_i(3,g)
    from_s_to_i(2,g)

    ans = fun(g)
    
    if not isinstance(ans, list):
        print("Must be a list")
    else:
        if set(ans) == set([1,9,3,2]):
            print("Pass")
        else:
            print("Did not pass, keep trying!")
test_function_seven(infected_nodes)
Must be a list

Function eight: Decrement recovery time#

We will need a function titled decrement_time that takes as input a node and a network and decreases the t attribute that is attached to node by one. This function does not need to return anything, it needs to decrement an attribute of our network.

#--this is the correct answer
def decrement_time(n,g):
    pass

#--this is for students to test
def test_function_eight( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0
        g.nodes[node]["t"] = 0
    
    g.nodes[5]["t"] = 1
        
    #--run function
    decrement_time(5,g)
    
    ans = g.nodes[5]["t"]
    correct_ans = 0
       
    if ans == correct_ans:
         print("Pass")
    else:
         print("Did not pass, keep trying!")
            
test_function_eight( decrement_time )
Did not pass, keep trying!

Function nine: Is time up#

We will need a function titled is_time_up that takes as input a node and a network and returns the value one if the t attribute is less than or equal to zero and returns the value zero otherwise. This function will determine if the infectious period for a node is over.

#--this is the correct answer
def is_time_up(n,g):
    pass


#--this is for students to test
def test_function_nine( fun ):
    import networkx as nx
    g = nx.gnm_random_graph(10, 5)
    
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0
        g.nodes[node]["t"] = 0
    
    g.nodes[5]["t"] = 1
        
    #--run function
    decrement_time(5,g)
    
    ans = is_time_up(5,g)
    correct_ans = 1
       
    if ans == correct_ans:
         print("Pass")
    else:
         print("Did not pass, keep trying!")
            
test_function_nine( is_time_up )
Did not pass, keep trying!

Putting all of these functions to use to simulate an outbreak.#

Given a parameter p that describes the probability of an effective contact, we will simulate an outbreak on a network over T time steps. Here are all the steps that use your nine function to generate an outbreak.

#--set our parameters
T = 20
p = 0.90

#--infect one node at random. 
infect_random_node(g)

#--create an empty list that will store the number of prevalent infections over time 
infections_over_time = []

for time in range(T):
    break #--<<remove this when you start coding

    #--Create a var called previous_infections that is a list of all infected nodes
    previous_infections     = infected_nodes(g)
    
    #--Create a var called num_previous_infections that counts the number of infected nodes
    num_previous_infections = len(previous_infections)

    #--Append num_previous_infections to the list infections_over_time
    infections_over_time.append(num_previous_infections)

    #--Section: propogate infections
    for node in g.nodes():
        #--Compute the probability that node is infected (use your prob_of_infection function)
        #--MAKE SURE TO CALL THIS PROBABILITY "prob"
        prob = prob_of_infection(node,g,p)

        #--If the node is in the removed state, skip this node by using the "continue" comand
        #--Note: you will want to use your is_removed function. 
        if is_removed(node,g):
            continue
        
        #--This piece of code infects a node with probability prob
        if np.random.random() < prob:
            from_s_to_i(node,g)
            
    #--Section: Recovery
    #--Loop through all infected individuals 
    for infected in previous_infections:
        #--include code that changes the disease state of the node infected to the recovered state
        #--you will need your function titled from_i_to_r
        from_i_to_r(infected,g)

Please plot the list infections_over_time. Add an xlabel using either plt.xlabel or ax.set_xlabel with the text “Time” and a ylabel with the text “Number of infections”.

#--student will complete this plot 

FINAL STEP#

Please wrap all of the code above into a function titled simulate_outbreak. This function will take two arguments: T and p and returns infections_over_time.

def simulate_outbreak(T,p=0.1):
    #--make sure everyone is susceptible
    for node in g.nodes():
        g.nodes[node]["s"] = 1
        g.nodes[node]["i"] = 0
        g.nodes[node]["r"] = 0

    ##--assigning infectious periods to everyone if they get infected
    recovery_times = np.random.gamma(5,1, size=len(g))

    for r,node in zip(recovery_times, g.nodes()):
        g.nodes[node]["t"] = r

    #--add an infection 
    patient0 = np.random.choice(list(g.nodes()))

    g.nodes[patient0]["s"]=0
    g.nodes[patient0]["i"]=1
    g.nodes[patient0]["r"]=0

    infections_over_time = []

    for time in range(T):
        previous_infections     = infected_nodes(g)
        
        num_previous_infections = len(previous_infections)

        infections_over_time.append(num_previous_infections)

        #--propogate infections
        for node in g.nodes():
            prob = prob_of_infection(node,g,p)

            #--skip this node if it is removed
            if is_removed(node,g):
                continue
            
            #--infection!
            if np.random.random() < prob:
                from_s_to_i(node,g)

        #--remove all previous infections 
        for infected in previous_infections:
            decrement_time(infected,g)

            if is_time_up(infected,g):
                from_i_to_r(infected,g)
    return infections_over_time

The codebelow takes your simulate outbreak function, runs this 1000 times, and plots the number of prevalent cases for each iteration.

import sys
infections = []
for i in range(1000):
    break #<--remove this when you are ready
    
    sys.stdout.write("\r Simulating {:03d}".format(i))
    infections_over_time = simulate_outbreak(T=50,p=0.1)
    plt.plot(infections_over_time, color= "blue", alpha=0.1)

plt.ylabel("Incident infections")
plt.xlabel("time")
plt.show()
_images/e45ecfb0f1b9c300b380ce843b7452e567b29e0d3b0a90cd443b104543268847.png