Home > FMAToolbox > Analyses > ReconstructPosition.m

ReconstructPosition

PURPOSE ^

ReconstructPosition - Bayesian reconstruction of positions from spike trains.

SYNOPSIS ^

function [stats,lambda,Px] = ReconstructPosition(positions,spikes,phases,varargin)

DESCRIPTION ^

ReconstructPosition - Bayesian reconstruction of positions from spike trains.

 Instantaneous positions are reconstructed using a Bayesian algorithm.
 Instantaneous population firing rates can be estimated either over fixed time
 windows, or over fractions of the theta cycle (or of any other brain rhythm).
 Similarly, positions will be reconstructed either over time or phase windows.
 The model is first trained on a subset of the data, then tested on the rest.

 USAGE

    [stats,lambda,Px] = ReconstructPosition(positions,spikes,phases,<options>)

    positions      linear or two-dimensional positions <a href="matlab:help samples">samples</a>, in [0..1]
    spikes         list of (t,ID) couples (obtained via e.g. <a href="matlab:help GetSpikes">GetSpikes</a>,
                   using numbered output) 
    phases         optional unwrapped phase <a href="matlab:help samples">samples</a> of the LFP (see <a href="matlab:help Phase">Phase</a>)

    =========================================================================
     Properties    Values
    -------------------------------------------------------------------------
     'training'    time interval over which the model should be trained
                   (see NOTE below for defaults)
     'window'      length of the time or phase window (default = 0.020 s for
                   time, and pi/3 for phases)
     'type'        two letters (one for X and one for Y) indicating which
                   coordinates are linear ('l') and which are circular ('c')
                   - for 1D data, only one letter is used (default 'll')
     'nBins'       firing curve or map resolution (default = [200 200])
     'mode'        perform only training ('train'), only reconstruction
                   ('test'), or both ('both', default)
     'lambda'      to provide previously generated model ('test' mode)
     'Px'          to provide previously generated model ('test' mode)
    =========================================================================

   OUTPUT

     stats.positions     real position across time or phase windows
     stats.spikes        cell firing vector across time or phase windows
     stats.estimations   estimated position across time or phase windows
     stats.errors        estimation error across time or phase windows
     stats.average       average estimation error in each phase window
     stats.windows       time windows (possibly computed from phases)
     stats.phases        phase windows (empty for fixed time windows)

     lambda              firing map for each unit
     Px                  occupancy probability map

   NOTE

     Positions and spikes are interpreted differently depending on the mode:

      - For 'train', all positions and spikes are used to train the model
      - For 'test', all positions and spikes are used to test the model, and
        positions are optional (e.g. for reconstruction during sleep)
      - For 'both', the optional parameter 'training' can be used to indicate
        the training interval (default = first half of the position data)

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [stats,lambda,Px] = ReconstructPosition(positions,spikes,phases,varargin)
0002 
0003 %ReconstructPosition - Bayesian reconstruction of positions from spike trains.
0004 %
0005 % Instantaneous positions are reconstructed using a Bayesian algorithm.
0006 % Instantaneous population firing rates can be estimated either over fixed time
0007 % windows, or over fractions of the theta cycle (or of any other brain rhythm).
0008 % Similarly, positions will be reconstructed either over time or phase windows.
0009 % The model is first trained on a subset of the data, then tested on the rest.
0010 %
0011 % USAGE
0012 %
0013 %    [stats,lambda,Px] = ReconstructPosition(positions,spikes,phases,<options>)
0014 %
0015 %    positions      linear or two-dimensional positions <a href="matlab:help samples">samples</a>, in [0..1]
0016 %    spikes         list of (t,ID) couples (obtained via e.g. <a href="matlab:help GetSpikes">GetSpikes</a>,
0017 %                   using numbered output)
0018 %    phases         optional unwrapped phase <a href="matlab:help samples">samples</a> of the LFP (see <a href="matlab:help Phase">Phase</a>)
0019 %
0020 %    =========================================================================
0021 %     Properties    Values
0022 %    -------------------------------------------------------------------------
0023 %     'training'    time interval over which the model should be trained
0024 %                   (see NOTE below for defaults)
0025 %     'window'      length of the time or phase window (default = 0.020 s for
0026 %                   time, and pi/3 for phases)
0027 %     'type'        two letters (one for X and one for Y) indicating which
0028 %                   coordinates are linear ('l') and which are circular ('c')
0029 %                   - for 1D data, only one letter is used (default 'll')
0030 %     'nBins'       firing curve or map resolution (default = [200 200])
0031 %     'mode'        perform only training ('train'), only reconstruction
0032 %                   ('test'), or both ('both', default)
0033 %     'lambda'      to provide previously generated model ('test' mode)
0034 %     'Px'          to provide previously generated model ('test' mode)
0035 %    =========================================================================
0036 %
0037 %   OUTPUT
0038 %
0039 %     stats.positions     real position across time or phase windows
0040 %     stats.spikes        cell firing vector across time or phase windows
0041 %     stats.estimations   estimated position across time or phase windows
0042 %     stats.errors        estimation error across time or phase windows
0043 %     stats.average       average estimation error in each phase window
0044 %     stats.windows       time windows (possibly computed from phases)
0045 %     stats.phases        phase windows (empty for fixed time windows)
0046 %
0047 %     lambda              firing map for each unit
0048 %     Px                  occupancy probability map
0049 %
0050 %   NOTE
0051 %
0052 %     Positions and spikes are interpreted differently depending on the mode:
0053 %
0054 %      - For 'train', all positions and spikes are used to train the model
0055 %      - For 'test', all positions and spikes are used to test the model, and
0056 %        positions are optional (e.g. for reconstruction during sleep)
0057 %      - For 'both', the optional parameter 'training' can be used to indicate
0058 %        the training interval (default = first half of the position data)
0059 %
0060 
0061 % Copyright (C) 2012-2015 by Michaël Zugaro, (C) 2012 by Karim El Kanbi (initial, non-vectorized implementation),
0062 % (C) 2015 by Céline Drieu (separate training vs test), (C) 2015 by Ralitsa Todorova (log-exp fix)
0063 %
0064 % This program is free software; you can redistribute it and/or modify
0065 % it under the terms of the GNU General Public License as published by
0066 % the Free Software Foundation; either version 3 of the License, or
0067 % (at your option) any later version.
0068 
0069 % Defaults
0070 wt = 0.020; % default time window
0071 wp = pi/3; % default phase window
0072 window = [];
0073 nBins = 200;
0074 training = 0.5;
0075 type = '';
0076 nDimensions = 1;
0077 mode = 'both';
0078 
0079 % Optional parameter 'phases'
0080 if nargin == 2,
0081     phases = [];
0082 elseif nargin >= 3 && ischar(phases),
0083     varargin = {phases,varargin{:}};
0084     phases = [];
0085 end
0086 
0087 % Check number of parameters
0088 if nargin < 2 || mod(length(varargin),2) ~= 0,
0089     builtin('error','Incorrect number of parameters (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0090 end
0091 
0092 % Check parameter sizes
0093 if ~isempty(positions) && ~isdmatrix(positions),
0094     builtin('error','Incorrect positions (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0095 end
0096 if ~isdmatrix(spikes,'@2') && ~isdmatrix(spikes,'@3'),
0097     builtin('error','Incorrect spikes (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0098 end
0099 if size(positions,2) >= 3,
0100     nDimensions = 2;
0101 end
0102 if ~isempty(phases) && ~isdmatrix(phases),
0103     builtin('error','Incorrect value for property ''phases'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0104 end
0105 
0106 % Parse parameters
0107 for i = 1:2:length(varargin),
0108     if ~ischar(varargin{i}),
0109         builtin('error',['Parameter ' num2str(i+2) ' is not a property (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).']);
0110     end
0111     switch(lower(varargin{i})),
0112         case 'training',
0113             training = varargin{i+1};
0114             if ~isdvector(training,'<'),
0115                 builtin('error','Incorrect value for property ''training'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0116             end
0117         case 'window',
0118             window = varargin{i+1};
0119             if ~isdscalar(window,'>0'),
0120                 builtin('error','Incorrect value for property ''window'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0121             end
0122         case 'show',
0123             show = varargin{i+1};
0124             if ~isastring(show,'on','off'),
0125                 builtin('error','Incorrect value for property ''show'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0126             end
0127         case 'nbins',
0128             nBins = varargin{i+1};
0129             if isiscalar(nBins),
0130                 builtin('error','Incorrect value for property ''nBins'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0131             end
0132         case 'type',
0133             type = lower(varargin{i+1});
0134             if (nDimensions == 1 && ~isastring(type,'cc','cl','lc','ll')) || (nDimensions == 2 && ~isastring(type,'ccl','cll','lcl','lll','ccc','clc','lcc','llc')),
0135                 builtin('error','Incorrect value for property ''type'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0136             end
0137         case 'mode',
0138             mode = lower(varargin{i+1});
0139             if ~isastring(mode,'both','train','test'),
0140                 builtin('error','Incorrect value for property ''mode'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0141             end
0142         case 'lambda',
0143             lambda = varargin{i+1};
0144             if ~isnumeric(lambda) || length(size(lambda)) ~= 3,
0145                 builtin('error','Incorrect value for property ''lambda'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0146             end
0147         case 'px',
0148             Px = varargin{i+1};
0149             if ~isdvector(Px),
0150                 builtin('error','Incorrect value for property ''Px'' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0151             end
0152         otherwise,
0153             builtin('error',['Unknown property ''' num2str(varargin{i}) ''' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).']);
0154     end
0155 end
0156 
0157 % Defaults (window)
0158 if isempty(window),
0159     if isempty(phases),
0160         window = wt;
0161     else
0162         window = wp;
0163         if ~isiscalar((2*pi)/window),
0164             builtin('error',['Incorrect phase window: not an integer fraction of 2pi (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).']);
0165         end
0166     end
0167 end
0168 % Defaults (training)
0169 if isastring(mode,'train','both') && isdscalar(training),
0170     training = [-Inf positions(1,1)+training*(positions(end,1)-positions(1,1))];
0171 end
0172 % Defaults (type)
0173 if isempty(type),
0174     if nDimensions == 2,
0175         type = 'lll';
0176     else
0177         type = 'll';
0178     end
0179 end
0180 % Defaults (nBins)
0181 nBinsX = nBins(1);
0182 if length(nBins) > 2,
0183     nBinsY = nBins(2);
0184 else
0185     if nDimensions == 2,
0186         nBinsY = nBinsX;
0187     else
0188         nBinsY = 1;
0189     end
0190 end
0191 % Defaults (mode)
0192 if isastring(mode,'train','both') && ( exist('lambda','var') || exist('Px','var') ),
0193     warning(['Inconsistent inputs, lambda and Px will be ignored in mode ''' mode ''' (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).']);
0194     clear('lambda');clear('Px');
0195 end
0196 % Convert from legacy format for backward compatibility with previous versions of the code (spikes)
0197 if isdmatrix(spikes,'@3'),
0198     % List units, assign them an ID (number them from 1 to N), and associate these IDs with each spike
0199     % (IDs will be easier to manipulate than (group,cluster) pairs in subsequent computations)
0200     [units,~,i] = unique(spikes(:,2:end),'rows');
0201     nUnits = length(units);
0202     index = 1:nUnits;
0203     id = index(i)';
0204     spikes = [spikes(:,1) id];
0205     warning('Spikes were provided as Nx3 samples - this is now obsolete (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0206     if ~strcmp(mode,'both'),
0207         builtin('error','Obsolete format can be used only when training and test are performed together (type ''help <a href="matlab:help ReconstructPosition">ReconstructPosition</a>'' for details).');
0208     end
0209 else
0210     if strcmp(mode,'test'),
0211         nUnits = size(lambda,3);
0212     else
0213         nUnits = max(spikes(:,2));
0214     end
0215 end
0216 
0217 % TRAINING
0218 
0219 if isastring(mode,'both','train'),
0220 
0221     % Split data (training vs test)
0222     if strcmp(mode,'train'),
0223         % Mode = 'train', use all data
0224         trainingPositions = positions;
0225         trainingSpikes = spikes;
0226     else
0227         % Mode = 'both', used info from 'training' parameter
0228         trainingPositions = Restrict(positions,training);
0229         trainingSpikes = Restrict(spikes,training);
0230     end
0231     
0232     % Compute average firing probability lambda for each unit (i.e. firing maps)
0233     for i = 1:nUnits,
0234         unit = trainingSpikes(:,2) == i;
0235         s = trainingSpikes(unit,1);
0236         map = Map(trainingPositions,s,'nbins',nBins,'smooth',5,'type',type);
0237         lambda(:,:,i) = map.z;
0238     end
0239 
0240     % Compute occupancy probability P(x) (i.e. normalized occupancy map)
0241     Px = map.time;
0242     Px = Px ./ sum(Px(:));
0243     
0244 end
0245 
0246 % TEST
0247 
0248 if strcmp(mode,'train'),
0249     stats.estimations = [];
0250     stats.spikes = [];
0251    stats.errors = [];
0252    stats.average = [];
0253    stats.windows = [];
0254    stats.phases = [];
0255    return
0256 end
0257 
0258 % Split data (training vs test)
0259 if strcmp(mode,'test'),
0260     % Mode = 'test', use all data
0261     testPositions = positions;
0262     testSpikes = spikes;
0263 else
0264     % Mode = 'both', used info from 'training' parameter
0265     testPositions = positions(~InIntervals(positions,training),:);
0266     testSpikes = spikes(~InIntervals(spikes,training),:);
0267 end
0268 
0269 % Determine time windows (using unwrapped phases if necessary)
0270 if ~isempty(phases),
0271     testPhases = phases(~InIntervals(phases,training),:);
0272     if ~isempty(testPositions),
0273         drop = testPhases(:,1) < testPositions(1,1);
0274         testPhases(drop,:) = [];
0275     end
0276     startPhase = ceil(testPhases(1,2)/(2*pi))*2*pi;
0277     stopPhase = floor(testPhases(end,2)/(2*pi))*2*pi;
0278     windows = (startPhase:window:stopPhase)';
0279     stats.phases = windows;
0280     windows = Interpolate(testPhases(:,[2 1]),windows);
0281     windows = [windows(1:end-1,2) windows(2:end,2)];
0282 else
0283     stats.phases = [];
0284     if ~isempty(testPositions),
0285         windows = (testPositions(1,1):window:testPositions(end,1))';
0286     else
0287         windows = (testSpikes(1,1):window:testSpikes(end,1))';
0288     end
0289     windows = [windows(1:end-1) windows(2:end)];
0290 end
0291 nWindows = size(windows,1);
0292 
0293 stats.estimations = nan(nBinsY,nBinsX,nWindows);
0294 stats.spikes = zeros(nUnits,nWindows);
0295 % Loop over data windows
0296 for i = 1:nWindows,
0297 
0298     % Get spikes for this window
0299     s = Restrict(testSpikes,windows(i,:));
0300 
0301     if isempty(s),
0302         % No spikes: set uniform probability
0303         stats.estimations(:,:,i) = ones(nBinsY,nBinsX,1)/(nBinsX*nBinsY);
0304         continue;
0305     end
0306 
0307     % Population spike count vector
0308     stats.spikes(:,i) = Accumulate(s(:,2),1,nUnits);
0309     % To avoid 'for' loops, prepare for vector computation:
0310     % assign a spike count to each position and unit (3D array)
0311     n = reshape(repmat(stats.spikes(:,i),1,nBinsX*nBinsY)',nBinsY,nBinsX,nUnits);
0312 
0313     % For each cell i, compute P(ni|x) using a Poisson model. The direct formula is:
0314     %      Pnix = (dt*lambda).^n./factorial(n).*exp(-dt*lambda);
0315     % However, large values of (dt*lambda).^n can create overflow erros, so instead we compute
0316     % the log and then take the exponential (fix by Ralitsa Todorova)
0317     dt = windows(i,2) - windows(i,1);
0318     Pnix = exp(n.*log(dt*lambda)-logfactorial(n)-dt*lambda);
0319     % Compute P(n|x) assuming independent probabilities across units (hmm...)
0320     % i.e. P(n|x) = product over i of P(ni|x)
0321     Pnx = prod(Pnix,3);
0322 
0323     % Compute P(n) = sum over x of P(n|x)*P(x)
0324     Pn = sum(sum(Pnx.*Px));
0325 
0326     % Compute P(x|n) = P(n|x)*P(x)/P(n)
0327     Pxn = Pnx .* Px / Pn;
0328 
0329     % Store result
0330     stats.estimations(:,:,i) = Pxn;
0331 
0332 end
0333 stats.estimations = squeeze(stats.estimations);
0334 stats.windows = windows;
0335 
0336 % Estimation error
0337 
0338 stats.errors = [];
0339 stats.average = [];
0340 if nDimensions == 1 && ~isempty(testPositions),
0341     % Bin test positions and compute distance to center
0342     stats.positions = Interpolate(testPositions,windows(:,1));
0343     stats.positions(:,2) = Bin(stats.positions(:,2),[0 1],nBinsX);
0344     dx = (round(nBinsX/2)-stats.positions(:,2))';
0345     % Shift estimated position by the real distance to center
0346     stats.errors = CircularShift(stats.estimations(:,1:length(dx)),dx);
0347     % Average over one or more cycles
0348     if ~isempty(phases),
0349         k = 2*pi/window;
0350         n = floor(size(stats.errors,2)/k)*k;
0351         stats.average = reshape(stats.errors(:,1:n),nBins,k,[]);
0352         stats.average = nanmean(stats.average,3);
0353     end
0354 else
0355     warning('Computation of estimation error not yet implemented for 2D environments');
0356 end
0357 end
0358 
0359 function data = logfactorial(data);
0360 
0361 % We compute log(n!) as the sum of logs, i.e. log(n!) = sum log(i) for i=1:n
0362 % First determine the largest n in the array
0363 m = max(data(:));
0364 % Create a look-up vector of sum log(i) for each i up to the largest n
0365 sums = [0 cumsum(log(1:m))];
0366 % Look-up the value for each item in the array
0367 data(:) = sums(data+1);
0368 end

Generated on Fri 16-Mar-2018 13:00:20 by m2html © 2005