from abc import ABC
from scipy import sparse
from sconce.trainer import Trainer
from matplotlib import pyplot as plt
import seaborn as sn
import numpy as np
__all__ = ['ClassifierTrainer', 'SingleClassImageClassifierMixin', 'SingleClassImageClassifierTrainer']
[docs]class SingleClassImageClassifierMixin(ABC):
[docs] def get_confusion_matrix(self, data_generator=None, cache_results=True):
if data_generator is None:
data_generator = self.test_data_generator
run_model_results = self._run_model_on_generator(data_generator,
cache_results=cache_results)
targets = run_model_results['targets']
predicted_targets = np.argmax(run_model_results['outputs'], axis=1)
matrix = sparse.coo_matrix((np.ones(len(targets)),
(predicted_targets, targets)), dtype='uint32').toarray()
return matrix
[docs] def get_classification_accuracy(self, data_generator=None,
cache_results=True):
if data_generator is None:
data_generator = self.test_data_generator
matrix = self.get_confusion_matrix(data_generator=data_generator,
cache_results=cache_results)
num_correct = np.trace(matrix)
return num_correct / data_generator.num_samples
[docs] def plot_confusion_matrix(self, data_generator=None, **heatmap_kwargs):
if data_generator is None:
data_generator = self.test_data_generator
matrix = self.get_confusion_matrix(data_generator=data_generator)
dataset = data_generator.real_dataset
defaults = {'cmap': 'YlGnBu', 'annot': True, 'fmt': 'd',
'xticklabels': dataset.classes,
'yticklabels': dataset.classes}
ax = sn.heatmap(matrix, **{**defaults, **heatmap_kwargs})
ax.xaxis.set_ticklabels(ax.xaxis.get_ticklabels(), rotation=0)
ax.yaxis.set_ticklabels(ax.yaxis.get_ticklabels(),
rotation=0, ha='right')
ax.set_xlabel('True')
ax.set_ylabel('Predicted')
return ax
[docs] def plot_samples(self, predicted_class,
true_class=None,
data_generator=None,
sort_by='rising predicted class score',
num_samples=7,
num_cols=7,
figure_width=15,
image_height=3,
cache_results=True):
"""
Plot samples of the dataset where the given <predicted_class> was predicted by the model.
Arguments:
predicted_class (int or string): the class string or the index of the class that was predicted by the model.
true_class (int or string): the class string or the index of the class that the image actually belongs to.
data_generator (:py:class:`~sconce.data_generators.base.SingleClassImageDataGenerator`): the data generator
to use to find the samples.
sort_by (string): one of the sort_by strings, see note below.
num_samples (int): the number of sample images to plot.
num_cols (int): the number of columns to plot, one image per column.
figure_width (float): the size, in matplotlib-inches, for the width of the whole figure.
image_height (float): the size, in matplotlib-inches, for the height of a single image.
cache_results (bool): keep the results in memory to make subsequent plots faster. Beware, that on large
datasets (like imagenet) this can cause your system to run out of memory.
Note:
The sort_by strings supported are:
- "rising predicted class score": samples are plotted in order of the lowest predicted class score to
the highest predicted class score.
- "falling predicted class score": samples are plotted in order of the higest predicted class score to
the lowest predicted class score.
- "rising true class score": samples are plotted in order of the lowest true class score to
the highest true class score.
- "falling true class score": samples are plotted in order of the higest true class score to
the lowest true class score.
"""
if data_generator is None:
data_generator = self.test_data_generator
dataset = data_generator.real_dataset
predicted_class = self._convert_to_class_index(predicted_class, dataset)
true_class = self._convert_to_class_index(true_class, dataset, default=predicted_class)
run_model_results = self._run_model_on_generator(data_generator,
cache_results=cache_results)
images = run_model_results['inputs']
targets = run_model_results['targets']
outputs = run_model_results['outputs']
predicted_targets = np.argmax(outputs, axis=1)
keep_idxs = ((targets == true_class) &
(predicted_targets == predicted_class))
kept_images = images[keep_idxs]
predicted_class_scores = np.exp(outputs[keep_idxs, predicted_class])
true_class_scores = np.exp(outputs[keep_idxs, true_class])
kept_images = np.array(kept_images)
predicted_class_scores = np.array(predicted_class_scores)
true_class_scores = np.array(true_class_scores)
sort_fns = {
'rising predicted class score': lambda p, t: np.argsort(p),
'falling predicted class score': lambda p, t: np.argsort(p)[::-1],
'rising true class score': lambda p, t: np.argsort(t),
'falling true class score': lambda p, t: np.argsort(t)[::-1],
}
sort_fn = sort_fns[sort_by]
sort_key = sort_fn(predicted_class_scores, true_class_scores)
sorted_kept_images = kept_images[sort_key]
sorted_predicted_class_scores = predicted_class_scores[sort_key]
sorted_true_class_scores = true_class_scores[sort_key]
if num_samples < len(kept_images):
print(f'Showing only the first {num_samples} of '
f'{len(kept_images)} images')
num_samples = min(num_samples, len(kept_images))
num_rows = -(-num_samples // num_cols)
fig = plt.figure(figsize=(figure_width, image_height * num_rows))
for i in range(num_samples):
image = sorted_kept_images[i]
predicted_class_score = sorted_predicted_class_scores[i]
true_class_score = sorted_true_class_scores[i]
if image.shape[0] == 1:
# greyscale image
image = image[0]
cmap = 'gray'
else:
# color channels present
image = image.swapaxes(0, 2)
image = image.swapaxes(0, 1)
cmap = None
ax = fig.add_subplot(num_rows, num_cols, i + 1)
ax.imshow(image, cmap=cmap)
if true_class != predicted_class:
ax.set_title('p: %2.1f%%\nt: %2.1f%%' % (
predicted_class_score * 100, true_class_score * 100))
else:
ax.set_title('%2.1f%%' % (predicted_class_score * 100))
ax.axis('off')
plt.tight_layout()
fig.subplots_adjust(wspace=0.05)
return fig
def _convert_to_class_index(self, _class, dataset, default=None):
if _class is None:
return default
else:
if not isinstance(_class, int):
return dataset.class_to_idx[_class]
else:
return _class
[docs]class SingleClassImageClassifierTrainer(Trainer, SingleClassImageClassifierMixin):
"""
A trainer with some methods that are handy when you're training an image classifier model. Specifically a model
that classifies images into a single class per image.
New in 0.10.0 (Used to be called ClassifierTrainer)
"""
pass
[docs]class ClassifierTrainer(Trainer, SingleClassImageClassifierMixin):
"""
Warning:
This class has been deprecated for :py:class:`~sconce.trainers.SingleClassImageClassifierTrainer` and will be
removed soon. It will continue to work for now, but please update your code accordingly.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print('WARNING: ClassifierTrainer is deprecated as of 0.10.0, and will be removed soon. Use '
'"SingleClassImageClassifierTrainer" instead.')