Image denoising using dictionary learning¶
An example comparing the effect of reconstructing noisy fragments of Lena using online Dictionary Learning and various transform methods.
The dictionary is fitted on the non-distorted left half of the image, and subsequently used to reconstruct the right half.
A common practice for evaluating the results of image denoising is by looking at the difference between the reconstruction and the original image. If the reconstruction is perfect this will look like gaussian noise.
It can be seen from the plots that the results of Orthogonal Matching Pursuit (OMP) with two non-zero coefficients is a bit less biased than when keeping only one (the edges look less prominent). It is in addition closer from the ground truth in Frobenius norm.
The result of Least Angle Regression is much more strongly biased: the difference is reminiscent of the local intensity value of the original image.
Thresholding is clearly not useful for denoising, but it is here to show that it can produce a suggestive output with very high speed, and thus be useful for other tasks such as object classification, where performance is not necessarily related to visualisation.
Python source code: plot_image_denoising.py
print __doc__
from time import time
import pylab as pl
import scipy as sp
import numpy as np
from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.feature_extraction.image import reconstruct_from_patches_2d
###############################################################################
# Load Lena image and extract patches
lena = sp.lena() / 256.0
# downsample for higher speed
lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2]
lena /= 4.0
height, width = lena.shape
# Distort the right half of the image
print 'Distorting image...'
distorted = lena.copy()
distorted[:, height / 2:] += 0.075 * np.random.randn(width, height / 2)
# Extract all clean patches from the left half of the image
print 'Extracting clean patches...'
t0 = time()
patch_size = (7, 7)
data = extract_patches_2d(distorted[:, :height / 2], patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print 'done in %.2fs.' % (time() - t0)
###############################################################################
# Learn the dictionary from clean patches
print 'Learning the dictionary... '
t0 = time()
dico = MiniBatchDictionaryLearning(n_atoms=100, alpha=1e-2, n_iter=500)
V = dico.fit(data).components_
dt = time() - t0
print 'done in %.2fs.' % dt
pl.figure(figsize=(4.2, 4))
for i, comp in enumerate(V[:100]):
pl.subplot(10, 10, i + 1)
pl.imshow(comp.reshape(patch_size), cmap=pl.cm.gray_r,
interpolation='nearest')
pl.xticks(())
pl.yticks(())
pl.suptitle('Dictionary learned from Lena patches\n' +
'Train time %.1fs on %d patches' % (dt, len(data)),
fontsize=16)
pl.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
###############################################################################
# Display the distorted image
def show_with_diff(image, reference, title):
"""Helper function to display denoising"""
pl.figure(figsize=(5, 3.3))
pl.subplot(1, 2, 1)
pl.title('Image')
pl.imshow(image, vmin=0, vmax=1, cmap=pl.cm.gray, interpolation='nearest')
pl.xticks(())
pl.yticks(())
pl.subplot(1, 2, 2)
difference = image - reference
pl.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
pl.imshow(difference, vmin=-0.5, vmax=0.5, cmap=pl.cm.PuOr,
interpolation='nearest')
pl.xticks(())
pl.yticks(())
pl.suptitle(title, size=16)
pl.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)
show_with_diff(distorted, lena, 'Distorted image')
###############################################################################
# Extract noisy patches and reconstruct them using the dictionary
print 'Extracting noisy patches... '
t0 = time()
data = extract_patches_2d(distorted[:, height / 2:], patch_size)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print 'done in %.2fs.' % (time() - t0)
transform_algorithms = [
('Orthogonal Matching Pursuit\n1 atom', 'omp',
{'transform_n_nonzero_coefs': 1}),
('Orthogonal Matching Pursuit\n2 atoms', 'omp',
{'transform_n_nonzero_coefs': 2}),
('Least-angle regression\n5 atoms', 'lars',
{'transform_n_nonzero_coefs': 5}),
('Thresholding\n alpha=0.1', 'threshold', {'transform_alpha': .1})]
reconstructions = {}
for title, transform_algorithm, kwargs in transform_algorithms:
print title, '... '
reconstructions[title] = lena.copy()
t0 = time()
dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
code = dico.transform(data)
patches = np.dot(code, V)
if transform_algorithm == 'threshold':
patches -= patches.min()
patches /= patches.max()
patches += intercept
patches = patches.reshape(len(data), *patch_size)
if transform_algorithm == 'threshold':
patches -= patches.min()
patches /= patches.max()
reconstructions[title][:, height / 2:] = reconstruct_from_patches_2d(
patches, (width, height / 2))
dt = time() - t0
print 'done in %.2fs.' % dt
show_with_diff(reconstructions[title], lena,
title + ' (time: %.1fs)' % dt)
pl.show()