-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Francois Petitjean
committed
Jun 10, 2014
1 parent
e39354e
commit 2366998
Showing
2 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
/******************************************************************************* | ||
* Copyright (C) 2013 Francois PETITJEAN | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, version 3 of the License. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
******************************************************************************/ | ||
|
||
import java.util.ArrayList; | ||
|
||
/** | ||
* This toy class show the use of DBA. | ||
* @author Francois Petitjean | ||
*/ | ||
public class DBA { | ||
static final long serialVersionUID = 1L; | ||
|
||
private final static int NIL = -1; | ||
private final static int DIAGONAL = 0; | ||
private final static int LEFT = 1; | ||
private final static int UP = 2; | ||
|
||
/** | ||
* This attribute is used in order to initialize only once the matrixes | ||
*/ | ||
private final static int MAX_SEQ_LENGTH = 2000; | ||
|
||
/** | ||
* store the cost of the alignment | ||
*/ | ||
private static double[][] costMatrix = new double[MAX_SEQ_LENGTH][MAX_SEQ_LENGTH]; | ||
|
||
/** | ||
* store the warping path | ||
*/ | ||
private static int[][] pathMatrix = new int[MAX_SEQ_LENGTH][MAX_SEQ_LENGTH]; | ||
|
||
/** | ||
* Store the length of the optimal path in each cell | ||
*/ | ||
private static int[][] optimalPathLength = new int[MAX_SEQ_LENGTH][MAX_SEQ_LENGTH]; | ||
|
||
|
||
/** | ||
* Dtw Barycenter Averaging (DBA) | ||
* @param C average sequence to update | ||
* @param sequences set of sequences to average | ||
*/ | ||
public static void DBA(double[] C, double[][] sequences) { | ||
|
||
final ArrayList<Double>[] tupleAssociation = new ArrayList[C.length]; | ||
for (int i = 0; i < tupleAssociation.length; i++) { | ||
tupleAssociation[i] = new ArrayList<Double>(sequences.length); | ||
} | ||
int nbTuplesAverageSeq, i, j, indiceRes; | ||
double res = 0.0; | ||
int centerLength = C.length; | ||
int seqLength; | ||
|
||
for (double[] T : sequences) { | ||
seqLength = T.length; | ||
|
||
costMatrix[0][0] = distanceTo(C[0], T[0]); | ||
pathMatrix[0][0] = DBA.NIL; | ||
optimalPathLength[0][0] = 0; | ||
|
||
for (i = 1; i < centerLength; i++) { | ||
costMatrix[i][0] = costMatrix[i - 1][0] + distanceTo(C[i], T[0]); | ||
pathMatrix[i][0] = DBA.UP; | ||
optimalPathLength[i][0] = i; | ||
} | ||
for (j = 1; j < seqLength; j++) { | ||
costMatrix[0][j] = costMatrix[0][j - 1] + distanceTo(T[j], C[0]); | ||
pathMatrix[0][j] = DBA.LEFT; | ||
optimalPathLength[0][j] = j; | ||
} | ||
|
||
for (i = 1; i < centerLength; i++) { | ||
for (j = 1; j < seqLength; j++) { | ||
indiceRes = DBA.ArgMin3(costMatrix[i - 1][j - 1], costMatrix[i][j - 1], costMatrix[i - 1][j]); | ||
pathMatrix[i][j] = indiceRes; | ||
switch (indiceRes) { | ||
case DIAGONAL: | ||
res = costMatrix[i - 1][j - 1]; | ||
optimalPathLength[i][j] = optimalPathLength[i - 1][j - 1] + 1; | ||
break; | ||
case LEFT: | ||
res = costMatrix[i][j - 1]; | ||
optimalPathLength[i][j] = optimalPathLength[i][j - 1] + 1; | ||
break; | ||
case UP: | ||
res = costMatrix[i - 1][j]; | ||
optimalPathLength[i][j] = optimalPathLength[i - 1][j] + 1; | ||
break; | ||
} | ||
costMatrix[i][j] = res + distanceTo(C[i], T[j]); | ||
} | ||
} | ||
|
||
nbTuplesAverageSeq = optimalPathLength[centerLength - 1][seqLength - 1] + 1; | ||
|
||
i = centerLength - 1; | ||
j = seqLength - 1; | ||
|
||
for (int t = nbTuplesAverageSeq - 1; t >= 0; t--) { | ||
tupleAssociation[i].add(T[j]); | ||
switch (pathMatrix[i][j]) { | ||
case DIAGONAL: | ||
i = i - 1; | ||
j = j - 1; | ||
break; | ||
case LEFT: | ||
j = j - 1; | ||
break; | ||
case UP: | ||
i = i - 1; | ||
break; | ||
} | ||
|
||
} | ||
|
||
} | ||
|
||
for (int t = 0; t < centerLength; t++) { | ||
|
||
C[t] = barycenter((tupleAssociation[t].toArray())); | ||
} | ||
|
||
} | ||
|
||
|
||
public static double Min3(final double a, final double b, final double c) { | ||
if (a < b) { | ||
if (a < c) { | ||
return a; | ||
} else { | ||
return c; | ||
} | ||
} else { | ||
if (b < c) { | ||
return b; | ||
} else { | ||
return c; | ||
} | ||
} | ||
} | ||
|
||
public static int ArgMin3(final double a, final double b, final double c) { | ||
if (a < b) { | ||
if (a < c) { | ||
return 0; | ||
} else { | ||
return 2; | ||
} | ||
} else { | ||
if (b < c) { | ||
return 1; | ||
} else { | ||
return 2; | ||
} | ||
} | ||
} | ||
|
||
public static double distanceTo(double a, double b) { | ||
return (a - b) * (a - b); | ||
} | ||
|
||
|
||
public static double barycenter(final Object... tab) { | ||
if (tab.length < 1) { | ||
throw new RuntimeException("empty double tab"); | ||
} | ||
double sum = 0.0; | ||
sum = 0.0; | ||
for (Object o : tab) { | ||
sum += ((Double) o); | ||
} | ||
return sum / tab.length; | ||
} | ||
|
||
public static void main(String [] args){ | ||
double [][]sequences = new double[100][]; | ||
for(int i=0;i<sequences.length;i++){ | ||
sequences[i] = new double[20]; | ||
for(int j=0;j<sequences[i].length;j++){ | ||
sequences[i][j] = Math.cos(Math.random()*j/20.0*Math.PI) ; | ||
} | ||
} | ||
double [] averageSequence = new double[20]; | ||
int choice = (int) Math.random()*100; | ||
for(int j=0;j<averageSequence.length;j++){ | ||
averageSequence[j] = sequences[choice][j] ; | ||
} | ||
|
||
System.out.print("["); | ||
for(int j=0;j<averageSequence.length;j++){ | ||
System.out.print(averageSequence[j]+" "); | ||
} | ||
System.out.println("]"); | ||
|
||
DBA(averageSequence, sequences); | ||
|
||
System.out.print("["); | ||
for(int j=0;j<averageSequence.length;j++){ | ||
System.out.print(averageSequence[j]+" "); | ||
} | ||
System.out.println("]"); | ||
|
||
DBA(averageSequence, sequences); | ||
|
||
System.out.print("["); | ||
for(int j=0;j<averageSequence.length;j++){ | ||
System.out.print(averageSequence[j]+" "); | ||
} | ||
System.out.println("]"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/******************************************************************************* | ||
* Copyright (C) 2013 Francois PETITJEAN, Ioannis PAPARRIZOS | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, version 3 of the License. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details.well | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
******************************************************************************/ | ||
|
||
function average = DBA(sequences) | ||
index=randi(size(sequences,1),1); | ||
average=sequences(index,:); | ||
for i=1:15 | ||
average=DBA_one_iteration(average,sequences); | ||
end | ||
end | ||
|
||
|
||
function average = DBA_one_iteration(averageS,sequences) | ||
|
||
tupleAssociation = cell (1, size(averageS,2)); | ||
for t=1:size(averageS,2) | ||
tupleAssociation{t}=[]; | ||
end | ||
|
||
costMatrix = []; | ||
pathMatrix = []; | ||
|
||
for k=1:size(sequences,1) | ||
sequence = sequences(k,:); | ||
costMatrix(1,1) = distanceTo(averageS(1),sequence(1)); | ||
pathMatrix(1,1) = -1; | ||
for i=2:size(averageS,2) | ||
costMatrix(i,1) = costMatrix(i-1,1) + distanceTo(averageS(i),sequence(1)); | ||
pathMatrix(i,1) = 2; | ||
end | ||
|
||
for j=2:size(sequence,2) | ||
costMatrix(1,j) = costMatrix(1,j-1) + distanceTo(sequence(j),averageS(1)); | ||
pathMatrix(1,j) = 1; | ||
end | ||
|
||
for i=2:size(averageS,2) | ||
for j=2:size(sequence,2) | ||
indiceRes = ArgMin3(costMatrix(i-1,j-1),costMatrix(i,j-1),costMatrix(i-1,j)); | ||
pathMatrix(i,j)=indiceRes; | ||
|
||
if indiceRes==0 | ||
res = costMatrix(i-1,j-1); | ||
elseif indiceRes==1 | ||
res = costMatrix(i,j-1); | ||
elseif indiceRes==2 | ||
res = costMatrix(i-1,j); | ||
end | ||
|
||
costMatrix(i,j) = res + distanceTo(averageS(i),sequence(j)); | ||
|
||
end | ||
end | ||
|
||
i=size(averageS,2); | ||
j=size(sequence,2); | ||
|
||
while(true) | ||
tupleAssociation{i}(end+1) = sequence(j); | ||
if pathMatrix(i,j)==0 | ||
i=i-1; | ||
j=j-1; | ||
elseif pathMatrix(i,j)==1 | ||
j=j-1; | ||
elseif pathMatrix(i,j)==2 | ||
i=i-1; | ||
else | ||
break; | ||
end | ||
end | ||
|
||
end | ||
|
||
for t=1:size(averageS,2) | ||
averageS(t) = mean(tupleAssociation{t}); | ||
end | ||
|
||
average = averageS; | ||
|
||
end | ||
|
||
function value = ArgMin3(a,b,c) | ||
|
||
if (a<b) | ||
if (a<c) | ||
value=0; | ||
return; | ||
else | ||
value=2; | ||
return; | ||
end | ||
else | ||
if (b<c) | ||
value=1; | ||
return; | ||
else | ||
value=2; | ||
return; | ||
end | ||
end | ||
|
||
end | ||
|
||
|
||
function dist = distanceTo(a,b) | ||
dist=(a-b)*(a-b); | ||
end | ||
|
||
sequences=rand(100,20); | ||
mean=DBA(sequences) | ||
plot(mean); | ||
|