Prediction and Insight through Machine Learning
Published:
Prediction and insight from blackbox machine leanrning techiniques¶
Introduction¶
In this project, I delved into the application of advanced machine learning techniques to train prediction models using the popular Kaggle dataset known as the Stroke Prediction Dataset. The dataset comprises 5110 observations (rows) indicating whether a specific patient experienced a stroke or not, along with 10 additional metrics (columns). These metrics encompass both personal and health-related features. The primary goal of working with this dataset was to develop a classification model that could accurately predict whether a patient had a stroke based on healthcare metrics. To achieve this, I employed two machine learning algorithms, namely Random Forest and Neural Net, and conducted a comparative analysis of their prediction accuracies. Furthermore, I conducted a thorough exploration of both machine learning techniques to uncover insights from these seemingly complex black-box models.
This notebook is outlined as follows:
- Data Cleaning and Exploritory Data Analysis
- Neural Network
- Random Forest
- Feature Importance
- Conclusion
1.Data cleaning and Exploratory Data Analysis (EDA)¶
First, I loaded the necessary Python modules and read in the data for this analysis.
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn import tree
import shap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.read_csv("/Users/tim/Desktop/Current/Prediction/Data/healthcare-dataset-stroke-data.csv")
Then, I looked at some basic summary statistics for each feature.
summary_stats_continous = df.describe()
summary_stats_categorical = df.describe(include = 'object')
print(summary_stats_continous)
print(summary_stats_categorical)
id age hypertension heart_disease \ count 5110.000000 5110.000000 5110.000000 5110.000000 mean 36517.829354 43.226614 0.097456 0.054012 std 21161.721625 22.612647 0.296607 0.226063 min 67.000000 0.080000 0.000000 0.000000 25% 17741.250000 25.000000 0.000000 0.000000 50% 36932.000000 45.000000 0.000000 0.000000 75% 54682.000000 61.000000 0.000000 0.000000 max 72940.000000 82.000000 1.000000 1.000000 avg_glucose_level bmi stroke count 5110.000000 4909.000000 5110.000000 mean 106.147677 28.893237 0.048728 std 45.283560 7.854067 0.215320 min 55.120000 10.300000 0.000000 25% 77.245000 23.500000 0.000000 50% 91.885000 28.100000 0.000000 75% 114.090000 33.100000 0.000000 max 271.740000 97.600000 1.000000 gender ever_married work_type Residence_type smoking_status count 5110 5110 5110 5110 5110 unique 3 2 5 2 4 top Female Yes Private Urban never smoked freq 2994 3353 2925 2596 1892
The feature BMI had some NA's, and there were a few ways to handle them. Because it was only a few observations and was still plenty of data I removed the relevant rows.
df = df.dropna()
df= df.reset_index(drop=True)
Next, I wanted to get an idea of what kinds of relationships there are between the features of this dataset. I categorized the features into 3 classes, numeric, binary and categorical.
X_numeric = ["age", "bmi", "avg_glucose_level"]
X_binary = ["hypertension", "heart_disease"]
X_cat = ["gender", "ever_married", "work_type", "Residence_type", "smoking_status"]
First, I explored the numeric features.
# Stroke and age
sns.boxplot(x='stroke', y= "age", data = df)
plt.title(f"Relationship between stroke and age")
plt.show()
# stroke and glucose
sns.boxplot(x='stroke', y= "avg_glucose_level", data = df)
plt.title(f"Relationship between stroke and glucose")
plt.show()
# stoke and bmi
sns.boxplot(x='stroke', y= "bmi", data = df)
plt.title(f"Relationship between stroke and bmi")
plt.show()
Older patients seem to be the majority of stoke victims. Additionally, stoke patients tend to have higher bmi's and higher average glucose levels.
Next, I wanted to see the relationships between these numeric features.
scatter_matrix = sns.pairplot(df[X_numeric])
for i, (ax_row, feature_row) in enumerate(zip(scatter_matrix.axes, X_numeric)):
for j, ax in enumerate(ax_row):
if i != j:
correlation_coefficient = df[[feature_row, X_numeric[j]]].corr().iloc[0, 1]
ax.annotate(f'Corr: {correlation_coefficient:.2f}', xy=(0.5, 0.99), xycoords='axes fraction', ha='center', va='center', fontsize=8, color='red')
plt.show()
There were no obvious relationships between the features. I didn’t think there would be any issue of multicollinearity in future models.
Next, the binary features.
pd.crosstab(df["stroke"], df["hypertension"], margins=True, margins_name='Total')
hypertension | 0 | 1 | Total |
---|---|---|---|
stroke | |||
0 | 4309 | 391 | 4700 |
1 | 149 | 60 | 209 |
Total | 4458 | 451 | 4909 |
pd.crosstab(df["stroke"], df["heart_disease"], margins=True, margins_name='Total')
heart_disease | 0 | 1 | Total |
---|---|---|---|
stroke | |||
0 | 4497 | 203 | 4700 |
1 | 169 | 40 | 209 |
Total | 4666 | 243 | 4909 |
pd.crosstab(df["hypertension"], df["heart_disease"], margins=True, margins_name='Total')
heart_disease | 0 | 1 | Total |
---|---|---|---|
hypertension | |||
0 | 4273 | 185 | 4458 |
1 | 393 | 58 | 451 |
Total | 4666 | 243 | 4909 |
Interestingly, of the patients that had stokes most of them did not have prior heart disease. Also, the majority of patients that had stokes did not have hypertension.
Finally, I checked out the categorical features.
pd.crosstab(df["stroke"], df["gender"], margins=True, margins_name='Total')
gender | Female | Male | Other | Total |
---|---|---|---|---|
stroke | ||||
0 | 2777 | 1922 | 1 | 4700 |
1 | 120 | 89 | 0 | 209 |
Total | 2897 | 2011 | 1 | 4909 |
pd.crosstab(df['gender'], df['stroke'], normalize='index').plot(kind='bar', stacked=True)
plt.show()
pd.crosstab(df['ever_married'], df['stroke'], normalize='index').plot(kind='bar', stacked=True)
plt.show()
sns.countplot(x='stroke', hue='work_type', data=df)
plt.show()
pd.crosstab(df['Residence_type'], df['stroke'], normalize='index').plot(kind='bar', stacked=True)
plt.show()
sns.countplot(x='stroke', hue='smoking_status', data=df)
plt.show()
Upon reviewing the categorical features, it was observed that the sex distribution in the subsample of the population aligns with that among stroke victims. Consequently, gender does not appear to have been a determinable factor. Notably, a majority of stroke patients were married, and individuals who never worked also never experienced a stroke. Similarly, the proportions of rural and urban patients indicated no significant relationship. Concerning smoking status, no clear pattern emerged, but the high number of unknown smoking statuses for patients likely contributed to the lack of clarity.
After completing the Exploratory Data Analysis (EDA), a better understanding was gained regarding the available features for predicting stroke occurrences. The decision was made to incorporate all these features into the machine learning algorithms since there didn't seem to be any downside to including them, even if their relationships with strokes were weak.
Model prep¶
Before training the algorithms, I did some data manipulation to aid in the process.
For the numeric features I centered and standardized.
scaler = StandardScaler()
X_scaled = scaler.fit_transform(df[X_numeric])
X_scaled = pd.DataFrame(X_scaled, columns = X_numeric)
For the categorical variables I hot encoded them. This processed widened the tabular data and made each treatment within a category into a binary option.
enc = OneHotEncoder(sparse_output=False, drop='first')
encoded_columns = enc.fit_transform(df[X_cat])
X_encoded = pd.DataFrame(encoded_columns, columns = enc.get_feature_names_out(X_cat))
X_format = pd.concat([X_scaled, X_encoded, df[X_binary]], axis=1)
I kept the originally binary features as was. Then I split the data into training and testing sets.
y = df["stroke"]
X_train, X_test, y_train, y_test = train_test_split(X_format, y, test_size=0.2, random_state=10)
2. Nerural Network¶
I first fit the neural network within the module tensorflow.
nn = tf.keras.models.Sequential()
nn.add(tf.keras.layers.Dense(32, activation='relu', input_dim=X_train.shape[1]))
nn.add(tf.keras.layers.Dense(16, activation='relu'))
nn.add(tf.keras.layers.Dense(1, activation='sigmoid'))
The neural network included two hidden layers with 32 and 16 neurons in each respectively. For each of the hidden layers the activation function used was relu and the final classification layers used a sigmoid function.
nn.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
For the training process an Adam optimizer was used, and the loss criteria was binary cross entropy (common for binary classification). Finally, the metric of importance is accuracy.
nn.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_test, y_test))
Epoch 1/10 123/123 [==============================] - 0s 1ms/step - loss: 0.1310 - accuracy: 0.9585 - val_loss: 0.1462 - val_accuracy: 0.9572 Epoch 2/10 123/123 [==============================] - 0s 940us/step - loss: 0.1299 - accuracy: 0.9587 - val_loss: 0.1461 - val_accuracy: 0.9572 Epoch 3/10 123/123 [==============================] - 0s 986us/step - loss: 0.1294 - accuracy: 0.9585 - val_loss: 0.1489 - val_accuracy: 0.9582 Epoch 4/10 123/123 [==============================] - 0s 993us/step - loss: 0.1288 - accuracy: 0.9585 - val_loss: 0.1470 - val_accuracy: 0.9572 Epoch 5/10 123/123 [==============================] - 0s 905us/step - loss: 0.1287 - accuracy: 0.9585 - val_loss: 0.1468 - val_accuracy: 0.9572 Epoch 6/10 123/123 [==============================] - 0s 926us/step - loss: 0.1280 - accuracy: 0.9585 - val_loss: 0.1473 - val_accuracy: 0.9572 Epoch 7/10 123/123 [==============================] - 0s 1ms/step - loss: 0.1278 - accuracy: 0.9585 - val_loss: 0.1465 - val_accuracy: 0.9572 Epoch 8/10 123/123 [==============================] - 0s 985us/step - loss: 0.1283 - accuracy: 0.9587 - val_loss: 0.1483 - val_accuracy: 0.9572 Epoch 9/10 123/123 [==============================] - 0s 957us/step - loss: 0.1266 - accuracy: 0.9587 - val_loss: 0.1496 - val_accuracy: 0.9552 Epoch 10/10 123/123 [==============================] - 0s 968us/step - loss: 0.1258 - accuracy: 0.9590 - val_loss: 0.1488 - val_accuracy: 0.9572
<keras.callbacks.History at 0x14fbc2f10>
A quick look at the neural net shows that the in-sample accuracy was 95.77% and the out-of-sample accuracy was 95.82%. Which is a very high accuracy for dataset like this.
We can use a confusion matrix to have a better look at the predictions made on the testing data.
predictions = nn.predict(X_test)
predicted_labels = np.round(predictions).flatten()
true_labels = y_test
conf_matrix = confusion_matrix(true_labels, predicted_labels)
plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ['No Stroke', 'Stroke'] # Replace with your actual class labels
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
# Adding annotations
for i in range(len(classes)):
for j in range(len(classes)):
plt.text(j, i, str(conf_matrix[i, j]), ha='center', va='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
31/31 [==============================] - 0s 435us/step
The majority of predictions indicated no stroke, and the majority of actual values also turned out to be no stroke. However, it is noteworthy that most of the incorrect predictions were for patients who ultimately experienced a stroke. Despite the overall accuracy being satisfactory, caution should be exercised when considering its use in a production setting, given that the consequences of predicting no stroke and encountering one are considerably more severe than the alternative.
3. Random Forest¶
As stated, the other machine learning algorithm I used was Random Forest. Random Forest is an ensemble flavor of decision tree which has proven to be a very powerful machine learning technique for predictive purposes.
For fitting the Random Forest model, I used the features as were prepared for the neural net.
rf = RandomForestClassifier(random_state = 10)
rf.fit(X_train, y_train)
RandomForestClassifier(random_state=10)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(random_state=10)
y_pred_test = rf.predict(X_test)
y_pred_train = rf.predict(X_train)
print("in sample accuracy:", round(accuracy_score(y_train, y_pred_train), 2))
print("out of sample accuracy:", round(accuracy_score(y_test, y_pred_test), 2))
in sample accuracy: 1.0 out of sample accuracy: 0.96
For the random forest model, the in-sample accuracy was 100%, while the out-of-sample accuracy stood at 96%. Comparing these results with the neural network, the in-sample error is superior, but the out-of-sample error is essentially the same. It's important to note that the in-sample error, being based on the training data, doesn't carry as much significance.
Rather than immediately delving into the confusion matrix, as was done for the neural network, I opted to visualize the tree first to provide context for the model. Without delving too deeply into the intricacies of the Random Forest algorithm, it aggregates the consensus of many unpruned decision trees to establish a final model. This implies that visualizations may not offer substantial insights into the data. Nevertheless, I present them to illustrate the complexity of the decision tree after fitting.
tree.plot_tree(rf.estimators_[0])
[Text(0.4570870535714286, 0.975, 'x[0] <= 1.092\ngini = 0.088\nsamples = 2532\nvalue = [3746, 181]'), Text(0.146875, 0.925, 'x[5] <= 0.5\ngini = 0.038\nsamples = 2111\nvalue = [3198, 63]'), Text(0.0163265306122449, 0.875, 'x[0] <= 0.605\ngini = 0.005\nsamples = 845\nvalue = [1300, 3]'), Text(0.00816326530612245, 0.825, 'x[1] <= -0.629\ngini = 0.002\nsamples = 823\nvalue = [1267, 1]'), Text(0.004081632653061225, 0.775, 'gini = 0.0\nsamples = 448\nvalue = [681, 0]'), Text(0.012244897959183673, 0.775, 'x[8] <= 0.5\ngini = 0.003\nsamples = 375\nvalue = [586, 1]'), Text(0.00816326530612245, 0.725, 'gini = 0.0\nsamples = 346\nvalue = [540, 0]'), Text(0.0163265306122449, 0.725, 'x[10] <= 0.5\ngini = 0.042\nsamples = 29\nvalue = [46, 1]'), Text(0.012244897959183673, 0.675, 'gini = 0.0\nsamples = 14\nvalue = [23, 0]'), Text(0.02040816326530612, 0.675, 'x[1] <= -0.572\ngini = 0.08\nsamples = 15\nvalue = [23, 1]'), Text(0.0163265306122449, 0.625, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.024489795918367346, 0.625, 'gini = 0.0\nsamples = 14\nvalue = [23, 0]'), Text(0.024489795918367346, 0.825, 'x[1] <= 1.07\ngini = 0.108\nsamples = 22\nvalue = [33, 2]'), Text(0.02040816326530612, 0.775, 'gini = 0.0\nsamples = 15\nvalue = [22, 0]'), Text(0.02857142857142857, 0.775, 'x[12] <= 0.5\ngini = 0.26\nsamples = 7\nvalue = [11, 2]'), Text(0.024489795918367346, 0.725, 'gini = 0.0\nsamples = 5\nvalue = [11, 0]'), Text(0.0326530612244898, 0.725, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.2774234693877551, 0.875, 'x[15] <= 0.5\ngini = 0.059\nsamples = 1266\nvalue = [1898, 60]'), Text(0.12729591836734694, 0.825, 'x[1] <= -0.464\ngini = 0.054\nsamples = 1220\nvalue = [1838, 52]'), Text(0.053061224489795916, 0.775, 'x[8] <= 0.5\ngini = 0.012\nsamples = 230\nvalue = [336, 2]'), Text(0.04897959183673469, 0.725, 'x[14] <= 0.5\ngini = 0.014\nsamples = 195\nvalue = [284, 2]'), Text(0.036734693877551024, 0.675, 'x[2] <= 0.014\ngini = 0.007\nsamples = 188\nvalue = [273, 1]'), Text(0.0326530612244898, 0.625, 'gini = 0.0\nsamples = 142\nvalue = [206, 0]'), Text(0.04081632653061224, 0.625, 'x[7] <= 0.5\ngini = 0.029\nsamples = 46\nvalue = [67, 1]'), Text(0.036734693877551024, 0.575, 'x[2] <= 0.039\ngini = 0.074\nsamples = 18\nvalue = [25, 1]'), Text(0.0326530612244898, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.04081632653061224, 0.525, 'gini = 0.0\nsamples = 17\nvalue = [25, 0]'), Text(0.044897959183673466, 0.575, 'gini = 0.0\nsamples = 28\nvalue = [42, 0]'), Text(0.061224489795918366, 0.675, 'x[7] <= 0.5\ngini = 0.153\nsamples = 7\nvalue = [11, 1]'), Text(0.05714285714285714, 0.625, 'x[13] <= 0.5\ngini = 0.375\nsamples = 2\nvalue = [3, 1]'), Text(0.053061224489795916, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [3, 0]'), Text(0.061224489795918366, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.0653061224489796, 0.625, 'gini = 0.0\nsamples = 5\nvalue = [8, 0]'), Text(0.05714285714285714, 0.725, 'gini = 0.0\nsamples = 35\nvalue = [52, 0]'), Text(0.20153061224489796, 0.775, 'x[1] <= -0.438\ngini = 0.062\nsamples = 990\nvalue = [1502, 50]'), Text(0.08163265306122448, 0.725, 'x[0] <= 1.048\ngini = 0.34\nsamples = 13\nvalue = [18, 5]'), Text(0.07755102040816327, 0.675, 'x[13] <= 0.5\ngini = 0.18\nsamples = 12\nvalue = [18, 2]'), Text(0.07346938775510205, 0.625, 'x[2] <= -0.437\ngini = 0.219\nsamples = 9\nvalue = [14, 2]'), Text(0.06938775510204082, 0.575, 'x[12] <= 0.5\ngini = 0.444\nsamples = 3\nvalue = [4, 2]'), Text(0.0653061224489796, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.07346938775510205, 0.525, 'gini = 0.0\nsamples = 2\nvalue = [4, 0]'), Text(0.07755102040816327, 0.575, 'gini = 0.0\nsamples = 6\nvalue = [10, 0]'), Text(0.08163265306122448, 0.625, 'gini = 0.0\nsamples = 3\nvalue = [4, 0]'), Text(0.08571428571428572, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.32142857142857145, 0.725, 'x[2] <= 3.227\ngini = 0.057\nsamples = 977\nvalue = [1484, 45]'), Text(0.2673469387755102, 0.675, 'x[1] <= 1.459\ngini = 0.055\nsamples = 974\nvalue = [1478, 43]'), Text(0.1836734693877551, 0.625, 'x[7] <= 0.5\ngini = 0.048\nsamples = 850\nvalue = [1303, 33]'), Text(0.1, 0.575, 'x[11] <= 0.5\ngini = 0.024\nsamples = 299\nvalue = [480, 6]'), Text(0.08163265306122448, 0.525, 'x[1] <= 0.835\ngini = 0.021\nsamples = 230\nvalue = [372, 4]'), Text(0.06938775510204082, 0.475, 'x[2] <= 1.25\ngini = 0.006\nsamples = 189\nvalue = [312, 1]'), Text(0.0653061224489796, 0.425, 'gini = 0.0\nsamples = 166\nvalue = [276, 0]'), Text(0.07346938775510205, 0.425, 'x[0] <= 0.56\ngini = 0.053\nsamples = 23\nvalue = [36, 1]'), Text(0.06938775510204082, 0.375, 'gini = 0.0\nsamples = 12\nvalue = [23, 0]'), Text(0.07755102040816327, 0.375, 'x[12] <= 0.5\ngini = 0.133\nsamples = 11\nvalue = [13, 1]'), Text(0.07346938775510205, 0.325, 'x[2] <= 1.472\ngini = 0.375\nsamples = 4\nvalue = [3, 1]'), Text(0.06938775510204082, 0.275, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.07755102040816327, 0.275, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.08163265306122448, 0.325, 'gini = 0.0\nsamples = 7\nvalue = [10, 0]'), Text(0.09387755102040816, 0.475, 'x[1] <= 0.848\ngini = 0.091\nsamples = 41\nvalue = [60, 3]'), Text(0.08979591836734693, 0.425, 'x[2] <= -0.252\ngini = 0.48\nsamples = 3\nvalue = [2, 3]'), Text(0.08571428571428572, 0.375, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.09387755102040816, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.09795918367346938, 0.425, 'gini = 0.0\nsamples = 38\nvalue = [58, 0]'), Text(0.11836734693877551, 0.525, 'x[14] <= 0.5\ngini = 0.036\nsamples = 69\nvalue = [108, 2]'), Text(0.11020408163265306, 0.475, 'x[3] <= 0.5\ngini = 0.02\nsamples = 62\nvalue = [96, 1]'), Text(0.10612244897959183, 0.425, 'x[1] <= -0.165\ngini = 0.04\nsamples = 33\nvalue = [48, 1]'), Text(0.10204081632653061, 0.375, 'x[1] <= -0.247\ngini = 0.165\nsamples = 6\nvalue = [10, 1]'), Text(0.09795918367346938, 0.325, 'gini = 0.0\nsamples = 4\nvalue = [7, 0]'), Text(0.10612244897959183, 0.325, 'x[2] <= 0.016\ngini = 0.375\nsamples = 2\nvalue = [3, 1]'), Text(0.10204081632653061, 0.275, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.11020408163265306, 0.275, 'gini = 0.0\nsamples = 1\nvalue = [3, 0]'), Text(0.11020408163265306, 0.375, 'gini = 0.0\nsamples = 27\nvalue = [38, 0]'), Text(0.11428571428571428, 0.425, 'gini = 0.0\nsamples = 29\nvalue = [48, 0]'), Text(0.12653061224489795, 0.475, 'x[0] <= 0.56\ngini = 0.142\nsamples = 7\nvalue = [12, 1]'), Text(0.12244897959183673, 0.425, 'gini = 0.0\nsamples = 6\nvalue = [12, 0]'), Text(0.1306122448979592, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.2673469387755102, 0.575, 'x[14] <= 0.5\ngini = 0.062\nsamples = 551\nvalue = [823, 27]'), Text(0.21224489795918366, 0.525, 'x[0] <= 0.25\ngini = 0.062\nsamples = 508\nvalue = [758, 25]'), Text(0.1510204081632653, 0.475, 'x[10] <= 0.5\ngini = 0.019\nsamples = 268\nvalue = [415, 4]'), Text(0.13877551020408163, 0.425, 'x[2] <= -0.271\ngini = 0.009\nsamples = 135\nvalue = [210, 1]'), Text(0.1346938775510204, 0.375, 'gini = 0.0\nsamples = 67\nvalue = [111, 0]'), Text(0.14285714285714285, 0.375, 'x[12] <= 0.5\ngini = 0.02\nsamples = 68\nvalue = [99, 1]'), Text(0.13877551020408163, 0.325, 'x[2] <= -0.25\ngini = 0.041\nsamples = 34\nvalue = [47, 1]'), Text(0.1346938775510204, 0.275, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.14285714285714285, 0.275, 'gini = 0.0\nsamples = 33\nvalue = [47, 0]'), Text(0.1469387755102041, 0.325, 'gini = 0.0\nsamples = 34\nvalue = [52, 0]'), Text(0.16326530612244897, 0.425, 'x[0] <= -0.194\ngini = 0.028\nsamples = 133\nvalue = [205, 3]'), Text(0.15918367346938775, 0.375, 'gini = 0.0\nsamples = 73\nvalue = [112, 0]'), Text(0.1673469387755102, 0.375, 'x[0] <= -0.149\ngini = 0.061\nsamples = 60\nvalue = [93, 3]'), Text(0.15510204081632653, 0.325, 'x[3] <= 0.5\ngini = 0.298\nsamples = 9\nvalue = [9, 2]'), Text(0.1510204081632653, 0.275, 'x[2] <= -0.526\ngini = 0.408\nsamples = 6\nvalue = [5, 2]'), Text(0.1469387755102041, 0.225, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.15510204081632653, 0.225, 'x[2] <= -0.389\ngini = 0.5\nsamples = 3\nvalue = [2, 2]'), Text(0.1510204081632653, 0.175, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.15918367346938775, 0.175, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.15918367346938775, 0.275, 'gini = 0.0\nsamples = 3\nvalue = [4, 0]'), Text(0.17959183673469387, 0.325, 'x[3] <= 0.5\ngini = 0.023\nsamples = 51\nvalue = [84, 1]'), Text(0.17551020408163265, 0.275, 'x[2] <= -0.602\ngini = 0.05\nsamples = 26\nvalue = [38, 1]'), Text(0.17142857142857143, 0.225, 'x[13] <= 0.5\ngini = 0.133\nsamples = 10\nvalue = [13, 1]'), Text(0.1673469387755102, 0.175, 'x[1] <= 0.211\ngini = 0.142\nsamples = 9\nvalue = [12, 1]'), Text(0.16326530612244897, 0.125, 'gini = 0.0\nsamples = 5\nvalue = [8, 0]'), Text(0.17142857142857143, 0.125, 'x[1] <= 0.268\ngini = 0.32\nsamples = 4\nvalue = [4, 1]'), Text(0.1673469387755102, 0.075, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.17551020408163265, 0.075, 'gini = 0.0\nsamples = 3\nvalue = [4, 0]'), Text(0.17551020408163265, 0.175, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.17959183673469387, 0.225, 'gini = 0.0\nsamples = 16\nvalue = [25, 0]'), Text(0.1836734693877551, 0.275, 'gini = 0.0\nsamples = 25\nvalue = [46, 0]'), Text(0.27346938775510204, 0.475, 'x[13] <= 0.5\ngini = 0.109\nsamples = 240\nvalue = [343, 21]'), Text(0.24081632653061225, 0.425, 'x[3] <= 0.5\ngini = 0.091\nsamples = 193\nvalue = [278, 14]'), Text(0.21428571428571427, 0.375, 'x[2] <= -0.98\ngini = 0.053\nsamples = 99\nvalue = [143, 4]'), Text(0.2, 0.325, 'x[10] <= 0.5\ngini = 0.49\nsamples = 6\nvalue = [4, 3]'), Text(0.19183673469387755, 0.275, 'x[0] <= 0.738\ngini = 0.5\nsamples = 3\nvalue = [2, 2]'), Text(0.18775510204081633, 0.225, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.19591836734693877, 0.225, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.20816326530612245, 0.275, 'x[1] <= 0.192\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.20408163265306123, 0.225, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.21224489795918366, 0.225, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.22857142857142856, 0.325, 'x[12] <= 0.5\ngini = 0.014\nsamples = 93\nvalue = [139, 1]'), Text(0.22448979591836735, 0.275, 'x[1] <= -0.082\ngini = 0.029\nsamples = 48\nvalue = [67, 1]'), Text(0.22040816326530613, 0.225, 'x[0] <= 0.871\ngini = 0.165\nsamples = 11\nvalue = [10, 1]'), Text(0.2163265306122449, 0.175, 'gini = 0.0\nsamples = 9\nvalue = [9, 0]'), Text(0.22448979591836735, 0.175, 'x[0] <= 1.026\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.22040816326530613, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.22857142857142856, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.22857142857142856, 0.225, 'gini = 0.0\nsamples = 37\nvalue = [57, 0]'), Text(0.23265306122448978, 0.275, 'gini = 0.0\nsamples = 45\nvalue = [72, 0]'), Text(0.2673469387755102, 0.375, 'x[11] <= 0.5\ngini = 0.128\nsamples = 94\nvalue = [135, 10]'), Text(0.2571428571428571, 0.325, 'x[0] <= 0.871\ngini = 0.088\nsamples = 60\nvalue = [83, 4]'), Text(0.24897959183673468, 0.275, 'x[2] <= -0.439\ngini = 0.058\nsamples = 49\nvalue = [65, 2]'), Text(0.24489795918367346, 0.225, 'x[2] <= -0.481\ngini = 0.133\nsamples = 21\nvalue = [26, 2]'), Text(0.24081632653061225, 0.175, 'x[10] <= 0.5\ngini = 0.071\nsamples = 20\nvalue = [26, 1]'), Text(0.23673469387755103, 0.125, 'gini = 0.0\nsamples = 9\nvalue = [12, 0]'), Text(0.24489795918367346, 0.125, 'x[2] <= -0.713\ngini = 0.124\nsamples = 11\nvalue = [14, 1]'), Text(0.24081632653061225, 0.075, 'x[2] <= -0.795\ngini = 0.278\nsamples = 5\nvalue = [5, 1]'), Text(0.23673469387755103, 0.025, 'gini = 0.0\nsamples = 4\nvalue = [5, 0]'), Text(0.24489795918367346, 0.025, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.24897959183673468, 0.075, 'gini = 0.0\nsamples = 6\nvalue = [9, 0]'), Text(0.24897959183673468, 0.175, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.2530612244897959, 0.225, 'gini = 0.0\nsamples = 28\nvalue = [39, 0]'), Text(0.2653061224489796, 0.275, 'x[2] <= 2.552\ngini = 0.18\nsamples = 11\nvalue = [18, 2]'), Text(0.2612244897959184, 0.225, 'gini = 0.0\nsamples = 10\nvalue = [18, 0]'), Text(0.2693877551020408, 0.225, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.27755102040816326, 0.325, 'x[0] <= 0.671\ngini = 0.185\nsamples = 34\nvalue = [52, 6]'), Text(0.27346938775510204, 0.275, 'gini = 0.0\nsamples = 19\nvalue = [30, 0]'), Text(0.2816326530612245, 0.275, 'x[1] <= 0.205\ngini = 0.337\nsamples = 15\nvalue = [22, 6]'), Text(0.27755102040816326, 0.225, 'x[2] <= -0.487\ngini = 0.432\nsamples = 9\nvalue = [13, 6]'), Text(0.27346938775510204, 0.175, 'gini = 0.0\nsamples = 5\nvalue = [11, 0]'), Text(0.2816326530612245, 0.175, 'x[10] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [2, 6]'), Text(0.27755102040816326, 0.125, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.2857142857142857, 0.125, 'gini = 0.0\nsamples = 2\nvalue = [0, 6]'), Text(0.2857142857142857, 0.225, 'gini = 0.0\nsamples = 6\nvalue = [9, 0]'), Text(0.30612244897959184, 0.425, 'x[3] <= 0.5\ngini = 0.176\nsamples = 47\nvalue = [65, 7]'), Text(0.3020408163265306, 0.375, 'x[10] <= 0.5\ngini = 0.278\nsamples = 27\nvalue = [35, 7]'), Text(0.2938775510204082, 0.325, 'x[2] <= -0.102\ngini = 0.18\nsamples = 12\nvalue = [18, 2]'), Text(0.2897959183673469, 0.275, 'gini = 0.0\nsamples = 7\nvalue = [12, 0]'), Text(0.2979591836734694, 0.275, 'x[0] <= 0.427\ngini = 0.375\nsamples = 5\nvalue = [6, 2]'), Text(0.2938775510204082, 0.225, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.3020408163265306, 0.225, 'gini = 0.0\nsamples = 4\nvalue = [6, 0]'), Text(0.31020408163265306, 0.325, 'x[2] <= -0.2\ngini = 0.351\nsamples = 15\nvalue = [17, 5]'), Text(0.30612244897959184, 0.275, 'gini = 0.0\nsamples = 7\nvalue = [11, 0]'), Text(0.3142857142857143, 0.275, 'x[1] <= -0.069\ngini = 0.496\nsamples = 8\nvalue = [6, 5]'), Text(0.31020408163265306, 0.225, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]'), Text(0.3183673469387755, 0.225, 'x[1] <= 0.618\ngini = 0.375\nsamples = 6\nvalue = [6, 2]'), Text(0.3142857142857143, 0.175, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.3224489795918367, 0.175, 'x[0] <= 0.427\ngini = 0.444\nsamples = 3\nvalue = [1, 2]'), Text(0.3183673469387755, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.32653061224489793, 0.125, 'x[2] <= 1.401\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.3224489795918367, 0.075, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.3306122448979592, 0.075, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.31020408163265306, 0.375, 'gini = 0.0\nsamples = 20\nvalue = [30, 0]'), Text(0.3224489795918367, 0.525, 'x[0] <= 0.472\ngini = 0.058\nsamples = 43\nvalue = [65, 2]'), Text(0.3183673469387755, 0.475, 'x[0] <= 0.427\ngini = 0.124\nsamples = 20\nvalue = [28, 2]'), Text(0.3142857142857143, 0.425, 'gini = 0.0\nsamples = 17\nvalue = [26, 0]'), Text(0.3224489795918367, 0.425, 'x[1] <= 0.065\ngini = 0.5\nsamples = 3\nvalue = [2, 2]'), Text(0.3183673469387755, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.32653061224489793, 0.375, 'x[2] <= 0.615\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.3224489795918367, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.3306122448979592, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.32653061224489793, 0.475, 'gini = 0.0\nsamples = 23\nvalue = [37, 0]'), Text(0.3510204081632653, 0.625, 'x[2] <= -0.195\ngini = 0.102\nsamples = 124\nvalue = [175, 10]'), Text(0.34285714285714286, 0.575, 'x[2] <= -0.217\ngini = 0.169\nsamples = 55\nvalue = [68, 7]'), Text(0.33877551020408164, 0.525, 'x[0] <= 0.161\ngini = 0.105\nsamples = 54\nvalue = [68, 4]'), Text(0.3346938775510204, 0.475, 'gini = 0.0\nsamples = 35\nvalue = [44, 0]'), Text(0.34285714285714286, 0.475, 'x[1] <= 1.624\ngini = 0.245\nsamples = 19\nvalue = [24, 4]'), Text(0.33877551020408164, 0.425, 'x[0] <= 0.516\ngini = 0.49\nsamples = 5\nvalue = [3, 4]'), Text(0.3346938775510204, 0.375, 'gini = 0.0\nsamples = 2\nvalue = [0, 4]'), Text(0.34285714285714286, 0.375, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.3469387755102041, 0.425, 'gini = 0.0\nsamples = 14\nvalue = [21, 0]'), Text(0.3469387755102041, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.35918367346938773, 0.575, 'x[2] <= 1.906\ngini = 0.053\nsamples = 69\nvalue = [107, 3]'), Text(0.3551020408163265, 0.525, 'gini = 0.0\nsamples = 32\nvalue = [54, 0]'), Text(0.363265306122449, 0.525, 'x[12] <= 0.5\ngini = 0.101\nsamples = 37\nvalue = [53, 3]'), Text(0.35918367346938773, 0.475, 'x[3] <= 0.5\ngini = 0.185\nsamples = 20\nvalue = [26, 3]'), Text(0.3551020408163265, 0.425, 'x[0] <= 0.361\ngini = 0.337\nsamples = 10\nvalue = [11, 3]'), Text(0.3510204081632653, 0.375, 'gini = 0.0\nsamples = 4\nvalue = [8, 0]'), Text(0.35918367346938773, 0.375, 'x[11] <= 0.5\ngini = 0.5\nsamples = 6\nvalue = [3, 3]'), Text(0.3510204081632653, 0.325, 'x[2] <= 2.056\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.3469387755102041, 0.275, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.3551020408163265, 0.275, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.3673469387755102, 0.325, 'x[2] <= 2.248\ngini = 0.5\nsamples = 4\nvalue = [2, 2]'), Text(0.363265306122449, 0.275, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.37142857142857144, 0.275, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.363265306122449, 0.425, 'gini = 0.0\nsamples = 10\nvalue = [15, 0]'), Text(0.3673469387755102, 0.475, 'gini = 0.0\nsamples = 17\nvalue = [27, 0]'), Text(0.37551020408163266, 0.675, 'x[10] <= 0.5\ngini = 0.375\nsamples = 3\nvalue = [6, 2]'), Text(0.37142857142857144, 0.625, 'x[2] <= 3.402\ngini = 0.444\nsamples = 2\nvalue = [4, 2]'), Text(0.3673469387755102, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.37551020408163266, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [4, 0]'), Text(0.3795918367346939, 0.625, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.4275510204081633, 0.825, 'x[13] <= 0.5\ngini = 0.208\nsamples = 46\nvalue = [60, 8]'), Text(0.4020408163265306, 0.775, 'x[0] <= 0.605\ngini = 0.153\nsamples = 33\nvalue = [44, 4]'), Text(0.3979591836734694, 0.725, 'gini = 0.0\nsamples = 14\nvalue = [20, 0]'), Text(0.4061224489795918, 0.725, 'x[1] <= 0.37\ngini = 0.245\nsamples = 19\nvalue = [24, 4]'), Text(0.39591836734693875, 0.675, 'x[2] <= 2.446\ngini = 0.375\nsamples = 9\nvalue = [9, 3]'), Text(0.3877551020408163, 0.625, 'x[8] <= 0.5\ngini = 0.198\nsamples = 7\nvalue = [8, 1]'), Text(0.3836734693877551, 0.575, 'x[10] <= 0.5\ngini = 0.278\nsamples = 4\nvalue = [5, 1]'), Text(0.3795918367346939, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [3, 0]'), Text(0.3877551020408163, 0.525, 'x[1] <= 0.045\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.3836734693877551, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.39183673469387753, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.39183673469387753, 0.575, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.40408163265306124, 0.625, 'x[11] <= 0.5\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.4, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.40816326530612246, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.4163265306122449, 0.675, 'x[8] <= 0.5\ngini = 0.117\nsamples = 10\nvalue = [15, 1]'), Text(0.4122448979591837, 0.625, 'gini = 0.0\nsamples = 7\nvalue = [12, 0]'), Text(0.4204081632653061, 0.625, 'x[0] <= 0.782\ngini = 0.375\nsamples = 3\nvalue = [3, 1]'), Text(0.4163265306122449, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.42448979591836733, 0.575, 'x[2] <= -0.277\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.4204081632653061, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.42857142857142855, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.4530612244897959, 0.775, 'x[2] <= -0.302\ngini = 0.32\nsamples = 13\nvalue = [16, 4]'), Text(0.4489795918367347, 0.725, 'gini = 0.0\nsamples = 5\nvalue = [10, 0]'), Text(0.45714285714285713, 0.725, 'x[3] <= 0.5\ngini = 0.48\nsamples = 8\nvalue = [6, 4]'), Text(0.4489795918367347, 0.675, 'x[8] <= 0.5\ngini = 0.278\nsamples = 4\nvalue = [5, 1]'), Text(0.4448979591836735, 0.625, 'x[10] <= 0.5\ngini = 0.375\nsamples = 3\nvalue = [3, 1]'), Text(0.44081632653061226, 0.575, 'x[1] <= -0.12\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.43673469387755104, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.4448979591836735, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.4489795918367347, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.4530612244897959, 0.625, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.46530612244897956, 0.675, 'x[1] <= 0.415\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'), Text(0.46122448979591835, 0.625, 'x[10] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.45714285714285713, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.46530612244897956, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.46938775510204084, 0.625, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.7672991071428571, 0.925, 'x[1] <= -0.082\ngini = 0.292\nsamples = 421\nvalue = [548, 118]'), Text(0.6569196428571429, 0.875, 'x[2] <= -1.107\ngini = 0.336\nsamples = 194\nvalue = [247, 67]'), Text(0.6528380102040816, 0.825, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.661001275510204, 0.825, 'x[14] <= 0.5\ngini = 0.33\nsamples = 193\nvalue = [247, 65]'), Text(0.5628188775510204, 0.775, 'x[10] <= 0.5\ngini = 0.27\nsamples = 152\nvalue = [198, 38]'), Text(0.5173469387755102, 0.725, 'x[11] <= 0.5\ngini = 0.237\nsamples = 66\nvalue = [94, 15]'), Text(0.4897959183673469, 0.675, 'x[0] <= 1.58\ngini = 0.203\nsamples = 52\nvalue = [77, 10]'), Text(0.4775510204081633, 0.625, 'x[12] <= 0.5\ngini = 0.094\nsamples = 34\nvalue = [58, 3]'), Text(0.47346938775510206, 0.575, 'x[3] <= 0.5\ngini = 0.211\nsamples = 17\nvalue = [22, 3]'), Text(0.46938775510204084, 0.525, 'x[2] <= -0.574\ngini = 0.305\nsamples = 11\nvalue = [13, 3]'), Text(0.46122448979591835, 0.475, 'x[2] <= -0.793\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.45714285714285713, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.46530612244897956, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.4775510204081633, 0.475, 'x[7] <= 0.5\ngini = 0.142\nsamples = 9\nvalue = [12, 1]'), Text(0.47346938775510206, 0.425, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.4816326530612245, 0.425, 'x[13] <= 0.5\ngini = 0.219\nsamples = 6\nvalue = [7, 1]'), Text(0.4775510204081633, 0.375, 'gini = 0.0\nsamples = 5\nvalue = [7, 0]'), Text(0.4857142857142857, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.4775510204081633, 0.525, 'gini = 0.0\nsamples = 6\nvalue = [9, 0]'), Text(0.4816326530612245, 0.575, 'gini = 0.0\nsamples = 17\nvalue = [36, 0]'), Text(0.5020408163265306, 0.625, 'x[1] <= -0.502\ngini = 0.393\nsamples = 18\nvalue = [19, 7]'), Text(0.49795918367346936, 0.575, 'x[2] <= -0.348\ngini = 0.5\nsamples = 11\nvalue = [7, 7]'), Text(0.4897959183673469, 0.525, 'x[2] <= -0.693\ngini = 0.32\nsamples = 5\nvalue = [1, 4]'), Text(0.4857142857142857, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.49387755102040815, 0.475, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]'), Text(0.5061224489795918, 0.525, 'x[1] <= -0.75\ngini = 0.444\nsamples = 6\nvalue = [6, 3]'), Text(0.5020408163265306, 0.475, 'gini = 0.0\nsamples = 4\nvalue = [6, 0]'), Text(0.5102040816326531, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]'), Text(0.5061224489795918, 0.575, 'gini = 0.0\nsamples = 7\nvalue = [12, 0]'), Text(0.5448979591836735, 0.675, 'x[0] <= 1.624\ngini = 0.351\nsamples = 14\nvalue = [17, 5]'), Text(0.5346938775510204, 0.625, 'x[2] <= 0.255\ngini = 0.198\nsamples = 12\nvalue = [16, 2]'), Text(0.5265306122448979, 0.575, 'x[0] <= 1.247\ngini = 0.117\nsamples = 10\nvalue = [15, 1]'), Text(0.5224489795918368, 0.525, 'x[1] <= -0.674\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.5183673469387755, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.5265306122448979, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.5306122448979592, 0.525, 'gini = 0.0\nsamples = 8\nvalue = [14, 0]'), Text(0.5428571428571428, 0.575, 'x[15] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.5387755102040817, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.5469387755102041, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.5551020408163265, 0.625, 'x[1] <= -1.069\ngini = 0.375\nsamples = 2\nvalue = [1, 3]'), Text(0.5510204081632653, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.5591836734693878, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.6082908163265306, 0.725, 'x[7] <= 0.5\ngini = 0.297\nsamples = 86\nvalue = [104, 23]'), Text(0.5795918367346938, 0.675, 'x[0] <= 1.137\ngini = 0.165\nsamples = 43\nvalue = [60, 6]'), Text(0.5714285714285714, 0.625, 'x[2] <= -0.695\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.5673469387755102, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.5755102040816327, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.5877551020408164, 0.625, 'x[0] <= 1.536\ngini = 0.144\nsamples = 41\nvalue = [59, 5]'), Text(0.5836734693877551, 0.575, 'gini = 0.0\nsamples = 22\nvalue = [34, 0]'), Text(0.5918367346938775, 0.575, 'x[3] <= 0.5\ngini = 0.278\nsamples = 19\nvalue = [25, 5]'), Text(0.5795918367346938, 0.525, 'x[0] <= 1.669\ngini = 0.231\nsamples = 8\nvalue = [13, 2]'), Text(0.5755102040816327, 0.475, 'gini = 0.0\nsamples = 5\nvalue = [11, 0]'), Text(0.5836734693877551, 0.475, 'x[12] <= 0.5\ngini = 0.5\nsamples = 3\nvalue = [2, 2]'), Text(0.5795918367346938, 0.425, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.5877551020408164, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.6040816326530613, 0.525, 'x[0] <= 1.58\ngini = 0.32\nsamples = 11\nvalue = [12, 3]'), Text(0.6, 0.475, 'x[12] <= 0.5\ngini = 0.5\nsamples = 5\nvalue = [3, 3]'), Text(0.5959183673469388, 0.425, 'x[8] <= 0.5\ngini = 0.48\nsamples = 4\nvalue = [3, 2]'), Text(0.5918367346938775, 0.375, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.6, 0.375, 'x[1] <= -0.438\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.5959183673469388, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6040816326530613, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.6040816326530613, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.6081632653061224, 0.475, 'gini = 0.0\nsamples = 6\nvalue = [9, 0]'), Text(0.6369897959183674, 0.675, 'x[1] <= -1.056\ngini = 0.402\nsamples = 43\nvalue = [44, 17]'), Text(0.6163265306122448, 0.625, 'x[2] <= -0.391\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.6122448979591837, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6204081632653061, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.6576530612244897, 0.625, 'x[0] <= 1.536\ngini = 0.383\nsamples = 41\nvalue = [43, 15]'), Text(0.6285714285714286, 0.575, 'x[11] <= 0.5\ngini = 0.239\nsamples = 25\nvalue = [31, 5]'), Text(0.6204081632653061, 0.525, 'x[1] <= -0.407\ngini = 0.142\nsamples = 20\nvalue = [24, 2]'), Text(0.6163265306122448, 0.475, 'x[5] <= 0.5\ngini = 0.32\nsamples = 9\nvalue = [8, 2]'), Text(0.6122448979591837, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.6204081632653061, 0.425, 'x[2] <= -0.281\ngini = 0.198\nsamples = 8\nvalue = [8, 1]'), Text(0.6163265306122448, 0.375, 'gini = 0.0\nsamples = 4\nvalue = [5, 0]'), Text(0.6244897959183674, 0.375, 'x[12] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [3, 1]'), Text(0.6204081632653061, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.6285714285714286, 0.325, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.6244897959183674, 0.475, 'gini = 0.0\nsamples = 11\nvalue = [16, 0]'), Text(0.636734693877551, 0.525, 'x[15] <= 0.5\ngini = 0.42\nsamples = 5\nvalue = [7, 3]'), Text(0.6326530612244898, 0.475, 'x[2] <= -0.099\ngini = 0.49\nsamples = 4\nvalue = [4, 3]'), Text(0.6285714285714286, 0.425, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.636734693877551, 0.425, 'x[2] <= 1.018\ngini = 0.48\nsamples = 2\nvalue = [2, 3]'), Text(0.6326530612244898, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.6408163265306123, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.6408163265306123, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [3, 0]'), Text(0.686734693877551, 0.575, 'x[11] <= 0.5\ngini = 0.496\nsamples = 16\nvalue = [12, 10]'), Text(0.6714285714285714, 0.525, 'x[1] <= -0.286\ngini = 0.492\nsamples = 11\nvalue = [7, 9]'), Text(0.6612244897959184, 0.475, 'x[0] <= 1.713\ngini = 0.397\nsamples = 7\nvalue = [3, 8]'), Text(0.6530612244897959, 0.425, 'x[12] <= 0.5\ngini = 0.245\nsamples = 4\nvalue = [1, 6]'), Text(0.6489795918367347, 0.375, 'x[5] <= 0.5\ngini = 0.375\nsamples = 3\nvalue = [1, 3]'), Text(0.6448979591836734, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6530612244897959, 0.325, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]'), Text(0.6571428571428571, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.6693877551020408, 0.425, 'x[5] <= 0.5\ngini = 0.5\nsamples = 3\nvalue = [2, 2]'), Text(0.6653061224489796, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.673469387755102, 0.375, 'x[2] <= 0.382\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.6693877551020408, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6775510204081633, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.6816326530612244, 0.475, 'x[12] <= 0.5\ngini = 0.32\nsamples = 4\nvalue = [4, 1]'), Text(0.6775510204081633, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.6857142857142857, 0.425, 'x[15] <= 0.5\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.6816326530612244, 0.375, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.689795918367347, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7020408163265306, 0.525, 'x[2] <= -0.166\ngini = 0.278\nsamples = 5\nvalue = [5, 1]'), Text(0.6979591836734694, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [3, 0]'), Text(0.7061224489795919, 0.475, 'x[3] <= 0.5\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.7020408163265306, 0.425, 'x[2] <= 1.218\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.6979591836734694, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7061224489795919, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.710204081632653, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.7591836734693878, 0.775, 'x[3] <= 0.5\ngini = 0.458\nsamples = 41\nvalue = [49, 27]'), Text(0.7346938775510204, 0.725, 'x[11] <= 0.5\ngini = 0.483\nsamples = 26\nvalue = [26, 18]'), Text(0.7183673469387755, 0.675, 'x[2] <= -0.867\ngini = 0.499\nsamples = 16\nvalue = [14, 13]'), Text(0.7142857142857143, 0.625, 'gini = 0.0\nsamples = 3\nvalue = [4, 0]'), Text(0.7224489795918367, 0.625, 'x[12] <= 0.5\ngini = 0.491\nsamples = 13\nvalue = [10, 13]'), Text(0.7183673469387755, 0.575, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.726530612244898, 0.575, 'x[7] <= 0.5\ngini = 0.472\nsamples = 11\nvalue = [8, 13]'), Text(0.7183673469387755, 0.525, 'x[1] <= -0.528\ngini = 0.473\nsamples = 6\nvalue = [5, 8]'), Text(0.7142857142857143, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.7224489795918367, 0.475, 'x[10] <= 0.5\ngini = 0.397\nsamples = 4\nvalue = [3, 8]'), Text(0.7183673469387755, 0.425, 'x[1] <= -0.426\ngini = 0.198\nsamples = 3\nvalue = [1, 8]'), Text(0.7142857142857143, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 5]'), Text(0.7224489795918367, 0.375, 'x[0] <= 1.27\ngini = 0.375\nsamples = 2\nvalue = [1, 3]'), Text(0.7183673469387755, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.726530612244898, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.726530612244898, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.7346938775510204, 0.525, 'x[1] <= -0.738\ngini = 0.469\nsamples = 5\nvalue = [3, 5]'), Text(0.7306122448979592, 0.475, 'gini = 0.0\nsamples = 3\nvalue = [0, 5]'), Text(0.7387755102040816, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [3, 0]'), Text(0.7510204081632653, 0.675, 'x[10] <= 0.5\ngini = 0.415\nsamples = 10\nvalue = [12, 5]'), Text(0.7428571428571429, 0.625, 'x[1] <= -0.579\ngini = 0.444\nsamples = 2\nvalue = [2, 1]'), Text(0.7387755102040816, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.746938775510204, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.7591836734693878, 0.625, 'x[2] <= 1.535\ngini = 0.408\nsamples = 8\nvalue = [10, 4]'), Text(0.7551020408163265, 0.575, 'x[0] <= 1.647\ngini = 0.165\nsamples = 7\nvalue = [10, 1]'), Text(0.7510204081632653, 0.525, 'x[8] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [3, 1]'), Text(0.746938775510204, 0.475, 'x[0] <= 1.403\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.7428571428571429, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.7510204081632653, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7551020408163265, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.7591836734693878, 0.525, 'gini = 0.0\nsamples = 3\nvalue = [7, 0]'), Text(0.763265306122449, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.7836734693877551, 0.725, 'x[2] <= 2.573\ngini = 0.404\nsamples = 15\nvalue = [23, 9]'), Text(0.7795918367346939, 0.675, 'x[10] <= 0.5\ngini = 0.08\nsamples = 12\nvalue = [23, 1]'), Text(0.7755102040816326, 0.625, 'x[11] <= 0.5\ngini = 0.219\nsamples = 6\nvalue = [7, 1]'), Text(0.7714285714285715, 0.575, 'gini = 0.0\nsamples = 4\nvalue = [6, 0]'), Text(0.7795918367346939, 0.575, 'x[2] <= 2.27\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.7755102040816326, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7836734693877551, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.7836734693877551, 0.625, 'gini = 0.0\nsamples = 6\nvalue = [16, 0]'), Text(0.7877551020408163, 0.675, 'gini = 0.0\nsamples = 3\nvalue = [0, 8]'), Text(0.8776785714285714, 0.875, 'x[1] <= 0.236\ngini = 0.248\nsamples = 227\nvalue = [301, 51]'), Text(0.8173469387755102, 0.825, 'x[0] <= 1.137\ngini = 0.089\nsamples = 84\nvalue = [123, 6]'), Text(0.8040816326530612, 0.775, 'x[7] <= 0.5\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.8, 0.725, 'x[12] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.7959183673469388, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8040816326530612, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8081632653061225, 0.725, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8306122448979592, 0.775, 'x[2] <= 3.159\ngini = 0.076\nsamples = 81\nvalue = [121, 5]'), Text(0.8204081632653061, 0.725, 'x[1] <= -0.031\ngini = 0.063\nsamples = 79\nvalue = [119, 4]'), Text(0.8122448979591836, 0.675, 'x[2] <= 1.567\ngini = 0.236\nsamples = 16\nvalue = [19, 3]'), Text(0.8081632653061225, 0.625, 'x[3] <= 0.5\ngini = 0.095\nsamples = 14\nvalue = [19, 1]'), Text(0.8040816326530612, 0.575, 'x[1] <= -0.069\ngini = 0.198\nsamples = 7\nvalue = [8, 1]'), Text(0.8, 0.525, 'gini = 0.0\nsamples = 2\nvalue = [4, 0]'), Text(0.8081632653061225, 0.525, 'x[8] <= 0.5\ngini = 0.32\nsamples = 5\nvalue = [4, 1]'), Text(0.8040816326530612, 0.475, 'x[1] <= -0.05\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.8, 0.425, 'x[14] <= 0.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.7959183673469388, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8040816326530612, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8081632653061225, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8122448979591836, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.8122448979591836, 0.575, 'gini = 0.0\nsamples = 7\nvalue = [11, 0]'), Text(0.8163265306122449, 0.625, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.8285714285714286, 0.675, 'x[1] <= 0.058\ngini = 0.02\nsamples = 63\nvalue = [100, 1]'), Text(0.8244897959183674, 0.625, 'x[2] <= -0.589\ngini = 0.051\nsamples = 24\nvalue = [37, 1]'), Text(0.8204081632653061, 0.575, 'x[1] <= 0.039\ngini = 0.198\nsamples = 6\nvalue = [8, 1]'), Text(0.8163265306122449, 0.525, 'gini = 0.0\nsamples = 5\nvalue = [8, 0]'), Text(0.8244897959183674, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8285714285714286, 0.575, 'gini = 0.0\nsamples = 18\nvalue = [29, 0]'), Text(0.8326530612244898, 0.625, 'gini = 0.0\nsamples = 39\nvalue = [63, 0]'), Text(0.8408163265306122, 0.725, 'x[5] <= 0.5\ngini = 0.444\nsamples = 2\nvalue = [2, 1]'), Text(0.8367346938775511, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.8448979591836735, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9380102040816326, 0.825, 'x[10] <= 0.5\ngini = 0.322\nsamples = 143\nvalue = [178, 45]'), Text(0.8994897959183673, 0.775, 'x[2] <= 2.832\ngini = 0.427\nsamples = 62\nvalue = [67, 30]'), Text(0.8683673469387755, 0.725, 'x[1] <= 0.249\ngini = 0.353\nsamples = 54\nvalue = [64, 19]'), Text(0.8642857142857143, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.8724489795918368, 0.675, 'x[8] <= 0.5\ngini = 0.32\nsamples = 53\nvalue = [64, 16]'), Text(0.8510204081632653, 0.625, 'x[3] <= 0.5\ngini = 0.237\nsamples = 35\nvalue = [44, 7]'), Text(0.8408163265306122, 0.575, 'x[0] <= 1.713\ngini = 0.305\nsamples = 23\nvalue = [26, 6]'), Text(0.8326530612244898, 0.525, 'x[0] <= 1.358\ngini = 0.238\nsamples = 21\nvalue = [25, 4]'), Text(0.8285714285714286, 0.475, 'x[0] <= 1.159\ngini = 0.444\nsamples = 7\nvalue = [8, 4]'), Text(0.8244897959183674, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8326530612244898, 0.425, 'x[2] <= -0.016\ngini = 0.397\nsamples = 6\nvalue = [8, 3]'), Text(0.8285714285714286, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.8367346938775511, 0.375, 'x[0] <= 1.314\ngini = 0.198\nsamples = 5\nvalue = [8, 1]'), Text(0.8326530612244898, 0.325, 'gini = 0.0\nsamples = 4\nvalue = [8, 0]'), Text(0.8408163265306122, 0.325, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8367346938775511, 0.475, 'gini = 0.0\nsamples = 14\nvalue = [17, 0]'), Text(0.8489795918367347, 0.525, 'x[11] <= 0.5\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.8448979591836735, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.8530612244897959, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8612244897959184, 0.575, 'x[15] <= 0.5\ngini = 0.1\nsamples = 12\nvalue = [18, 1]'), Text(0.8571428571428571, 0.525, 'gini = 0.0\nsamples = 7\nvalue = [11, 0]'), Text(0.8653061224489796, 0.525, 'x[14] <= 0.5\ngini = 0.219\nsamples = 5\nvalue = [7, 1]'), Text(0.8612244897959184, 0.475, 'x[1] <= 0.447\ngini = 0.278\nsamples = 4\nvalue = [5, 1]'), Text(0.8571428571428571, 0.425, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.8653061224489796, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8693877551020408, 0.475, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.8938775510204081, 0.625, 'x[0] <= 1.491\ngini = 0.428\nsamples = 18\nvalue = [20, 9]'), Text(0.8775510204081632, 0.575, 'x[12] <= 0.5\ngini = 0.133\nsamples = 11\nvalue = [13, 1]'), Text(0.8734693877551021, 0.525, 'gini = 0.0\nsamples = 5\nvalue = [6, 0]'), Text(0.8816326530612245, 0.525, 'x[14] <= 0.5\ngini = 0.219\nsamples = 6\nvalue = [7, 1]'), Text(0.8775510204081632, 0.475, 'gini = 0.0\nsamples = 3\nvalue = [5, 0]'), Text(0.8857142857142857, 0.475, 'x[1] <= 0.95\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.8816326530612245, 0.425, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.889795918367347, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9102040816326531, 0.575, 'x[1] <= 0.593\ngini = 0.498\nsamples = 7\nvalue = [7, 8]'), Text(0.9061224489795918, 0.525, 'x[15] <= 0.5\ngini = 0.486\nsamples = 6\nvalue = [7, 5]'), Text(0.9020408163265307, 0.475, 'x[2] <= -0.027\ngini = 0.494\nsamples = 4\nvalue = [4, 5]'), Text(0.8979591836734694, 0.425, 'gini = 0.0\nsamples = 2\nvalue = [0, 5]'), Text(0.9061224489795918, 0.425, 'gini = 0.0\nsamples = 2\nvalue = [4, 0]'), Text(0.9102040816326531, 0.475, 'gini = 0.0\nsamples = 2\nvalue = [3, 0]'), Text(0.9142857142857143, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.9306122448979591, 0.725, 'x[7] <= 0.5\ngini = 0.337\nsamples = 8\nvalue = [3, 11]'), Text(0.926530612244898, 0.675, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.9346938775510204, 0.675, 'x[3] <= 0.5\ngini = 0.153\nsamples = 7\nvalue = [1, 11]'), Text(0.9306122448979591, 0.625, 'x[2] <= 2.911\ngini = 0.219\nsamples = 3\nvalue = [1, 7]'), Text(0.926530612244898, 0.575, 'x[2] <= 2.865\ngini = 0.444\nsamples = 2\nvalue = [1, 2]'), Text(0.9224489795918367, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.9306122448979591, 0.525, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9346938775510204, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 5]'), Text(0.9387755102040817, 0.625, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]'), Text(0.976530612244898, 0.775, 'x[15] <= 0.5\ngini = 0.21\nsamples = 81\nvalue = [111, 15]'), Text(0.9653061224489796, 0.725, 'x[3] <= 0.5\ngini = 0.19\nsamples = 72\nvalue = [101, 12]'), Text(0.9551020408163265, 0.675, 'x[5] <= 0.5\ngini = 0.237\nsamples = 54\nvalue = [69, 11]'), Text(0.9469387755102041, 0.625, 'x[14] <= 0.5\ngini = 0.397\nsamples = 8\nvalue = [8, 3]'), Text(0.9428571428571428, 0.575, 'gini = 0.0\nsamples = 6\nvalue = [8, 0]'), Text(0.9510204081632653, 0.575, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]'), Text(0.963265306122449, 0.625, 'x[1] <= 0.275\ngini = 0.205\nsamples = 46\nvalue = [61, 8]'), Text(0.9591836734693877, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 2]'), Text(0.9673469387755103, 0.575, 'x[11] <= 0.5\ngini = 0.163\nsamples = 45\nvalue = [61, 6]'), Text(0.9530612244897959, 0.525, 'x[12] <= 0.5\ngini = 0.078\nsamples = 34\nvalue = [47, 2]'), Text(0.9489795918367347, 0.475, 'gini = 0.0\nsamples = 11\nvalue = [16, 0]'), Text(0.9571428571428572, 0.475, 'x[1] <= 1.898\ngini = 0.114\nsamples = 23\nvalue = [31, 2]'), Text(0.9530612244897959, 0.425, 'x[7] <= 0.5\ngini = 0.061\nsamples = 22\nvalue = [31, 1]'), Text(0.9489795918367347, 0.375, 'x[14] <= 0.5\ngini = 0.1\nsamples = 12\nvalue = [18, 1]'), Text(0.9448979591836735, 0.325, 'x[0] <= 1.624\ngini = 0.117\nsamples = 10\nvalue = [15, 1]'), Text(0.9408163265306122, 0.275, 'gini = 0.0\nsamples = 7\nvalue = [12, 0]'), Text(0.9489795918367347, 0.275, 'x[2] <= -0.1\ngini = 0.375\nsamples = 3\nvalue = [3, 1]'), Text(0.9448979591836735, 0.225, 'x[1] <= 0.638\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.9408163265306122, 0.175, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9489795918367347, 0.175, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9530612244897959, 0.225, 'gini = 0.0\nsamples = 1\nvalue = [2, 0]'), Text(0.9530612244897959, 0.325, 'gini = 0.0\nsamples = 2\nvalue = [3, 0]'), Text(0.9571428571428572, 0.375, 'gini = 0.0\nsamples = 10\nvalue = [13, 0]'), Text(0.9612244897959183, 0.425, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9816326530612245, 0.525, 'x[0] <= 1.292\ngini = 0.346\nsamples = 11\nvalue = [14, 4]'), Text(0.9775510204081632, 0.475, 'x[1] <= 0.669\ngini = 0.5\nsamples = 5\nvalue = [4, 4]'), Text(0.9693877551020408, 0.425, 'x[1] <= 0.568\ngini = 0.375\nsamples = 2\nvalue = [1, 3]'), Text(0.9653061224489796, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9734693877551021, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 3]'), Text(0.9857142857142858, 0.425, 'x[8] <= 0.5\ngini = 0.375\nsamples = 3\nvalue = [3, 1]'), Text(0.9816326530612245, 0.375, 'gini = 0.0\nsamples = 2\nvalue = [3, 0]'), Text(0.9897959183673469, 0.375, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9857142857142858, 0.475, 'gini = 0.0\nsamples = 6\nvalue = [10, 0]'), Text(0.9755102040816327, 0.675, 'x[13] <= 0.5\ngini = 0.059\nsamples = 18\nvalue = [32, 1]'), Text(0.9714285714285714, 0.625, 'gini = 0.0\nsamples = 16\nvalue = [31, 0]'), Text(0.9795918367346939, 0.625, 'x[2] <= -0.329\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.9755102040816327, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9836734693877551, 0.575, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9877551020408163, 0.725, 'x[0] <= 1.58\ngini = 0.355\nsamples = 9\nvalue = [10, 3]'), Text(0.9836734693877551, 0.675, 'gini = 0.0\nsamples = 6\nvalue = [9, 0]'), Text(0.9918367346938776, 0.675, 'x[2] <= 0.426\ngini = 0.375\nsamples = 3\nvalue = [1, 3]'), Text(0.9877551020408163, 0.625, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.9959183673469387, 0.625, 'gini = 0.0\nsamples = 2\nvalue = [0, 3]')]
Now, let's move into the confusion matrix and see the specifics of the predictions.
y_pred = rf.predict(X_test)
conf_matrix = confusion_matrix(y_test, y_pred)
plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ['No Stroke', 'Stroke']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
for i in range(len(classes)):
for j in range(len(classes)):
plt.text(j, i, str(conf_matrix[i, j]), ha='center', va='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Analyzing the confusion matrix revealed that the predictions closely mirrored those of the neural network. Moreover, it was observed that the random forest shares a similar issue with the neural net—incorrect predictions tend to classify patients as stroke patients (i.e., false negatives or type-II errors).
Addressing the challenge of false negatives, one could consider optimizing the algorithms to enhance the models. During the training of both models, I adhered to generic preset parameters for simplicity. However, I chose not to delve into optimization for the sake of simplicity in this analysis.
It's worth noting that, during the course of this analysis, I did write code to perform a grid search aimed at optimizing the parameters used for training the models. In the grid search, the optimization focused on reducing false positives. However, the decrease in type-II error was minimal, and there was a substantial increase in overall error and type 1 error. For these reasons, I decided to stick with the machine learning algorithms as trained and presented in this notebook.
4. Feature importance¶
What insight can we get fom the trained machine leanring algorithims?
While both neural networks and random forests are essentially Blackbox techniques, implying limited understanding of the underlying processes, there have been advancements that allow us to gain insights into how machine learning algorithms are influenced by the features provided. Notably, feature importance comes to the forefront in this context. The specific approach I employed for the trained algorithms involved SHAP (SHapley Additive exPlanations) values. SHAP values, rooted in game theory, assign an importance value to each feature in a model. Furthermore, these values can be visualized to illustrate the relative importance of features compared to others in achieving prediction accuracy.
Nerual Net Shap values¶
explainer = shap.DeepExplainer(nn, np.array(X_train))
shap_values = explainer.shap_values(np.array(X_test))
shap.summary_plot(shap_values, np.array(X_test), feature_names=X_test.columns)
Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion. `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
The generated bar plot above illustrates the mean SHAP values for features calculated from the testing dataset and the neural net. These values, on their own, do not provide a direct explanation of how well the trained algorithm predicts stroke patients. However, considering the tested accuracy exceeded 95%, the SHAP values serve as a reliable indication of which features contribute significantly to the high prediction accuracy. Analyzing the SHAP values, it was revealed that age, BMI, and the patient's smoking history were the three features with the most substantial impact on predictions for the testing dataset. Notably, age emerged as the most influential among these features.
Random Forest Shap values¶
explainer = shap.Explainer(rf)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test, feature_names=X_test.columns)
The generated bar plot above illustrates the mean SHAP values for features calculated from the testing dataset and the random forest. Like the neural net age was the most impactful feature but I found a few other features that weighed more heavily for the random forest than it did for the neural net. For instance, average glucose level and hypertension were the second and third most important feature in obtaining the high prediction accuracy.
5. Conclusion¶
In the process of training and testing the two selected machine learning algorithms on the stroke patient dataset, I obtained prediction models with accuracies exceeding 95%. Furthermore, I identified key features contributing to this high accuracy, including age, BMI, average glucose level, and the presence of hypertension. Notably, a common trend in both trained algorithms was the occurrence of predominantly false negatives when errors arose. This finding raises concerns about deploying either of these algorithms into production within the healthcare field, as false negatives can have severe consequences. Careful consideration and potential improvements are necessary before implementing these models in a real-world healthcare setting.
Given the significant impact of features such as age, BMI, average glucose level, and the presence of hypertension on the accuracy of the prediction models, an avenue for improvement lies in expanding data collection for these critical features. Increasing the volume and diversity of data related to these key variables could provide the algorithms with a more comprehensive understanding of their relationships with stroke occurrences. This expanded dataset could help capture nuanced patterns and variations, leading to more robust and generalizable models. Additionally, incorporating additional relevant features or exploring interactions among existing features might further enhance the algorithms' predictive capabilities. The emphasis should be on obtaining a richer and more representative dataset that aligns closely with the critical factors influencing stroke prediction. Regular updates and refinements based on continuously growing datasets could contribute to the algorithms' adaptability and improved performance over time.