基於表達式模版(expression template)的惰性求值(lazy evaluation)

代碼碰到惰性求值的模版編程技術,動手實踐下。

  • 目標:設計一個支持任意四則表達式運算的向量類(Vec)

定義一個類Vec

template<typename T>
struct Vec {    
    Vec() : len(0), dptr(0), owned(false) {}
    Vec(T* dptr, int len) : len(len), dptr(dptr), owned(false) {}
    Vec(int len) : len(len), owned(true) {
        dptr = new T[len];
    }

    ~Vec() {
        if (owned) delete[] dptr;   
        dptr = 0;
    }

    inline int length() {
        return len;
    }

    inline T operator[] (int i) const {
        return dptr[i];
    }

    inline T & operator[] (int i) {
        return dptr[i];
    }

    inline Vec & operator= (const Vec<T>& src) {        
#pragma omp parallel for
        for (int i = 0; i < len; i++) {
            dptr[i] = src[i];
        }
        return *this;
    }

private:

    int len;
    T* dptr;
    bool owned;
};

支持任意長度四則運算,如

Vec<double> a, b, c, d, e, f;
a = b + c - d / e * f;
  • first try
    重定義運算符+、-、*、/。這是C++的一個常規解決方案
template<T>
Vec<T> & operator+ (Vec<T> &lhs, Vec<T> &rhs) {
    // assert(lhs.size() == rhs.size());
    Vec<T> rs(lhs.size());
    for (int i = 0; i < lhs.size(); i++) {
        rs[i] = lhs[i] + rhs[i];
    }
    return rs;
}

這一方案的問題在於,運算過程中需要分配臨時空間。此外,還存在多次函數調用。

  • 更好的方案——表達式模版
    這裏, 我們使用表達式模版實現運算的惰性求值。不僅不需要額外的空間,也減少函數調用開銷。

  • 表達式模版
    表達式(等號右邊部分)可以用一個表達式樹抽象的表示。其中,葉子節點(終結符)是我們的向量,它也是一種特殊的表達式。樹的根節點是運算符,左右子樹是子表達式。

  • 實現
    – 首先,我們定義表達式。它是所有可能表達式的父類。

template<typename RealType>
struct Exp {
    inline const RealType& self() const {
        return *static_cast<const RealType*>(this);
    }
};

實質上,它只是一個wrapper。它的作用是,當我們需要將一個對象做爲表達式傳遞是時,它將其他封裝。在傳遞之後,通過self()函數再得到原來的對象。

例如,我們如下定義Vec:

template<T>
struct Vec: Exp<Vec<T>> {
    ... ..
}

對比常規定義:

template<T>
struct Vec {
    ... ...
}

Vec的完整定義如下:

template<typename T>
struct Vec : public Exp < Vec<T> > {
    typedef T value_type;

    int len;
    T* dptr;
    bool owned;

    Vec(T* dptr, int len) : len(len), dptr(dptr), owned(false) {}
    Vec(int len) : len(len), owned(true) {
        dptr = new T[len];
    }

    ~Vec() {
        if (owned) delete[] dptr;   
        dptr = 0;
    }

    inline T operator[] (int i) const {
        return dptr[i];
    }

    template<typename EType>
    inline Vec & operator= (const Exp<EType>& src_) {
        const EType &src = src_.self();
#pragma omp parallel for
        for (int i = 0; i < len; i++) {
            dptr[i] = src[i];
        }
        return *this;
    }
};

唯一需要解釋是賦值操作

template<typename EType>
    inline Vec & operator= (const Exp<EType>& src_) {
        const EType &src = src_.self();
#pragma omp parallel for
        for (int i = 0; i < len; i++) {
            dptr[i] = src[i];
        }

        return *this;
    }

Vec接受一個表達式,表達式必須提供operator[]函數,返回相應的值。正是由於[]的定義,使得惰性求值成爲可能。

以上,我們已經有了葉子節點(Vec)。要構造表達式樹,我們要定義每個中間節點和根節點。它們本質上是二元操作。

template<typename Op, typename TLhs, typename TRhs>
struct BinaryOpExp : Exp < BinaryOpExp<Op, TLhs, TRhs> > {
    const TLhs &lhs;
    const TRhs &rhs;
    typedef typename ReturnType<TLhs, TRhs>::value_type value_type;

    BinaryOpExp(const TLhs &lhs, const TRhs &rhs) : lhs(lhs), rhs(rhs) {}

    inline value_type operator[] (int i) const {
        return Op::eval(lhs[i], rhs[i]);
    }
};

其中,ReturnType 只是一個簡單的功能模版。

template<typename TLhs, typename TRhs>
struct ReturnType {
    typedef typename TLhs::value_type value_type;
};

作爲表達式,BinaryOpExp 重載了我們需要的 operator[]。

最後要做的是,重載+號等運算符

template<typename T>
struct add {
    inline static T eval(const T& lhs, const T& rhs) {
        return lhs + rhs;
    }
};

template<typename TLhs, typename TRhs>
inline BinaryOpExp<add<typename ReturnType<TLhs, TRhs>::value_type>, TLhs, TRhs>
operator+ (const Exp<TLhs> &lhs, const Exp<TRhs> &rhs) {
    return BinaryOpExp<detail::add<typename ReturnType<TLhs, TRhs>::value_type>, TLhs, TRhs>(lhs.self(), rhs.self());
}

一個簡單的測試:

int main() {
    const int n = 3;
    double sa[n] = { 1, 2, 3 };
    double sb[n] = { 2, 3, 4 };
    double sc[n] = { 3, 4, 5 };
    double sd[n] = { 4, 5, 6 };
    double se[n] = { 5, 6, 7 };
    double sf[n] = { 6, 7, 8 };

    Vec<double> A(sa, n), B(sb, n), C(sc, n), D(sd, n), E(se, n), F(sf, n);

    // run expression, this expression is longer:)
    A = B + C - D * E / F;
    for (int i = 0; i < n; ++i) {
        printf("%d:%f == %f + %f - %f * %f / %f == %f\n", i,
            A[i], B[i], C[i], D[i], E[i], F[i], B[i] + C[i] - D[i] * E[i] / F[i]);
    }

    return 0;
}

輸出結果:

0:1.666667 == 2.000000 + 3.000000 - 4.000000 * 5.000000 / 6.000000 == 1.666667
1:2.714286 == 3.000000 + 4.000000 - 5.000000 * 6.000000 / 7.000000 == 2.714286
2:3.750000 == 4.000000 + 5.000000 - 6.000000 * 7.000000 / 8.000000 == 3.750000

除基本的+、-、*、\之外,我們還可以自定義二元運算符。

template<typename Op, typename TLhs, typename TRhs>
inline BinaryOpExp<Op, TLhs, TRhs> F(const Exp<TLhs> &lhs, const Exp<TRhs> &rhs) {
    return BinaryOpExp<Op, TLhs, TRhs>(lhs.self(), rhs.self());
}

類似的,我們可以定義一元操作運算

template<typename Op, typename T>
struct UnaryOpExp : Exp < UnaryOpExp<Op, T> > {
    const T &arg;

    typedef typename T::value_type value_type;

    UnaryOpExp(const T &arg) : arg(arg) {}

    inline value_type operator[] (int i) const {
        return Op::eval(arg[i]);
    }
};

template<typename Op, typename T>
inline UnaryOpExp<Op, T> F(const Exp<T> &arg) {
    return UnaryOpExp<Op, T>(arg.self());
}

我們重載sin函數

template<typename T>
struct sinOp {
    inline static T eval(const T& arg) {
        return std::sin(arg);
    }
};

template<typename T>
UnaryOpExp<detail::sinOp<typename T::value_type>, T> sin(const Exp<T> &arg) {
    return UnaryOpExp<detail::sinOp<typename T::value_type>, T>(arg.self());
}

一個簡單的測試:

int main() {
    const int n = 3;
    double sa[n] = { 1, 2, 3 };
    double sb[n] = { 2, 3, 4 };
    double sc[n] = { 3, 4, 5 };

    Vec<double> A(sa, n), B(sb, n), C(sc, n);

    A = sin(B) + sin(C);
    for (int i = 0; i < n; ++i) {
        printf("%d:%f == sin(%f) + sin(%f) == %f\n", i, A[i], B[i], C[i], sin(B[i]) + sin(C[i]));
    }   
    return 0;
}

輸出結果如下:

0:1.050417 == sin(2.000000) + sin(3.000000) == 1.050417
1:-0.615682 == sin(3.000000) + sin(4.000000) == -0.615682
2:-1.715727 == sin(4.000000) + sin(5.000000) == -1.715727
發佈了31 篇原創文章 · 獲贊 162 · 訪問量 24萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章