Handwritten digit recognition with ANNs

digits_ann .py :

import cv2
import cPickle
import numpy as np
import gzip


def load_data():
    mnist = gzip.open('mnist.pkl.gz', 'rb')
    training_data, classification_data, test_data = cPickle.load(mnist)
    mnist.close()
    return training_data, classification_data, test_data


def wrap_data():
    tr_d, va_d, te_d = load_data()
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_inputs, te_d[1])
    return training_data, validation_data, test_data


def vectorized_result(j):
    e = np.zeros((10, 1))
    e[j] = 1.0
    return e


def create_ANN(hidden=20):
    ann = cv2.ml.ANN_MLP_create()
    ann.setLayerSizes(np.array([784, hidden, 10]))
    ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
    ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
    ann.setTermCriteria(( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 1))
    return ann


def train(ann, samples=10000, epochs=1):
    tr, val, test = wrap_data()
    for x in xrange(epochs):
        counter = 0
        for img in tr:
            if(counter > samples):
                break
            if(counter % 1000 == 0):
                print "Epoch %d: Trained %d/%d" % (x, counter, samples)
            counter += 1
            data, digit = img
            ann.train(np.array([data.ravel()], dtype=np.float32),
            cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32))
            print "Epoch %d complete" % x
    return ann, test


def test(ann, test_data):
    sample = np.array(test_data[0][0].ravel(), dtype=np.float32).reshape(28,
    28)
    cv2.imshow("sample", sample)
    cv2.waitKey()
    print ann.predict(np.array([test_data[0][0].ravel()], dtype=np.float32))


def predict(ann, sample):
    resized = sample.copy()
    rows, cols = resized.shape
    if (rows != 28 or cols != 28) and rows * cols > 0:
        resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)
    return ann.predict(np.array([resized.ravel()], dtype=np.float32))

main.py :

import cv2
import numpy as np
import digits_ann as ANN


def inside(r1, r2):
    x1,y1,w1,h1 = r1
    x2,y2,w2,h2 = r2
    if (x1 > x2) and (y1 > y2) and (x1+w1 < x2+w2) and (y1+h1 < y2 + h2):
        return True
    else:
        return False


def wrap_digit(rect):
    x, y, w, h = rect
    padding = 5
    hcenter = x + w/2
    vcenter = y + h/2
    if (h > w):
        w = h
        x = hcenter - (w/2)
    else:
        h = w
        y = vcenter - (h/2)
    return (x-padding, y-padding, w+padding, h+padding)
ann, test_data = ANN.train(ANN.create_ANN(56), 20000)
font = cv2.FONT_HERSHEY_SIMPLEX
path = "./c.png"
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
bw = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
bw = cv2.GaussianBlur(bw, (7, 7), 0)
ret, thbw = cv2.threshold(bw, 127, 255, cv2.THRESH_BINARY_INV)
thbw = cv2.erode(thbw, np.ones((2,2), np.uint8), iterations=2)
image, cntrs, hier = cv2.findContours(thbw.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
rectangles = []
for c in cntrs:
    r = x,y,w,h = cv2.boundingRect(c)
    a = cv2.contourArea(c)
    b = (img.shape[0]-3) * (img.shape[1] - 3)
    is_inside = False
    for q in rectangles:
        if inside(r, q):
            is_inside = True
            break
    if not is_inside:
        if not a == b:
            rectangles.append(r)
for r in rectangles:
    x,y,w,h = wrap_digit(r)
    cv2.rectangle(img, (x,y), (x+w, y+h), (0, 255, 0), 2)
    roi = thbw[y:y+h, x:x+w]
    try:
        digit_class = int(ANN.predict(ann, roi.copy())[0])
    except:
        continue
    cv2.putText(img, "%d" % digit_class, (x, y-1), font, 1, (0, 255, 0))
cv2.imshow("thbw", thbw)
cv2.imshow("contours", img)
cv2.imwrite("results.jpg", img)
cv2.waitKey()



results:


compare pictures'  results:




Conclusion :

This comparison demonstrated previously reveals its insufficient .


Reference to 《learning opencv3 computer vision with Python

" mnist.pkl.gz  " downloaded from  http://www.cnblogs.com/xueliangliu/archive/2013/04/03/2997437.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章