Train ML algorithms with Python scikit-learn

One of the most used Python libraries for Machine Learning is scikit-learn. It provides implementation a broad range of algorithm, and is relatively simple to use out-of-the-box. Let’s look at an example of Logistic Regression with the classic ‘iris’ dataset.

# Import libraries
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

Load dataset

# Load sample dataset into a DataFrame
X = sns.load_dataset('iris')
X
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 virginica
146 6.3 2.5 5.0 1.9 virginica
147 6.5 3.0 5.2 2.0 virginica
148 6.2 3.4 5.4 2.3 virginica
149 5.9 3.0 5.1 1.8 virginica

150 rows × 5 columns

# Pop the target column 'species' and store in a distinct Series
y = X.pop('species')

Split into train and test sets

# Use `train_test_split()` function from sklearn to extract a 20% random test set 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

Fit model

# Instantiate a logistic regression model with default parameters
model = LogisticRegression()

# Fit model on training set
model.fit(X_train, y_train)

# Predict on test set
y_pred = model.predict(X_test)

Evaluate model performance

# Classification report
print(sklearn.metrics.classification_report(y_test, y_pred))
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00         6
  versicolor       0.83      1.00      0.91        10
   virginica       1.00      0.86      0.92        14

    accuracy                           0.93        30
   macro avg       0.94      0.95      0.94        30
weighted avg       0.94      0.93      0.93        30
# Plot confusion matrix
sklearn.metrics.plot_confusion_matrix(model, X_test, y_test)
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fb3d0117510>

png