You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1104 lines
33 KiB

# -*- coding: utf-8 -*-
# Copyright (C) 2015-2017 by Brendt Wohlberg <brendt@ieee.org>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.
"""Utility functions"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from future.utils import PY2
from builtins import range
from builtins import object
from timeit import default_timer as timer
import os
import imghdr
import io
import platform
import multiprocessing as mp
import itertools
import collections
import socket
if PY2:
import urllib2 as urlrequest
import urllib2 as urlerror
else:
import urllib.request as urlrequest
import urllib.error as urlerror
import numpy as np
from scipy import misc
import scipy.ndimage.interpolation as sni
import sporco.linalg as sla
import sporco.plot as spl
__author__ = """Brendt Wohlberg <brendt@ieee.org>"""
import warnings
def plot(*args, **kwargs):
warnings.warn("sporco.util.plot is deprecated: use sporco.plot.plot",
PendingDeprecationWarning)
return spl.plot(*args, **kwargs)
def surf(*args, **kwargs):
warnings.warn("sporco.util.surf is deprecated: use sporco.plot.surf",
PendingDeprecationWarning)
return spl.surf(*args, **kwargs)
def imview(*args, **kwargs):
warnings.warn("sporco.util.imview is deprecated: use sporco.plot.imview",
PendingDeprecationWarning)
return spl.imview(*args, **kwargs)
# Python 2/3 unicode literal compatibility
if PY2:
import codecs
def u(x):
"""Python 2/3 compatible definition of utf8 literals"""
return x.decode('utf8')
else:
def u(x):
"""Python 2/3 compatible definition of utf8 literals"""
return x
def ntpl2array(ntpl):
"""
Convert a :func:`collections.namedtuple` object to a :class:`numpy.ndarray`
object that can be saved using :func:`numpy.savez`.
Parameters
----------
ntpl : collections.namedtuple object
Named tuple object to be converted to ndarray
Returns
-------
arr : ndarray
Array representation of input named tuple
"""
return np.asarray((np.vstack([col for col in ntpl]), ntpl._fields,
ntpl.__class__.__name__))
def array2ntpl(arr):
"""
Convert a :class:`numpy.ndarray` object constructed by :func:`ntpl2array`
back to the original :func:`collections.namedtuple` representation.
Parameters
----------
arr : ndarray
Array representation of named tuple constructed by :func:`ntpl2array`
Returns
-------
ntpl : collections.namedtuple object
Named tuple object with the same name and fields as the original named
typle object provided to :func:`ntpl2array`
"""
cls = collections.namedtuple(arr[2], arr[1])
return cls(*tuple(arr[0]))
def transpose_ntpl_list(lst):
"""Transpose a list of named tuple objects (of the same type) into a
named tuple of lists.
Parameters
----------
lst : list of collections.namedtuple object
List of named tuple objects of the same type
Returns
-------
ntpl : collections.namedtuple object
Named tuple object with each entry consisting of a list of the
corresponding fields of the named tuple objects in list ``lst``
"""
cls = collections.namedtuple(lst[0].__class__.__name__, lst[0]._fields)
if len(lst) == 0:
return None
else:
return cls(*[[lst[k][l] for k in range(len(lst))]
for l in range(len(lst[0]))])
def solve_status_str(hdrtxt, fwiter=4, fpothr=2):
"""Construct header and format details for status display of an
iterative solver.
Parameters
----------
hdrtxt : tuple of strings
Tuple of field header strings
fwiter : int, optional (default 4)
Number of characters in iteration count integer field
fpothr : int, optional (default 2)
Precision of other float field
Returns
-------
hdrstr : string
Complete header string
fmtstr : string
Complete print formatting string for numeric values
nsep : integer
Number of characters in separator string
"""
# Field width for all fields other than first depends on precision
fwothr = fpothr + 6
# Construct header string from hdrtxt list of column headers
hdrstr = ("%-*s" % (fwiter+2, hdrtxt[0])) + \
((("%%-%ds " % (fwothr+1)) * (len(hdrtxt)-1)) % \
tuple(hdrtxt[1:]))
# Construct iteration status format string
fmtstr = ("%%%dd" % (fwiter)) + (((" %%%d.%de" % (fwothr, fpothr)) * \
(len(hdrtxt)-1)))
# Compute length of separator string
nsep = fwiter + (fwothr + 2)*(len(hdrtxt)-1)
return hdrstr, fmtstr, nsep
def tiledict(D, sz=None):
"""Construct an image allowing visualization of dictionary content.
Parameters
----------
D : array_like
Dictionary matrix/array.
sz : tuple
Size of each block in dictionary.
Returns
-------
im : ndarray
Image tiled with dictionary entries.
"""
# Handle standard 2D (non-convolutional) dictionary
if D.ndim == 2:
D = D.reshape((sz + (D.shape[1],)))
sz = None
dsz = D.shape
if D.ndim == 4:
axisM = 3
szni = 3
else:
axisM = 2
szni = 2
# Construct dictionary atom size vector if not provided
if sz is None:
sz = np.tile(np.array(dsz[0:2]).reshape([2, 1]), (1, D.shape[axisM]))
else:
sz = np.array(sum(tuple((x[0:2],) * x[szni] for x in sz), ())).T
# Compute the maximum atom dimensions
mxsz = np.amax(sz, 1)
# Shift and scale values to [0, 1]
D = D - D.min()
D = D / D.max()
# Construct tiled image
N = dsz[axisM]
Vr = int(np.floor(np.sqrt(N)))
Vc = int(np.ceil(N/float(Vr)))
if D.ndim == 4:
im = np.ones((Vr*mxsz[0] + Vr-1, Vc*mxsz[1] + Vc-1, dsz[2]))
else:
im = np.ones((Vr*mxsz[0] + Vr-1, Vc*mxsz[1] + Vc-1))
k = 0
for l in range(0, Vr):
for m in range(0, Vc):
r = mxsz[0]*l + l
c = mxsz[1]*m + m
if D.ndim == 4:
im[r:(r+sz[0, k]), c:(c+sz[1, k]), :] = D[0:sz[0, k],
0:sz[1, k], :, k]
else:
im[r:(r+sz[0, k]), c:(c+sz[1, k])] = D[0:sz[0, k],
0:sz[1, k], k]
k = k + 1
if k >= N:
break
if k >= N:
break
return im
def imageblocks(imgs, blksz):
"""Extract all blocks of specified size from an image or list of images.
Parameters
----------
imgs: array_like or tuple of array_like
Single image or tuple of images from which to extract blocks
blksz : tuple of two ints
Size of the blocks
Returns
-------
blks : ndarray
Array of extracted blocks
"""
# See http://stackoverflow.com/questions/16774148 and
# sklearn.feature_extraction.image.extract_patches_2d
if not isinstance(imgs, tuple):
imgs = (imgs,)
blks = np.array([]).reshape(blksz + (0,))
for im in imgs:
Nr, Nc = im.shape
nr, nc = blksz
shape = (Nr-nr+1, Nc-nc+1, nr, nc)
strides = im.itemsize*np.array([Nc, 1, Nc, 1])
sb = np.lib.stride_tricks.as_strided(np.ascontiguousarray(im),
shape=shape, strides=strides)
sb = np.ascontiguousarray(sb)
sb.shape = (-1, nr, nc)
sb = np.rollaxis(sb, 0, 3)
blks = np.dstack((blks, sb))
return blks
def rgb2gray(rgb):
"""Convert an RGB image (or images) to grayscale.
Parameters
----------
rgb : ndarray
RGB image as Nr x Nc x 3 or Nr x Nc x 3 x K array
Returns
-------
gry : ndarray
Grayscale image as Nr x Nc or Nr x Nc x K array
"""
w = sla.atleast_nd(rgb.ndim, np.array([0.299, 0.587, 0.144],
dtype=rgb.dtype, ndmin=3))
return np.sum(w * rgb, axis=2)
def complex_randn(*args):
"""Return a complex array of samples drawn from a standard normal
distribution.
Parameters
----------
d0, d1, ..., dn: int
Dimensions of the random array
Returns
-------
a : ndarray
Random array of shape (d0, d1, ..., dn)
"""
return np.random.randn(*args) + 1j*np.random.randn(*args)
def spnoise(s, frc, smn=0.0, smx=1.0):
"""Return image with salt & pepper noise imposed on it.
Parameters
----------
s : ndarray
Input image
frc : float
Desired fraction of pixels corrupted by noise
smn : float, optional (default 0.0)
Lower value for noise (pepper)
smx : float, optional (default 1.0)
Upper value for noise (salt)
Returns
-------
sn : ndarray
Noisy image
"""
sn = s.copy()
spm = np.random.uniform(-1.0, 1.0, s.shape)
sn[spm < frc - 1.0] = smn
sn[spm > 1.0 - frc] = smx
return sn
def tikhonov_filter(s, lmbda, npd=16):
r"""Lowpass filter based on Tikhonov regularization.
Lowpass filter image(s) and return low and high frequency
components, consisting of the lowpass filtered image and its
difference with the input image. The lowpass filter is equivalent to
Tikhonov regularization with `lmbda` as the regularization parameter
and a discrete gradient as the operator in the regularization term,
i.e. the lowpass component is the solution to
.. math::
\mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s}
\right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;,
where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the
regularization parameter, and :math:`G_i` is an operator that
computes the discrete gradient along image axis :math:`i`. Once the
lowpass component :math:`\mathbf{x}` has been computed, the highpass
component is just :math:`\mathbf{s} - \mathbf{x}`.
Parameters
----------
s : array_like
Input image or array of images.
lmbda : float
Regularization parameter controlling lowpass filtering.
npd : int, optional (default=16)
Number of samples to pad at image boundaries.
Returns
-------
sl : array_like
Lowpass image or array of images.
sh : array_like
Highpass image or array of images.
"""
grv = np.array([-1.0, 1.0]).reshape([2, 1])
gcv = np.array([-1.0, 1.0]).reshape([1, 2])
Gr = sla.fftn(grv, (s.shape[0]+2*npd, s.shape[1]+2*npd), (0, 1))
Gc = sla.fftn(gcv, (s.shape[0]+2*npd, s.shape[1]+2*npd), (0, 1))
A = 1.0 + lmbda*np.conj(Gr)*Gr + lmbda*np.conj(Gc)*Gc
if s.ndim > 2:
A = A[(slice(None),)*2 + (np.newaxis,)*(s.ndim-2)]
sp = np.pad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric')
slp = np.real(sla.ifftn(sla.fftn(sp, axes=(0, 1)) / A, axes=(0, 1)))
sl = slp[npd:(slp.shape[0]-npd), npd:(slp.shape[1]-npd)]
sh = s - sl
return sl.astype(s.dtype), sh.astype(s.dtype)
def idle_cpu_count(mincpu=1):
"""Estimate number of idle CPUs, for use by multiprocessing code
needing to determine how many processes can be run without excessive
load. This function uses :func:`os.getloadavg` which is only available
under a Unix OS.
Parameters
----------
mincpu : int
Minimum number of CPUs to report, independent of actual estimate
Returns
-------
idle : int
Estimate of number of idle CPUs
"""
if PY2:
ncpu = mp.cpu_count()
else:
ncpu = os.cpu_count()
idle = int(ncpu - np.floor(os.getloadavg()[0]))
return max(mincpu, idle)
def grid_search(fn, grd, fmin=True, nproc=None):
"""Perform a grid search for optimal parameters of a specified
function. In the simplest case the function returns a float value,
and a single optimum value and corresponding parameter values are
identified. If the function returns a tuple of values, each of
these is taken to define a separate function on the search grid,
with optimum function values and corresponding parameter values
being identified for each of them. On all platforms except Windows
(where ``mp.Pool`` usage has some limitations), the computation
of the function at the grid points is computed in parallel.
**Warning:** This function will hang if `fn` makes use of :mod:`pyfftw`
with multi-threading enabled (the
`bug <https://github.com/pyFFTW/pyFFTW/issues/135>`_ has been reported).
When using the FFT functions in :mod:`sporco.linalg`, multi-threading
can be disabled by including the following code::
import sporco.linalg
sporco.linalg.pyfftw_threads = 1
Parameters
----------
fn : function
Function to be evaluated. It should take a tuple of parameter values as
an argument, and return a float value or a tuple of float values.
grd : tuple of array_like
A tuple providing an array of sample points for each axis of the grid
on which the search is to be performed.
fmin : bool, optional (default True)
Determine whether optimal function values are selected as minima or
maxima. If `fmin` is True then minima are selected.
nproc : int or None, optional (default None)
Number of processes to run in parallel. If None, the number of
CPUs of the system is used.
Returns
-------
sprm : ndarray
Optimal parameter values on each axis. If `fn` is multi-valued,
`sprm` is a matrix with rows corresponding to parameter values
and columns corresponding to function values.
sfvl : float or ndarray
Optimum function value or values
fvmx : ndarray
Function value(s) on search grid
sidx : tuple of int or tuple of ndarray
Indices of optimal values on parameter grid
"""
if fmin:
slct = np.argmin
else:
slct = np.argmax
fprm = itertools.product(*grd)
if platform.system() == 'Windows':
fval = list(map(fn, fprm))
else:
if nproc is None:
nproc = mp.cpu_count()
pool = mp.Pool(processes=nproc)
fval = pool.map(fn, fprm)
pool.close()
pool.join()
if isinstance(fval[0], (tuple, list, np.ndarray)):
nfnv = len(fval[0])
fvmx = np.reshape(fval, [a.size for a in grd] + [nfnv,])
sidx = np.unravel_index(slct(fvmx.reshape((-1, nfnv)), axis=0),
fvmx.shape[0:-1]) + (np.array((range(nfnv))),)
sprm = np.array([grd[k][sidx[k]] for k in range(len(grd))])
sfvl = tuple(fvmx[sidx])
else:
fvmx = np.reshape(fval, [a.size for a in grd])
sidx = np.unravel_index(slct(fvmx), fvmx.shape)
sprm = np.array([grd[k][sidx[k]] for k in range(len(grd))])
sfvl = fvmx[sidx]
return sprm, sfvl, fvmx, sidx
def convdicts():
"""Access a set of example learned convolutional dictionaries.
Returns
-------
cdd : dict
A dict associating description strings with dictionaries represented
as ndarrays
Examples
--------
Print the dict keys to obtain the identifiers of the available
dictionaries
>>> from sporco import util
>>> cd = util.convdicts()
>>> print(cd.keys())
['G:12x12x72', 'G:8x8x16,12x12x32,16x16x48', ...]
Select a specific example dictionary using the corresponding identifier
>>> D = cd['G:8x8x96']
"""
pth = os.path.join(os.path.dirname(__file__), 'data', 'convdict.npz')
npz = np.load(pth)
cdd = {}
for k in list(npz.keys()):
cdd[k] = npz[k]
return cdd
def netgetdata(url, maxtry=3, timeout=10):
"""
Get content of a file via a URL.
Parameters
----------
url : string
URL of the file to be downloaded
maxtry : int, optional (default 3)
Maximum number of download retries
timeout : int, optional (default 10)
Timeout in seconds for blocking operations
Returns
-------
str : io.BytesIO
Buffered I/O stream
Raises
------
urlerror.URLError (urllib2.URLError in Python 2,
urllib.error.URLError in Python 3)
If the file cannot be downloaded
"""
err = ValueError('maxtry parameter should be greater than zero')
for ntry in range(maxtry):
try:
rspns = urlrequest.urlopen(url, timeout=timeout)
cntnt = rspns.read()
break
except urlerror.URLError as e:
err = e
if not isinstance(e.reason, socket.timeout):
raise
else:
raise err
return io.BytesIO(cntnt)
class ExampleImages(object):
"""Access a set of example images."""
def __init__(self, scaled=False, dtype=None, zoom=None, gray=False,
pth=None):
"""Initialise an ExampleImages object.
Parameters
----------
scaled : bool, optional (default False)
Flag indicating whether images should be on the range [0,...,255]
with np.uint8 dtype (False), or on the range [0,...,1] with
np.float32 dtype (True)
dtype : data-type or None, optional (default None)
Desired data type of images. If `scaled` is True and `dtype` is an
integer type, the output data type is np.float32
zoom : float or None, optional (default None)
Optional support rescaling factor to apply to the images
gray : bool, optional (default False)
Flag indicating whether RGB images should be converted to grayscale
pth : string or None (default None)
Path to directory containing image files. If the value is None the
path points to a set of example images that are included with the
package.
"""
self.scaled = scaled
self.dtype = dtype
self.zoom = zoom
self.gray = gray
if pth is None:
self.bpth = os.path.join(os.path.dirname(__file__), 'data')
else:
self.bpth = pth
self.imglst = []
self.grpimg = {}
for dirpath, dirnames, filenames in os.walk(self.bpth):
# It would be more robust and portable to use
# pathlib.PurePath.relative_to
prnpth = dirpath[len(self.bpth)+1:]
for f in filenames:
fpth = os.path.join(dirpath, f)
if imghdr.what(fpth) is not None:
gpth = os.path.join(prnpth, f)
self.imglst.append(gpth)
if prnpth not in self.grpimg:
self.grpimg[prnpth] = []
self.grpimg[prnpth].append(gpth)
def images(self):
"""Get list of available images.
Returns
-------
nlst : list
A list of names of available images
"""
return self.imglst
def groups(self):
"""Get list of available image groups.
Returns
-------
grp : list
A list of names of available image groups
"""
return list(self.grpimg.keys())
def groupimages(self, grp):
"""Get list of available images in specified group.
Parameters
----------
grp : str
Name of image group
Returns
-------
nlst : list
A list of names of available images in the specified group
"""
return self.grpimg[grp]
def image(self, fname, group=None, scaled=None, dtype=None, idxexp=None,
zoom=None, gray=None):
"""Get named image.
Parameters
----------
fname : string
Filename of image
group : string or None, optional (default None)
Name of image group
scaled : bool or None, optional (default None)
Flag indicating whether images should be on the range [0,...,255]
with np.uint8 dtype (False), or on the range [0,...,1] with
np.float32 dtype (True). If the value is None, scaling behaviour
is determined by the `scaling` parameter passed to the object
initializer, otherwise that selection is overridden.
dtype : data-type or None, optional (default None)
Desired data type of images. If `scaled` is True and `dtype` is an
integer type, the output data type is np.float32. If the value is
None, the data type is determined by the `dtype` parameter passed to
the object initializer, otherwise that selection is overridden.
idxexp : index expression or None, optional (default None)
An index expression selecting, for example, a cropped region of
the requested image. This selection is applied *before* any
`zoom` rescaling so the expression does not need to be modified when
the zoom factor is changed.
zoom : float or None, optional (default None)
Optional rescaling factor to apply to the images. If the value is
None, support rescaling behaviour is determined by the `zoom`
parameter passed to the object initializer, otherwise that selection
is overridden.
gray : bool or None, optional (default None)
Flag indicating whether RGB images should be converted to grayscale.
If the value is None, behaviour is determined by the `gray`
parameter passed to the object initializer.
Returns
-------
img : ndarray
Image array
Raises
------
IOError
If the image is not accessible
"""
if scaled is None:
scaled = self.scaled
if dtype is None:
if self.dtype is None:
dtype = np.uint8
else:
dtype = self.dtype
if scaled and np.issubdtype(dtype, np.integer):
dtype = np.float32
if zoom is None:
zoom = self.zoom
if gray is None:
gray = self.gray
if group is None:
pth = os.path.join(self.bpth, fname)
else:
pth = os.path.join(self.bpth, group, fname)
try:
img = np.asarray(misc.imread(pth), dtype=dtype)
except IOError:
raise IOError('Could not access image %s in group %s' %
(fname, group))
if scaled:
img /= 255.0
if idxexp is not None:
img = img[idxexp]
if zoom is not None:
if img.ndim == 2:
img = sni.zoom(img, zoom)
else:
img = sni.zoom(img, (zoom,)*2 + (1,)*(img.ndim-2))
if gray:
img = rgb2gray(img)
return img
class Timer(object):
"""Timer class supporting multiple independent labelled timers.
The timer is based on the relative time returned by
:func:`timeit.default_timer`.
"""
def __init__(self, labels=None, dfltlbl='main', alllbl='all'):
"""Initialise timer object.
Parameters
----------
labels : string or list, optional (default None)
Specify the label(s) of the timer(s) to be initialised to zero.
dfltlbl : string, optional (default 'main')
Set the default timer label to be used when methods are
called without specifying a label
alllbl : string, optional (default 'all')
Set the label string that will be used to denote all timer labels
"""
# Initialise current and accumulated time dictionaries
self.t0 = {}
self.td = {}
# Record default label and string indicating all labels
self.dfltlbl = dfltlbl
self.alllbl = alllbl
# Initialise dictionary entries for labels to be created
# immediately
if labels is not None:
if not isinstance(labels, (list, tuple)):
labels = [labels,]
for lbl in labels:
self.td[lbl] = 0.0
self.t0[lbl] = None
def start(self, labels=None):
"""Start specified timer(s).
Parameters
----------
labels : string or list, optional (default None)
Specify the label(s) of the timer(s) to be started. If it is
``None``, start the default timer with label specified by the
``dfltlbl`` parameter of :meth:`__init__`.
"""
# Default label is self.dfltlbl
if labels is None:
labels = self.dfltlbl
# If label is not a list or tuple, create a singleton list
# containing it
if not isinstance(labels, (list, tuple)):
labels = [labels,]
# Iterate over specified label(s)
t = timer()
for lbl in labels:
# On first call to start for a label, set its accumulator to zero
if lbl not in self.td:
self.td[lbl] = 0.0
self.t0[lbl] = None
# Record the time at which start was called for this lbl if
# it isn't already running
if self.t0[lbl] is None:
self.t0[lbl] = t
def stop(self, labels=None):
"""Stop specified timer(s).
Parameters
----------
labels : string or list, optional (default None)
Specify the label(s) of the timer(s) to be stopped. If it is
``None``, stop the default timer with label specified by the
``dfltlbl`` parameter of :meth:`__init__`. If it is equal to
the string specified by the ``alllbl`` parameter of
:meth:`__init__`, stop all timers.
"""
# Get current time
t = timer()
# Default label is self.dfltlbl
if labels is None:
labels = self.dfltlbl
# All timers are affected if label is equal to self.alllbl,
# otherwise only the timer(s) specified by label
if labels == self.alllbl:
labels = self.t0.keys()
elif not isinstance(labels, (list, tuple)):
labels = [labels,]
# Iterate over specified label(s)
for lbl in labels:
if lbl not in self.t0:
raise KeyError('Unrecognized timer key %s' % lbl)
# If self.t0[lbl] is None, the corresponding timer is
# already stopped, so no action is required
if self.t0[lbl] is not None:
# Increment time accumulator from the elapsed time
# since most recent start call
self.td[lbl] += t - self.t0[lbl]
# Set start time to None to indicate timer is not running
self.t0[lbl] = None
def reset(self, labels=None):
"""Reset specified timer(s).
Parameters
----------
labels : string or list, optional (default None)
Specify the label(s) of the timer(s) to be stopped. If it is
``None``, stop the default timer with label specified by the
``dfltlbl`` parameter of :meth:`__init__`. If it is equal to
the string specified by the ``alllbl`` parameter of
:meth:`__init__`, stop all timers.
"""
# Get current time
t = timer()
# Default label is self.dfltlbl
if labels is None:
labels = self.dfltlbl
# All timers are affected if label is equal to self.alllbl,
# otherwise only the timer(s) specified by label
if labels == self.alllbl:
labels = self.t0.keys()
elif not isinstance(labels, (list, tuple)):
labels = [labels,]
# Iterate over specified label(s)
for lbl in labels:
if lbl not in self.t0:
raise KeyError('Unrecognized timer key %s' % lbl)
# Set start time to None to indicate timer is not running
self.t0[lbl] = None
# Set time accumulator to zero
self.td[lbl] = 0.0
def elapsed(self, label=None, total=True):
"""Get elapsed time since timer start.
Parameters
----------
label : string, optional (default None)
Specify the label of the timer for which the elapsed time is
required. If it is ``None``, the default timer with label
specified by the ``dfltlbl`` parameter of :meth:`__init__`
is selected.
total : bool, optional (default True)
If ``True`` return the total elapsed time since the first
call of :meth:`start` for the selected timer, otherwise
return the elapsed time since the most recent call of
:meth:`start` for which there has not been a corresponding
call to :meth:`stop`.
Returns
-------
dlt : float
Elapsed time
"""
# Get current time
t = timer()
# Default label is self.dfltlbl
if label is None:
label = self.dfltlbl
# Return 0.0 if default timer selected and it is not initialised
if label not in self.t0:
return 0.0
# Raise exception if timer with specified label does not exist
if label not in self.t0:
raise KeyError('Unrecognized timer key %s' % label)
# If total flag is True return sum of accumulated time from
# previous start/stop calls and current start call, otherwise
# return just the time since the current start call
te = 0.0
if self.t0[label] is not None:
te = t - self.t0[label]
if total:
te += self.td[label]
return te
def labels(self):
"""Get a list of timer labels.
Returns
-------
lbl : list
List of timer labels
"""
return self.t0.keys()
def __str__(self):
"""Return string representation of object.
The representation consists of a table with the following columns:
* Timer label
* Accumulated time from past start/stop calls
* Time since current start call, or 'Stopped' if timer is not
currently running
"""
# Get current time
t = timer()
# Length of label field, calculated from max label length
lfldln = max([len(lbl) for lbl in self.t0] + [len(self.dfltlbl),]) + 2
# Header string for table of timers
s = '%-*s Accum. Current\n' % (lfldln, 'Label')
s += '-' * (lfldln + 25) + '\n'
# Construct table of timer details
for lbl in sorted(self.t0):
td = self.td[lbl]
if self.t0[lbl] is None:
ts = ' Stopped'
else:
ts = ' %.2e s' % (t - self.t0[lbl])
s += '%-*s %.2e s %s\n' % (lfldln, lbl, td, ts)
return s
class ContextTimer(object):
"""A wrapper class for :class:`Timer` that enables its use as a
context manager.
For example, instead of
>>> t = Timer()
>>> t.start()
>>> do_something()
>>> t.stop()
>>> elapsed = t.elapsed()
one can use
>>> t = Timer()
>>> with ContextTimer(t):
... do_something()
>>> elapsed = t.elapsed()
"""
def __init__(self, timer=None, label=None, action='StartStop'):
"""Initialise context manager timer wrapper.
Parameters
----------
timer : class:`Timer` object, optional (default None)
Specify the timer object to be used as a context manager. If
``None``, a new class:`Timer` object is constructed.
label : string, optional (default None)
Specify the label of the timer to be used. If it is ``None``,
start the default timer.
action : string, optional (default 'StartStop')
Specify actions to be taken on context entry and exit. If
the value is 'StartStop', start the timer on entry and stop
on exit; if it is 'StopStart', stop the timer on entry and
start it on exit.
"""
if action not in ['StartStop', 'StopStart']:
raise ValueError('Unrecognized action %s' % action)
if timer is None:
self.timer = Timer()
else:
self.timer = timer
self.label = label
self.action = action
def __enter__(self):
"""Start the timer and return this ContextTimer instance."""
if self.action == 'StartStop':
self.timer.start(self.label)
else:
self.timer.stop(self.label)
return self
def __exit__(self, type, value, traceback):
"""Stop the timer and return True if no exception was raised within
the 'with' block, otherwise return False.
"""
if self.action == 'StartStop':
self.timer.stop(self.label)
else:
self.timer.start(self.label)
if type:
return False
else:
return True
def elapsed(self, total=True):
"""Return the elapsed time for the timer.
Parameters
----------
total : bool, optional (default True)
If ``True`` return the total elapsed time since the first
call of :meth:`start` for the selected timer, otherwise
return the elapsed time since the most recent call of
:meth:`start` for which there has not been a corresponding
call to :meth:`stop`.
Returns
-------
dlt : float
Elapsed time
"""
return self.timer.elapsed(self.label, total=total)