libtorch c++ 自定義數據類型並使用

上述幾節主要介紹瞭如何利用MNIST數據集搭建多層神經網絡並完成模型的訓練,用到的數據都是torch::data::dataset自帶的數據集,這節介紹如何根據實際情況創建自己的數據集。

(1)自定義類型的設計方法

實際上,自定義數據類型很簡單,只需要繼承torch::data::datasets::Dataset<self, SingleExample>,同時重寫get(size_t index)以獲取指定元素和樣本總數size()即可。

Dataset類繼承的定義在base.h中,它繼承自BatchDataset,它支持隨機方式獲取元素,也支持批量的方式獲取元素。僅需要重寫兩個函數:

第一個是get(size_t index),用來獲取指定編號的樣本

 /// Returns the example at the given index.
  virtual ExampleType get(size_t index) = 0;

第二個是size(),用來獲取總的樣本數

/// Returns the size of the dataset, or an empty optional if it is unsized.
  virtual optional<size_t> size() const = 0;

對於返回的元素Example是模板類型,它是個數據項data和標籤項target的組合,而data和target都是張量Tensor,數據項data用於給傳遞給神經網絡正向傳播,數據標籤項target用於和正向傳播的結果output一起結算損失大小。

換句話說,自定義的類,只要繼承了Dataset,同時重寫了get函數,其他部分可以任意設計。這裏介紹兩種可能的情況,

(1)比如自定義類型中存儲了數據項坐在的文件夾和標籤項所在的文件夾,兩個文件夾中文件數相同,比如都是10000個,文件中存儲了地質模型和地震正演模擬結果,具有一一對應關係,自定義類型按照列表方式分別存儲兩個文件件中文件名,通過索引的方式可以獲取文件名,然後get(size_t index)可以根據指定的文件名讀取文件內容並返回張量格式的數據。

(2)比如自定義類型中存儲了文件,文件中存儲了所有的樣本數據,比如列存儲方式的測井數據,具有m行,n列數據,每行數據中前n-1列爲數據項,第n列爲標籤項,讀取內容後以兩個列表形式存儲數據,通過指定索引號,可以通過get(size_t index)方式獲取張量格式的有n-1個數據組成的數據項和1個數據組成的標籤項。

(3)再比如,自定義類型中存儲了一個地震數據體,同時一個文本文件存儲了其他方式解釋的沉積相或產量等,分別是沉積相類型或產量及其空間座標xyz,這些數據按行存儲,通過指定行號,的get方法可以從地震數據體重獲取點周圍的地震數據組成數據項和沉積相、產量自身作爲標籤項。

總之,自定義類型的格式各種各樣,不一一列舉,

(2)自定義類型的通用格式

假設自定義類型名稱爲MyDataset,構造函數中分別制定兩種數據所在的文件位置,形成兩個文件名組成的字符串列表,或者直接傳遞兩個文件名字符串列表給構造函數。

如果單個樣本數據量不大,可以在自定義的類型中設置兩個張量作爲數據項和標籤項,其中的數據項爲states_, 標籤項爲labels_, 兩者都是Tensor類型,把該批數據都存儲在兩個張量中。

如果單個樣本數據量很大,內存不足,數據類型自身不存儲整個批次的數據,只存儲每個樣本的文件名稱列表,然後只在get(size_t index)中創建單個樣本的數據,並通過make_data_loader的方式疊置成整個批次的數據。

下面的代碼展示後者,即在數據集內部只存儲文件名,實際在data_loader 時再真正讀取數據,其中read_source和read_target函數的具體內容不在展示。

class CustomDataset :public torch::data::Dataset<CustomDataset> {
private:
	//declare 2 vectors for sources and targets
	//std::vector<torch::Tensor> sources, targets;
	std::vector<std::string>sourceFiles, targetFiles;
	int rows;
	int cols;
public:
	//constructor
	CustomDataset(std::vector<std::string> sources_list, std::vector<std::string>targets_list, int rows, int cols) {
		if (sources_list.size() != targets_list.size()) {
			std::cout << "sources_list size must be equal as target_list size" << std::endl;
			return ;
		}
		//sources = process_sources(sources_list, rows, cols);
		//targets = process_targets(targets_list, rows, cols);
		this->sourceFiles = sources_list;
		this->targetFiles = targets_list;
		this->rows = rows;
		this->cols = cols;

	};

	//override get() function to return tensor at location index
	torch::data::Example<>get(size_t index)override {
		/*torch::Tensor sample_source = sources.at(index);
		torch::Tensor sample_target = targets.at(index);*/
		torch::Tensor sample_source = read_source(sourceFiles[index], rows, cols);
		torch::Tensor sample_target = read_target(targetFiles[index], rows, cols);
		std::cout << index << std::endl;
		std::cout << sample_source.max() << std::endl;
		std::cout << sample_target.max() << std::endl;
		/*torch::Tensor sample_source = read_source(sourceFiles[index]);
		torch::Tensor sample_target = read_target(targetFiles[index]);*/
		return { sample_source.clone(), sample_target.clone() };
	};

	//return the length of the data
	torch::optional<size_t>size()const override {
		//return targets.size();
		return sourceFiles.size();
	};
};

下面是主函數中對自定義數據的調用和打印輸出以作測試

int main()
{

	std::vector<std::string> home_root, sourceList, targetList;
	home_root.push_back("floder1");
	home_root.push_back("floader2");

	for (int i = 0; i < 5000; i++) {
		std::string dataFile = home_root[0] + "\\case_" + std::to_string(i) + ".txt";
		std::string targetFile = home_root[1] + "\\case_" + std::to_string(i) + ".txt";
		if (std::filesystem::exists(dataFile) && std::filesystem::exists(targetFile)) {
			sourceList.push_back(dataFile);
			targetList.push_back(targetFile);			
		}
		else{
			std::cout << dataFile << " exist status:" << std::filesystem::exists(dataFile) << std::endl;
			std::cout << targetFile << " exist status:" << std::filesystem::exists(targetFile) << std::endl;
			continue;
		}		
	}
	int rows = 1000;
	int cols = 1000;
	auto dataset = CustomDataset(sourceList, targetList, rows, cols).map(torch::data::transforms::Stack<>());
	int batchSize = 10;
	auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), batchSize);

	for (auto& batch : *dataLoader) {

		auto data = batch.data;
		auto target = batch.target;
		std::cout << data.sizes() << std::endl;
		std::cout << data.max() << std::endl;
		std::cout << data << std::endl;
	}

	return EXIT_SUCCESS;
}

打印結果如下:

......... 
-0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000
  0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000 -0.0000 -0.0000
 -0.0264 -0.0292 -0.0313 -0.0323 -0.0318 -0.0296 -0.0254 -0.0190 -0.0106
  0.0074  0.0024 -0.0021 -0.0057 -0.0086 -0.0105 -0.0116 -0.0120 -0.0118
  0.0000  0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
 -0.0000 -0.0000  0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000
  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000
 -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000
 -0.0000  0.0000 -0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
  0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000 -0.0000  0.0000 -0.0000
 -0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000
  0.0000 -0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000
 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
  0.0000 -0.0000  0.0000 -0.0000  0.0000 -0.0000 -0.0000  0.0000 -0.0000
  0.0150  0.0152  0.0141  0.0118  0.0086  0.0047  0.0005 -0.0036 -0.0072
  0.0367  0.0215  0.0039 -0.0155 -0.0355 -0.0551 -0.0731 -0.0883 -0.0998
 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
  0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000 -0.0000
 -0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000 -0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000
 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000 -0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.0000  0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000
 -0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000
....

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章