Source code for spike.Algo.CS_transformations

#!/usr/bin/env python 
# encoding: utf-8

"""
Authors: Marc-André
Last modified: 2011/11/18.

adapted by Lionel on 2013-3-6.
Copyright (c) 2011 IGBMC. All rights reserved.
"""

from __future__ import print_function
import numpy as np
#import scipy.fftpack as fft
import numpy.fft as fft # numpy scipy seem equivalent for ifft and fft.. 


[docs]class transformations(object): """ this class contains methods which are tools to generate transform and ttransform. ttrans form data to image trans from image to data. """ def __init__(self, size_image, size_mesure, sampling = None, debug=0): """ size_image and size_mesure are sizes of x and s space all other fields are meant to be overloaded after creation direct transform refers to S => X // image => Data transform """ self.size_image = size_image self.size_mesure = size_mesure #print "data size is ",self.size_mesure #print "image size is ",self.size_image self.pre_ft = self.Id # applied before FT (in the direct transform) self.tpre_ft = self.Id # its transpose (so applied after FT in ttransform) self.post_ft = self.Id # applied after FT (in the direct transform) self.tpost_ft = self.Id # its transpose (so applied before FT in ttransform) self.ft = fft.ifft # trans: from image to data self.tft = fft.fft # ttrans: from data to image. self.sampling = sampling # sampling vector if not none self.debug = debug
[docs] def report(self): "dumps content" for i in dir(self): if not i.startswith('_') : print(i, getattr(self,i))
[docs] def Id(self, x): return x
[docs] def check(self): if self.debug: print(""" size_image: %d - size_mesure: %d sampling %s """%(self.size_image, self.size_mesure, str(self.sampling))) if self.sampling is not None: assert(len(self.sampling) == self.size_mesure) assert(max(self.sampling) <= self.size_image)
# assert( (self.pre_ft == self.Id and self.tpre_ft == self.Id) or \ # (self.pre_ft != self.Id and self.tpre_ft != self.Id) ) # if one, then both ! # assert( (self.post_ft == self.Id and self.tpost_ft == self.Id) or \ # (self.post_ft != self.Id and self.tpost_ft != self.Id) ) # if one, then both !
[docs] def zerofilling(self,x): # eventually zerofill # xx = np.zeros(self.size_image, dtype = 'complex') xx = np.zeros(self.size_image, dtype=x.dtype) xx[:len(x)] = x[:] x = xx return x
[docs] def sample(self, x): """ apply a sampling function - using self.sampling """ #print self.sampling return x[self.sampling]
[docs] def tsample(self, x): """ transpose of the sampling function """ # xx = np.zeros(self.size_image,'complex128') xx = np.zeros(self.size_image, dtype=x.dtype) #print xx.dtype, x.dtype, self.sampling.dtype xx[self.sampling] = x return xx
[docs] def transform(self, s): """ transform to data. Passing from s (image) to x (data) pre_ft() : s->s ft() : s->x - fft.ifft by default - should not change size post_ft() : x->x - typically : broadening, sampling, truncating, etc... """ if self.debug: print('entering trans', s.shape, s.dtype) if self.pre_ft != self.Id: s = self.pre_ft(s) if self.debug: print('trans pre_ft', s.shape, s.dtype) x = self.ft(s) if self.post_ft != self.Id: x = self.post_ft(x) if self.debug: print('trans post_ft', x.shape, x.dtype) if self.sampling is not None: x = self.sample(x) if self.debug: print('trans sample', x.shape, x.dtype) if self.size_mesure != len(x): # eventually truncate x = x[0:self.size_mesure] if self.debug: print('trans trunc', x.shape, x.dtype) if self.debug: print('exiting trans', x.shape, x.dtype) return x
[docs] def ttransform(self, x): """ the transpose of transform Passing from x to s (data to image) """ if self.debug: print('entering ttrans', x.shape, x.dtype) if self.sampling is not None: if self.debug: print('ttrans sample') x = self.tsample(x) elif self.size_image != len(x): # eventually zerofill if self.debug: print('ttrans zerofill',len(x),self.size_image) x = self.zerofilling(x) if self.tpost_ft != self.Id: if self.debug: print('ttrans tpost_ft') x = self.tpost_ft(x) s = self.tft(x) if self.tpre_ft != self.Id: if self.debug: print('ttrans tpre_ft') s = self.tpre_ft(s) if self.debug: print('exiting ttrans', s.shape, s.dtype) return s
[docs]def sampling_load(addr_sampling_file): ''' Loads a sampling protocole from a list of indices stored in a file named addr_sampling_file returns an nparray with the sampling scheme. i.e. if b is a full dataset, b[sampling] is the sampled one ''' with open(addr_sampling_file, 'r') as F: # print "### reads the sampling file ", addr_sampling_file param = read_param(F) F.seek(0) sampling = read_data(F) # print "sampling[0], sampling[len(sampling)/2], sampling[-1]",sampling[0], sampling[len(sampling)/2], sampling[-1] return sampling, param
[docs]def read_data(F): ''' Reads data from the sampling file, used by sampling_load() ''' data = [] for l in F: if not l.startswith("#"): if l.strip() == "": continue data.append(int(l)) return np.array(data)
[docs]def read_param(F): ''' Reads the sampling parameters. used by sampling_load() ''' """ given F, an opend file , retrieve all parameters found in file header read_param returns values in a plain dictionnary """ dic = {} for l in F: # print l.rstrip() if not l.startswith("#"): break v = l.rstrip().split(':') # remove trailing chars and split around : if len(v)<2: # comment lines pass #print l else: entry = v[0][1:].strip() dic[entry] = v[1].lstrip() # metadata lines return dic
if __name__ == "__main__": tr = transformations(2000, 1000) tr.report()