三次樣條曲線 python實現

reference:https://blog.csdn.net/deramer1/article/details/79034201

import numpy as np
import matplotlib.pyplot as plt
from pylab import mpl

"""
三次樣條實現:
函數的自變量x:3, 4.5, 7, 9
函數的因變量y:2.5, 10, 2.5, 10.5
"""
x = [3, 4.5, 7, 9, 13]
y = [2.5, 10, 2.5, 10.5, 2]

"""
功能:完後對三次樣條函數求解方程參數的輸入
參數:要進行三次樣條曲線計算的自變量
返回值:方程的參數
"""


def calculateEquationParameters(x):

    '''
    代解未知數爲a1,b1,c1,d1,a2,b2,c2,d2,a3,b3,c3,d3

    :param x:
    :return: parameter
    '''
    # parameter爲二維數組,用來存放參數,sizeOfInterval是用來存放區間的個數
    parameter = []
    n = len(x)
    sizeOfInterval = n - 1  # 線段數
    i = 1
    # 首先輸入方程兩邊相鄰  節點  處函數值相等的方程爲2(n-2)=4個方程; n個點要減去兩個端點,然後乘以2
    while i <= n-2:
        # x[i]代入點i前一條線段的參數
        data = np.zeros(sizeOfInterval * 4)
        data[(i - 1) * 4] = x[i] * x[i] * x[i]
        data[(i - 1) * 4 + 1] = x[i] * x[i]
        data[(i - 1) * 4 + 2] = x[i]
        data[(i - 1) * 4 + 3] = 1
        # x[i]代入點i後一條線段的參數
        data1 = np.zeros(sizeOfInterval * 4)
        data1[i * 4] = x[i] * x[i] * x[i]
        data1[i * 4 + 1] = x[i] * x[i]
        data1[i * 4 + 2] = x[i]
        data1[i * 4 + 3] = 1
        parameter.append(data)
        parameter.append(data1)
        i += 1
    # 輸入   端點   處的函數值。爲2個方程
    data = np.zeros(sizeOfInterval * 4)
    # 第一條線段
    data[0] = x[0] * x[0] * x[0]
    data[1] = x[0] * x[0]
    data[2] = x[0]
    data[3] = 1
    parameter.append(data)
    data = np.zeros(sizeOfInterval * 4)
    # 最後一條線段
    data[-4] = x[-1] * x[-1] * x[-1]
    data[-3] = x[-1] * x[-1]
    data[-2] = x[-1]
    data[-1] = 1
    parameter.append(data)

    # 節點點函數一階導數值相等爲n-2=2個方程。
    i = 1
    while i <= n-2:
        data = np.zeros(sizeOfInterval * 4)
        data[(i - 1) * 4] = 3 * x[i] * x[i]
        data[(i - 1) * 4 + 1] = 2 * x[i]
        data[(i - 1) * 4 + 2] = 1
        data[i * 4] = -3 * x[i] * x[i]
        data[i * 4 + 1] = -2 * x[i]
        data[i * 4 + 2] = -1
        # temp = data[2:]
        # parameter.append(temp)
        parameter.append(data)
        i += 1

    # 節點函數二階導數值相等爲n-2=2個方程。且端點處的函數值的二階導數爲零,爲2個方程。
    i = 1
    while i <= n-2:
        data = np.zeros(sizeOfInterval * 4)
        data[(i - 1) * 4] = 6 * x[i]
        data[(i - 1) * 4 + 1] = 2
        data[i * 4] = -6 * x[i]
        data[i * 4 + 1] = -2
        # temp = data[2:]
        # parameter.append(temp)
        parameter.append(data)
        i += 1

    # 總共2(n-1)-2=10個方程
    parameter = np.array(parameter)
    return parameter[:, 2:]  # 去掉前兩個a1,b1


"""
功能:計算樣條函數的係數。
參數:parametes爲方程的係數,y爲要插值函數的因變量。
返回值:三次插值函數的係數。
"""


def solutionOfEquation(parametes, y):
    n = len(x)
    sizeOfInterval = n - 1
    result = np.zeros(sizeOfInterval * 4 - 2)
    i = 1
    # 節點處方程右邊
    while i < sizeOfInterval:  # result[0,1,2,3]
        result[(i - 1) * 2] = y[i]
        result[(i - 1) * 2 + 1] = y[i]
        i += 1
    # 起末端點處方程右邊
    result[(sizeOfInterval - 1) * 2] = y[0]
    result[(sizeOfInterval - 1) * 2 + 1] = y[-1]

    a = np.array(parametes)
    b = np.array(result)
    return np.linalg.solve(a, b)  # 解線性方程組


"""
功能:根據所給參數,計算三次函數的函數值:
參數:parameters爲二次函數的係數,x爲自變量
返回值:爲函數的因變量
"""


def calculate(paremeters, x):
    result = []
    for data_x in x:
        y = paremeters[0] * data_x * data_x * data_x + paremeters[1] * data_x * data_x + paremeters[2] * data_x + paremeters[3]
        result.append(y)
    return result


"""
功能:採點
參數:x
返回值:採取點和採取點的函數值
"""


def grasp_sample(x):
    n = len(x)
    i = 1
    # result 爲求解出來後的a1,b1,c1,d1,a2,b2,c2,d2,a3,b3,c3,d3
    result = [0, 0]
    temp = solutionOfEquation(calculateEquationParameters(x), y)
    result.extend(temp)
    samples_x = []
    samples_y = []
    # n-1段曲線
    while i < n:
        sample_x = np.arange(x[i-1], x[i], 0.01)
        sample_y = calculate(result[(i-1)*4:i*4], sample_x)
        samples_x.extend(sample_x)
        samples_y.extend(sample_y)
        i = i+1
    samples_x.append(x[n-1])
    samples_y.extend(calculate(result[-4:], [x[n-1]]))
    return [samples_x, samples_y]

"""
功能:將函數繪製成圖像
參數:data_x,data_y爲離散的點.new_data_x,new_data_y爲由拉格朗日插值函數計算的值。x爲函數的預測值。
返回值:空
"""


def Draw(data_x, data_y, new_data_x, new_data_y):
    plt.plot(new_data_x, new_data_y, label="擬合曲線", color="black")
    plt.scatter(data_x, data_y, label="離散數據", color="red")
    mpl.rcParams['font.sans-serif'] = ['SimHei']
    mpl.rcParams['axes.unicode_minus'] = False
    plt.title("三次樣條函數")
    plt.legend(loc="upper left")
    plt.show()


samples = grasp_sample(x)
# print(samples[0])
# print(samples[1])
Draw(x, y, samples[0], samples[1])

樣本點:

運行結果圖

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章