Source code for mds_2025_helper_functions.eda
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
[docs]
def perform_eda(dataframe, rows=5, cols=2):
"""
A universal EDA function to generate data summaries and visualize features.
Parameters:
dataframe (pd.DataFrame): The input dataset for EDA.
rows (int): Number of rows in the grid layout for visualizations.
cols (int): Number of columns in the grid layout for visualizations.
Returns:
None
Example:
>>> import pandas as pd
>>> from mds_2025_helper_functions.eda import perform_eda
>>> data = {
... 'Age': [25, 32, 47, 51, 62],
... 'Salary': [50000, 60000, 120000, 90000, 85000],
... 'Department': ['HR', 'Finance', 'IT', 'Finance', 'HR'],
... 'JoiningDate': pd.to_datetime(['2015-01-01', '2016-07-15', '2017-03-12', '2018-06-01', '2019-08-19']),
... 'Bonus': [0, 5000, 12000, 7500, 7000]
... }
>>> df = pd.DataFrame(data)
>>> # Use the function to perform EDA
>>> perform_eda(df, rows=2, cols=2)
# The above call will generate the following:
# 1. A dataset overview
# 2. Basic statistics for all columns
# 3. A missing values report and heatmap (if applicable)
# 4. Correlation heatmap for numeric columns
# 5. Feature distribution/count plots
# 6. Scatterplots for numeric feature pairs (if applicable)
# 7. Outlier detection report for numeric features
# Note: Visualizations will be shown as matplotlib and seaborn plots.
"""
if not isinstance(dataframe, pd.DataFrame):
raise TypeError("Input must be a pandas DataFrame.")
print("===== Dataset Overview =====")
print(dataframe.info())
print("\n===== Basic Statistics =====")
print(dataframe.describe(include='all').transpose())
# Missing value report
print("\n===== Missing Values Report =====")
missing_values = dataframe.isnull().sum()
print(missing_values[missing_values > 0])
# Plot missing value heatmap (if missing values exist)
if dataframe.isnull().values.any():
plt.figure(figsize=(10, 6))
sns.heatmap(dataframe.isnull(), cbar=False, cmap="viridis")
plt.title("Missing Values Heatmap")
plt.show()
else:
print("No missing values in the dataset.")
# Correlation heatmap for numeric features
numeric_cols = dataframe.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 1:
plt.figure(figsize=(12, 10))
mask = np.triu(np.ones_like(dataframe[numeric_cols].corr(), dtype=bool))
sns.heatmap(dataframe[numeric_cols].corr(), mask=mask, annot=True, fmt=".2f", cmap="coolwarm", square=True)
plt.title("Correlation Heatmap")
plt.show()
else:
print("Not enough numeric columns for correlation heatmap.")
# Dynamic Grid Plot for All Features
print("\n===== Feature Visualizations =====")
total_features = len(dataframe.columns)
total_plots = rows * cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 6, rows * 4), tight_layout=True)
axes = axes.ravel()
for i, feature in enumerate(dataframe.columns):
if dataframe[feature].dtype in [np.float64, np.int64]: # Numeric columns
sns.histplot(dataframe[feature], kde=True, bins=20, ax=axes[i])
axes[i].set_title(f"Distribution of {feature}")
axes[i].set_xlabel(feature)
axes[i].set_ylabel("Frequency")
elif pd.api.types.is_datetime64_any_dtype(dataframe[feature]): # Datetime columns
dataframe[feature].value_counts().sort_index().plot(kind="line", marker="o", ax=axes[i])
axes[i].set_title(f"Time Series of {feature}")
axes[i].set_xlabel(feature)
axes[i].set_ylabel("Count")
else:
sns.countplot(
x=dataframe[feature],
ax=axes[i],
order=dataframe[feature].value_counts().index,
palette="viridis",
hue=None,
legend=False
)
axes[i].tick_params(axis='x', rotation=45)
axes[i].set_title(f"Count Plot for {feature}")
axes[i].set_xlabel(feature)
axes[i].set_ylabel("Count")
for j in range(total_features, total_plots):
fig.delaxes(axes[j])
plt.show()
# Scatterplots for Numeric Feature Pairs
print("\n===== Scatterplots for Numeric Features =====")
if len(numeric_cols) > 1:
num_pairs = len(numeric_cols) * (len(numeric_cols) - 1) // 2 # Total number of unique pairs
rows_scatter = (num_pairs // cols) + (1 if num_pairs % cols != 0 else 0) # Dynamic row count
fig, axes = plt.subplots(rows_scatter, cols, figsize=(cols * 6, rows_scatter * 4), tight_layout=True)
axes = axes.ravel()
pair_idx = 0
for i, col1 in enumerate(numeric_cols):
for col2 in numeric_cols[i + 1:]:
if pair_idx >= len(axes):
break
sns.scatterplot(x=dataframe[col1], y=dataframe[col2], ax=axes[pair_idx], alpha=0.7)
axes[pair_idx].set_title(f"{col1} vs {col2}")
axes[pair_idx].set_xlabel(col1)
axes[pair_idx].set_ylabel(col2)
pair_idx += 1
for j in range(pair_idx, len(axes)):
fig.delaxes(axes[j])
plt.show()
else:
print("Not enough numeric columns for scatterplots.")
# Outliers Detection Report
print("\n===== Outliers Report =====")
for col in numeric_cols:
Q1 = dataframe[col].quantile(0.25)
Q3 = dataframe[col].quantile(0.75)
IQR = Q3 - Q1
outliers = dataframe[(dataframe[col] < Q1 - 1.5 * IQR) | (dataframe[col] > Q3 + 1.5 * IQR)]
print(f"{col}: {len(outliers)} potential outliers")