Sick Gaming
[Tut] Creating Beautiful Heatmaps with Seaborn - Printable Version

+- Sick Gaming (https://www.sickgaming.net)
+-- Forum: Programming (https://www.sickgaming.net/forum-76.html)
+--- Forum: Python (https://www.sickgaming.net/forum-83.html)
+--- Thread: [Tut] Creating Beautiful Heatmaps with Seaborn (/thread-98949.html)



[Tut] Creating Beautiful Heatmaps with Seaborn - xSicKxBot - 12-26-2020

Creating Beautiful Heatmaps with Seaborn

<div><p>Heatmaps are a specific type of plot which exploits the combination of color schemes and numerical values for representing complex and articulated datasets. They are largely used in data science application that involves large numbers, like biology, economics and medicine. </p>
<p>In this video we will see how to create a heatmap for representing the total number of COVID-19 cases in the different USA countries, in different days. For achieving this result, we will exploit <em>Seaborn</em>, a Python package that provides lots of fancy and powerful functions for plotting data.</p>
<figure class="wp-block-embed is-type-video is-provider-youtube wp-block-embed-youtube wp-embed-aspect-16-9 wp-has-aspect-ratio">
<div class="wp-block-embed__wrapper">
<div class="ast-oembed-container"><iframe title="Creating Beautiful Heatmaps with Seaborn" width="1400" height="788" src="https://www.youtube.com/embed/lQwLsa0WY8A?feature=oembed" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe></div>
</div>
</figure>
<p>Here’s the code to be discussed:</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns #url of the .csv file
url = r"path of the .csv file" # import the .csv file into a pandas DataFrame
df = pd.read_csv(url, sep = ';', thousands = ',') # defining the array containing the states present in the study
states = np.array(df['state'].drop_duplicates())[:40] #extracting the total cases for each day and each country
overall_cases = []
for state in states: tot_cases = [] for i in range(len(df['state'])): if df['state'][i] == state: tot_cases.append(df['tot_cases'][i]) overall_cases.append(tot_cases[:30]) data = pd.DataFrame(overall_cases).T
data.columns = states #Plotting
fig = plt.figure()
ax = fig.subplots()
ax = sns.heatmap(data, annot = True, fmt="d", linewidths=0, cmap = 'viridis', xticklabels = True)
ax.invert_yaxis()
ax.set_xlabel('States')
ax.set_ylabel('Day n°')
plt.show()
</pre>
<p>Let’s dive into the code to learn Seaborn’s heatmap functionality in a step-by-step manner. </p>
<h2>Importing the required libraries for this example</h2>
<p>We start our script by importing the libraries requested for running this example; namely <a href="https://blog.finxter.com/numpy-tutorial/" target="_blank" rel="noreferrer noopener" title="NumPy Tutorial – Everything You Need to Know to Get Started">Numpy</a>, <a href="https://blog.finxter.com/pandas-quickstart/" target="_blank" rel="noreferrer noopener" title="10 Minutes to Pandas (in 5 Minutes)">Pandas</a>, <a href="https://blog.finxter.com/matplotlib-line-plot/" target="_blank" rel="noreferrer noopener" title="Matplotlib Line Plot – A Helpful Illustrated Guide">Matplotlib </a>and Seaborn.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
</pre>
</p>
<h2>What’s in the data?</h2>
<p>As mentioned in the introduction part, we will use the COVID-19 data that were also used in the article about<a href="https://blog.finxter.com/exponential-fit-with-scipys-curve_fit/" target="_blank" rel="noreferrer noopener" title="https://blog.finxter.com/exponential-fit-with-scipys-curve_fit/"> <em>Scipy.curve_fit()</em></a> function. Data have been downloaded from the official website of the<a href="https://data.cdc.gov/Case-Surveillance/United-States-COVID-19-Cases-and-Deaths-by-State-o/9mfq-cb36" target="_blank" rel="noreferrer noopener" title="https://data.cdc.gov/Case-Surveillance/United-States-COVID-19-Cases-and-Deaths-by-State-o/9mfq-cb36"> “Centers for Disease Control and Prevention”</a> as a .csv file. </p>
<p>The file reports multiple information regarding the COVID-19 pandemic in the different US countries, such as the total number of cases, the number of new cases, the number of deaths etc…; all of them have been recorded every day, for multiple US countries. </p>
<p>We will generate a heatmap that displays in each slot the number of total cases recorded for a particular day in a particular US country. To do that, the first thing that should be done is to import the .csv file and to store it in a Pandas DataFrame.</p>
<h2>Importing the data with Pandas</h2>
<p>The data are stored in a .csv file; the different values are separated by a semi-colon while the thousands symbol is denoted with a comma. In order to import the .csv file within our python script, we exploit the Pandas function <em>.read_csv()</em> which accepts as input the path of the file and converts it into a <a href="https://blog.finxter.com/how-to-create-a-dataframe-in-pandas/" title="How to Create a DataFrame in Pandas?" target="_blank" rel="noreferrer noopener">Pandas DataFrame</a>. </p>
<p>It is important to note that, when calling <em>.read_csv(), </em>we specify the separator, which in our case is “;” by saying “sep = ‘;’” and the symbol used for denoting the thousands, by writing “thousands = ‘,’”. All these things are contained in the following code lines:</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">#url of the .csv file
url = r"path of the file" # import the .csv file into a pandas DataFrame
df = pd.read_csv(url, sep = ';', thousands = ',')
</pre>
<h2>Creating the arrays that will be used in the heatmap</h2>
<p>At this point, we have to edit the created DataFrame in order to extract just the information that will be used for the creation of the heatmap. </p>
<p>The first values that we extract are the ones that describe the name of the countries in which the data have been recorded. To better identify all the categories that make up the DataFrame, we can type “df.columns” to print out the header of the file. Among the different categories present in the header, the one that we are interested in is “state”, in which we can find the name of all the states involved in this chart. </p>
<p>Since the data are recorded on daily basis, each line corresponds to the data collected for a single day in a specific state; as a result, the names of the states are repeated along this column. Since we do not want any repetition in our heatmap, we also have to remove the duplicates from the array. </p>
<p>We proceed further by defining a <a href="https://blog.finxter.com/how-to-convert-a-list-to-a-numpy-array/" target="_blank" rel="noreferrer noopener" title="How to Convert a List to a NumPy Array?">Numpy array</a> called “states” in which we store all the values present under the column “state” of the DataFrame; in the same code line, we also apply the method <em>.drop_duplicates()</em> to remove any duplicate of that array. Since there are 60 states in the DataFrame, we limit our analysis to the first 40, in order not to create graphical problems in the labels of the heatmap x-axis, due to the limited window space.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">#defining the array containing the states present in the study
states = np.array(df['state'].drop_duplicates())[:40] </pre>
<p>The next step is to extract the number of total cases, recorded for each day in each country. To do that, we exploit two nested for loops which allow us creating a list containing the n° of total cases (an integer number for each day) for every country present in the “states” array and appending them into another list called “overall_cases” which needs to be defined before calling the <a href="https://blog.finxter.com/python-loops/" target="_blank" rel="noreferrer noopener" title="Python Loops">for loop</a>.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">#extracting the total cases for each day and each country
overall_cases = []
</pre>
<p>As you can see in the following code, in the first for loop we iterate over the different states that were previously stored into the “states” array; for each state, we define an empty list called “tot_cases” in which we will append the values referred to the total cases recorded at each day.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">for state in states: tot_cases = []
</pre>
<p>Once we are within the first for loop (meaning that we are dealing with a single state), we initialize another for loop which iterates through all the total cases values stored for that particular state. This second for loop will start from the element 0 and iterate through all the values of the “state” column of our DataFrame. We achieve this by exploiting the functions <em>range </em>and <em>len.</em></p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group=""> for i in range(len(df['state'])):</pre>
<p>Once we are within this second <code>for</code> loop, we want to append to the <a href="https://blog.finxter.com/python-lists/" target="_blank" rel="noreferrer noopener" title="The Ultimate Guide to Python Lists">list </a>“tot_cases” only the values that are referred to the state we are currently interested in (i.e the one defined in the first for loop, identified by the value of the variable “state”); we do this by using the following <em><a href="https://blog.finxter.com/if-then-else-in-one-line-python/" target="_blank" rel="noreferrer noopener" title="If-Then-Else in One Line Python [Video + Interactive Code Shell]">if statement</a>:</em></p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group=""> if df['state'][i] == state: tot_cases.append(df['tot_cases'][i])
</pre>
<p>When we are finished with appending the values of total cases for each day of a particular country to the “tot_cases” list, we exit from the inner for loop and store this list into the “overall_cases” one, which will then become a list of lists. Also in this case, we limit our analysis to the first 30 days, otherwise we would not have enough space in our heatmap for all the 286 values present in the DataFrame.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group=""> overall_cases.append(tot_cases[:30])</pre>
<p>In the next iteration, the code will start to analyze the second element of the “states” array, i.e. another country, will initialize an empty list called “tot_cases” and enter in the second <em>for loop</em> for appending all the values referred to that country in the different days and eventually, once finished, append the entire list to the list “overall_cases”; this procedure will be iterated for all the countries stored in the “states” array. At the end, we will have extracted all the values needed for generating our heatmap.&nbsp;</p>
<h2>Creating the DataFrame for the heatmap</h2>
<p>As already introduced in the first part, we exploit the Seaborn function .<em>heatmap() </em>to generate our heatmap. </p>
<p>This function can take as input a pandas DataFrame that contains the rows, the columns and all the values for each cell that we want to display in our plot. We hence generate a new pandas DataFrame (we call it “data”) that contains the values stored in the list “overall_cases”; in this way, each row of this new DataFrame is referred to a specific state and each column to a specific day. </p>
<p>We then transpose this DataFrame by adding “.T” at the end of the code line, since in this way we can then insert the name of the states as the header of our Dataframe.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">data = pd.DataFrame(overall_cases).T</pre>
<p>The names of the states were previously stored in the array “states”, we can modify the header of the DataFrame using the following code:</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">data.columns = states</pre>
<p>The DataFrame that will be used for generating the heatmap will have the following shape:</p>
<pre class="wp-block-preformatted">   CO  FL  AZ  SC  CT  NE  KY  WY  IA  ...  LA  ID  NV  GA  IN  AR  MD  NY  OR 0   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0 1   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0 2   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0 3   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0 4   0   0   1   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0 </pre>
<p>The row indexes represent the n° of the day in which the data are recorded while the columns of the header are the name of the states.</p>
<h2>Generating the heatmap</h2>
<p>After generating the usual plot window with the typical matplotlib functions, we call the Seaborn function .<em>heatmap() </em>to generate the heatmap. </p>
<p>The mandatory input of this function is the pandas DataFrame that we created in the previous section. There are then multiple optional input parameters that can improve our heatmap:</p>
<ul>
<li><em>linewidths </em>allows adding a white contour to each cell to better separate them, we just have to specify the width; </li>
<li><em>xticklabels </em>modify the notation along the x-axis, if it’s equal to True, all the values of the array plotted as the x-axis will be displayed. </li>
<li>We can also chose the colormap of the heatmap by using <em>cmap </em>and specifying the name of an available heatmap (“viridis” or “magma” are very fancy but also the Seaborn default one is really cool); </li>
<li>finally, it is possible to display the numerical value of each cell by using the option <em>annot = True; </em>the numerical value will be displayed at the center of each cell. </li>
</ul>
<p>The following lines contain the code for plotting the heatmap. One final observation regards the command .invert_yaxis(); since we plot the heatmap directly from a pandas DataFrame, the row index will be the “day n°”; hence it will start from 0 and increase as we go down along the rows. By adding .invert_yaxis() we reverse the y-axis, having day 0 at the bottom part of the heatmap.</p>
<pre class="EnlighterJSRAW" data-enlighter-language="generic" data-enlighter-theme="" data-enlighter-highlight="" data-enlighter-linenumbers="" data-enlighter-lineoffset="" data-enlighter-title="" data-enlighter-group="">#Plotting
fig = plt.figure()
ax = fig.subplots()
ax = sns.heatmap(data, annot = True, fmt="d", linewidths=0, cmap = 'viridis', xticklabels = True)
ax.invert_yaxis()
ax.set_xlabel('States')
ax.set_ylabel('Day n°')
plt.show() </pre>
<p>Figure 1 displays the heatmap obtained by this code snippet.</p>
<div class="wp-block-image">
<figure class="aligncenter size-large"><img loading="lazy" width="486" height="298" src="https://blog.finxter.com/wp-content/uploads/2020/12/image-100.png" alt="" class="wp-image-19577" srcset="https://blog.finxter.com/wp-content/uploads/2020/12/image-100.png 486w, https://blog.finxter.com/wp-content/uploads/2020/12/image-100-300x184.png 300w, https://blog.finxter.com/wp-content/uploads/2020/12/image-100-150x92.png 150w" sizes="(max-width: 486px) 100vw, 486px" /></figure>
</div>
<p><strong><em>Figure 1:</em></strong> Heatmap representing the number of COVID-19 total cases for the first 30 days of measurement (y-axis) in the different USA countries (x-axis).</p>
<p>As you can see in Figure 1, there are a lot of zeroes, this is because we decided to plot the data related to the first 30 days of measurement, in which the n° of recorded cases were very low. If we decided to plot the results from all the days of measurement (from day 0 to 286), we would obtain the result displayed in Figure 2 (in this latter case, we placed <em>annot</em> equal to False since the numbers would have been too large for the cell size):</p>
<div class="wp-block-image">
<figure class="aligncenter size-large"><img loading="lazy" width="507" height="299" src="https://blog.finxter.com/wp-content/uploads/2020/12/image-102.png" alt="" class="wp-image-19580" srcset="https://blog.finxter.com/wp-content/uploads/2020/12/image-102.png 507w, https://blog.finxter.com/wp-content/uploads/2020/12/image-102-300x177.png 300w, https://blog.finxter.com/wp-content/uploads/2020/12/image-102-150x88.png 150w" sizes="(max-width: 507px) 100vw, 507px" /></figure>
</div>
<p><strong><em>Figure 2:</em></strong> Heatmap representing the number of COVID-19 total cases for the first 286 days of measurement (y-axis) in the different USA countries (x-axis); this time<em> annot = False</em>, since the cells are too small for accommodating the n° of total cases (which becomes very large towards the upper part of the heatmap).</p>
<p>The post <a href="https://blog.finxter.com/heatmaps-with-seaborn/" target="_blank" rel="noopener noreferrer">Creating Beautiful Heatmaps with Seaborn</a> first appeared on <a href="https://blog.finxter.com/" target="_blank" rel="noopener noreferrer">Finxter</a>.</p>
</div>


https://www.sickgaming.net/blog/2020/12/26/creating-beautiful-heatmaps-with-seaborn/