An interactive ML tool for predicting heart disease

In a previous blog ("Modeling the UCI Heart Disease dataset") I trained a model to predict the presence of heart disease. So I have a model, now what?

Machine learning models like this can be put to work generating predictions on new inputs, and they’re great for simulations as well. Let’s say we wanted to know the likelihood of heart disease for a 60 year-old male with a cholesterol value of 244, and a resting blood pressure value of 88. Or the likelihood for a 44 year-old female reporting atypical chest pain. We might want to know how incremental changes in cholesterol change the likelihood of heart disease. Maybe I’m interested in lowering my my heart disease risk, and I want to know specifically what areas or biomarkers to focus on improving.

For this post, I created a live web tool that accepts user input on any of the features the model was trained on. Using the trained model, the tool generates probability predictions reflecting the likelihood of heart disease. As the inputs are changed, probabilities update in real-time.

Try it out live by clicking the “launch binder” button.

Binder

The code follows.

Initialize the model

First we initialize the model that we trained in a previous blog post.

# Imports
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import make_column_transformer
from sklearn.linear_model import LogisticRegression

# Read data
df = pd.read_csv('heart.csv')

# Make lists of variables
bins = ['sex', 'fbs', 'exang', 'thal', 'restecg']
cats = ['cp', 'slope']
ords = ['ca']
nums = ['age', 'oldpeak', 'trestbps', 'chol', 'thalach']
target = ['target']

# Recoding
df.target = df.target.replace({0:1, 1:0})
df.cp = df.cp.replace({0:'Asympt.', 1:'Atypical', 2:'Non', 3:'Typical'})
df.restecg = df.restecg.replace({0:'LV hyper', 1:'Normal', 2:'ST-T wave'})
df.slope = df.slope.replace({0:'down', 1:'up', 2:'flat'})
df.thal = df.thal.replace({0:'NA', 1:'Fixed', 2:'Normal', 3:'Revers.'})

# Collapse sparse classes
df.restecg = df.restecg.replace({'Normal':'Normal', 'LV hyper':'Abnormal', 'ST-T wave':'Abnormal'})
df.thal = df.thal.replace({'NA':'Normal', 'Normal':'Normal', 'Fixed': 'Abnormal', 'Revers.': 'Abnormal'})

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(df, 
                                                    df.target, 
                                                    test_size = 0.2, 
                                                    random_state = 42,
                                                    stratify = df.target)

# Feature encoding
clt = make_column_transformer(
    (StandardScaler(), nums),
    (OneHotEncoder(), cats)
)

clt.fit(X_train)
X_train_transformed = clt.transform(X_train)
X_test_transformed = clt.transform(X_test)

# Fit model
lr = LogisticRegression()
lr.fit(X_train_transformed, y_train)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                   warm_start=False)

Interactive model predictions

Then, using ipywidgets, we can build a user interface that takes user inputs, and sends these inputs to the lr.predict_proba() function to generate a probability estimate of heart disease.

# Initialize a dataframe containing "average" default values for the starting model/inputs

init = {
    'age': [X_train['age'].median()], 
    'sex': [X_train['sex'].mode()[0]],
    'cp': [X_train['cp'].mode()[0]],
    'trestbps': [X_train['trestbps'].median()],
    'chol': [X_train['chol'].median()],
    'fbs': [X_train['fbs'].mode()[0]],
    'restecg': [X_train['restecg'].mode()[0]],
    'thalach': [X_train['thalach'].mode()[0]],
    'exang': [X_train['exang'].mode()[0]],
    'oldpeak': [X_train['oldpeak'].median()],
    'slope': [X_train['slope'].mode()[0]],
    'ca': [X_train['ca'].mode()[0]],
    'thal': [X_train['thal'].mode()[0]],
    'target': [np.nan]
    }
init_df = pd.DataFrame(data=init)
# Define the user widgets

import ipywidgets as widgets

age = widgets.IntSlider(
    value=init_df['age'], 
    description='Age:'
)
sex = widgets.Dropdown(
    options=[('Female', 0), ('Male', 1)],
    value=init_df['sex'][0],
    description='Sex:'
)
cp = widgets.Dropdown(
    options=['Asympt.', 'Atypical', 'Non', 'Typical'],
    value=init_df['cp'][0],
    description='Chest pain:'
)
trestbps = widgets.FloatSlider(
    value=init_df['trestbps'],
    description='Resting BP:',
    min=60,
    max=240,
    step=0.5,
)
chol = widgets.IntSlider(
    value=init_df['chol'],
    description='Cholesterol:',
    min=50,
    max=600
)
fbs = widgets.Dropdown(
    options=[('<= 120 mg/dl', 0), ('> 120 mg/dl', 1)],
    value=init_df['fbs'][0],
    description='Fasting BS:'
)
restecg = widgets.Dropdown(
    options=['Normal', 'Abnormal'],
    value=init_df['restecg'][0],
    description='Resting ECG:'
)
thalach = widgets.IntSlider(
    value=init_df['thalach'],
    description='Max HR:',
    min=60,
    max=220
)
exang = widgets.Dropdown(
    options=[('No', 0), ('Yes', 1)],
    value=init_df['exang'][0],
    description='Ex. angina:'
)
oldpeak = widgets.FloatSlider(
    value=init_df['oldpeak'],
    description='Old peak:',
    min=0,
    max=10,
    step=0.1,
)
slope = widgets.Dropdown(
    options=['down', 'up', 'flat'],
    value=init_df['slope'][0],
    description='ST slope:'
)
ca = widgets.Dropdown(
    options=[0, 1, 2, 3],
    value=init_df['ca'][0],
    description='# vessels:'
)
thal = widgets.Dropdown(
    options=['Normal', 'Abnormal'],
    value=init_df['thal'][0],
    description='Thalium test:'
)
# Define the update/output function

from IPython.display import clear_output, display, HTML
out = widgets.Output()

def on_update(_):
    with out:
        # Generate prediction from inputs
        inputs = {
            'age': [age.value], 
            'sex': [sex.value],
            'cp': [cp.value],
            'trestbps': [trestbps.value],
            'chol': [chol.value],
            'fbs': [fbs.value],
            'restecg': [restecg.value],
            'thalach': [thalach.value],
            'exang': [exang.value],
            'oldpeak': [oldpeak.value],
            'slope': [slope.value],
            'ca': [ca.value],
            'thal': [thal.value],
            'target': [np.nan]
            }
        inputs_df = pd.DataFrame(data=inputs)
        inputs_transform = clt.transform(inputs_df)
        pred = lr.predict_proba(inputs_transform)
        
        clear_output()        
        result = 'inputs: '\
                + str(inputs)\
                + '\n\n\nprobability of heart disease: '\
                + str(round(pred[0][1]*100))\
                + '%\n\n'        
        print(result)

inputs = [age, sex, cp, trestbps, chol, fbs, 
          restecg, thalach, exang, oldpeak, slope,
          ca, thal]

for input in inputs:
    input.observe(on_update)
    
# Run once
on_update('')
%%HTML
<!-- Use HTML to set a minimum height -->
<style>
    .widget-output {
        min-height: 120px
    }
</style>
# Display widgets
widgets.VBox([
    out,
    age, 
    sex, 
    cp, 
    trestbps, 
    chol, 
    fbs,
    restecg,
    thalach,
    exang,
    oldpeak,
    slope,
    ca,
    thal
])
VBox(children=(Output(outputs=({'name': 'stdout', 'text': "inputs: {'age': [55], 'sex': [1], 'cp': ['Asympt.']…