mxnet教程1

import mxnet as mx
#%matplotlib inline
import os
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# 從內存中讀取數據
def test1():
    data = np.random.rand(100,3)
    label = np.random.randint(0, 10, (100,))
    print(data.shape)
    print(label.shape)
    data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
    for batch in data_iter:
        #print([batch.data, batch.label, batch.pad])
        print(type(batch))
        print(batch)
        #input()
        #print(batch.data.shape)
        #print(batch.label.shape)
        print(batch.pad)
        print()
    
#從CSV文件中讀取數據    
def test2():
    #lets save `data` into a csv file first and try reading it back
    data = np.random.rand(100,3)
    label = np.random.randint(0, 10, (100,))
    np.savetxt('data.csv', data, delimiter=',')
    np.savetxt('label.csv', label, delimiter=',')
    
    data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
    for batch in data_iter:
        #print([batch.data, batch.pad])
        print(batch.data)
        print(batch.pad)
            
        
#創建一個簡單的迭代器:
class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = zip(data_names, data_shapes)
        self._provide_label = zip(label_names, label_shapes)
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration
        
# 自定義迭代器
def test3():
    num_classes = 10
    net = mx.sym.Variable('data')
    net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
    net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
    net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
    net = mx.sym.SoftmaxOutput(data=net, name='softmax')
    print(net.list_arguments())
    print(net.list_outputs())
    
    import logging
    logging.basicConfig(level=logging.INFO)
    
    n = 32
    data_iter = SimpleIter(['data'], [(n, 100)],
                      [lambda s: np.random.uniform(-1, 1, s)],
                      ['softmax_label'], [(n,)],
                      [lambda s: np.random.randint(0, num_classes, s)])
    
    #mod = mx.mod.Module(symbol=net, context=mx.gpu(), label_names=None)
    mod = mx.mod.Module(symbol=net, 
                        context=mx.gpu(0),
                        data_names=['data'],
                        label_names=['softmax_label'])
    mod.fit(data_iter, num_epoch=5)
    
  
def test4():
    a = [1,2,3]
    b = [4, 5, 6]
    c = [4,5,6,7,8]
    zip_ab = zip(a, b)
    zip_ac = zip(a, c)
    #print("len(zip_ab):", len(zip_ab))
    #print("len(zip_ac):", len(zip_ac))
    print(zip_ab)
    print(zip_ac)
    for element in zip_ab:
        print(element)
    for element in zip_ac:
        print(element)    
        
    unzip_ac = zip(*zip_ac)
    print(unzip_ac)
    for element in unzip_ac:
        print(element)    
      
#test1()        
#test2()        
test3()        
#test4()        

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章