DTW算法實現

#!/usr/bin/env python3
# -*- coding:UTF-8 -*-
##########################################################################
# File Name: DTW.py
# Author: stubborn vegeta
# Created Time: 2019年12月08日 星期日 00時06分14秒
##########################################################################
import numpy as np
import matplotlib.pyplot as plt
import random,sys

def listGenerate(step):
    alist = np.arange(0, 2*np.pi, step)
    alist = np.array(alist)
    return alist

def dtw(l0, N, l1, M):
    L0 = np.tile(l0, (N,1))
    L1 = np.tile(l1, (M,1)).T
    distance = abs(L0-L1)
    print('distance\n',distance)
    D = np.ones((N,M))
    D[0,0] = distance[0,0]
    for i in range(1,N):
        D[i,0] = D[i-1,0] + distance[i,0]
    for j in range(1,M):
        D[0,j] = D[0,j-1] + distance[0,j]
    for i in range(1,N):
        for j in range(1,M):
            choice = min([D[i,j-1], D[i-1,j-1], D[i-1,j]])
            D[i,j] = distance[i,j]+ choice
    return D,distance

def path(distance, D, N, M):
    tempindex = [(0,0)]
    node = []
    tempD = distance[0,0]
    TEMPD = []
    TEMPD.append(tempD)
    i = j = 0
    n = 1  
    nn = []  
    while i<M-1 or j<N-1:
        if i == M-1 and j != N -1:
            j = j + 1
        elif j == N-1 and i != M-1:
            i = i + 1
        elif j != N-1 and i != M-1:
            choice = min([distance[i+1,j+1], distance[i,j+1], distance[i+1,j]])
            if choice == distance[i+1,j+1]:
                if choice == distance[i, j+1] and choice == distance[i+1, j]:
                    node.append((i,j+1))
                    node.append((i+1,j))
                    nn.append(n)
                    nn.append(n)
                elif choice == distance[i,j+1]:
                    node.append((i,j+1))
                    nn.append(n)
                elif choice == distance[i+1,j]:
                    node.append((i+1,j))
                    nn.append(n)            
                i = i + 1
                j = j + 1

            elif choice == distance[i,j+1]:
                if choice == distance[i+1,j]:
                    node.append((i+1,j))
                    nn.append(n)            
                j = j + 1
            else:
                i = i + 1                  
        n = n + 1
        tempindex.append((i,j))
        tempD = distance[i,j]
        TEMPD.append(tempD)

        if (i == M-1 and j == N-1):          
            if sum(TEMPD) == D[-1,-1]:
                index = tempindex
                print(sum(TEMPD),'\t',D[-1,-1])
                return np.array(index).T
            else:
                try:
                    print(sum(TEMPD),'\t',D[-1,-1])
                    n = nn.pop()
                    del(tempindex[n:])
                    del(TEMPD[n:])
                    (i,j) = node.pop()
                    tempindex.append((i,j))
                    TEMPD.append(distance[i,j])
                except IndexError:
                    return np.array(tempindex).T

def draw(t1, x, t2, y, index):
    grid = plt.GridSpec(4, 4, wspace=0.5, hspace=0.5)           # 分割網格
    mainAX = plt.subplot(grid[0:3,1:4])                          
    plt.pcolor(colormat,edgecolors='black',linewidths=4)
    plt.plot(index[0],index[1],'w-o')

    leftAX = plt.subplot(grid[0:3,0],yticks=[])  
    plt.plot(x,t1,'b')
    leftAX.invert_xaxis()                                       # 翻轉x軸

    bottomAX = plt.subplot(grid[3,1:4],xticks=[])
    plt.plot(t2,y,'r')
    bottomAX.invert_yaxis()                                     # 翻轉y軸

    plt.savefig('path.jpg')
    plt.show()
    connect = []
    for i in index.T:
        connect.append(([t2[i[0]],t1[i[1]]], [y[i[0]],x[i[1]]+5]))
    plt.figure()
    plt.scatter(t1,x+5,s=120)
    plt.scatter(t2,y,s=120)
    plt.plot(t1,x+5, 'black')
    plt.plot(t2,y, 'r')
    for point in connect:
        plt.plot(point[0],point[1],'b--')
    plt.yticks([])
    plt.savefig('wrap.jpg')
    plt.show()


if __name__ == '__main__':
    t0 = listGenerate(0.2)
    t1 = listGenerate(0.24)
    N, M = len(t0), len(t1)
    noise = np.array([random.randint(-10,10)/100 for i in range(M)])
    y0 = np.sin(t0)                     #(63,1)
    y1 = np.sin(t1)                     #(53,1)
    y1 = y1 + noise
    line0 = np.vstack((t0,y0))
    line1 = np.vstack((t1,y1))
    [D, distance] = dtw(y0, M, y1, N)
    # print('D',D.shape)
    colormat = np.ones((N-1,M-1))
    index = path(distance,D,N,M)
    # print(index)
    draw(t0,y0,t1,y1,index)

在這裏插入圖片描述
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章