背景
前段時間,一位好友發給我如下的文件:
每個CSV文件中的數據由三個屬性組成,第一個屬性爲ID,第二個屬性爲X座標,第三個屬性爲Y座標。由於是二維數據,可以繪製出每個文件的散點圖,把這些散點圖連接起來,構成如下的GIF圖像。類似於視頻中的幀,每一幀對應一個CSV文件。
要做的事情是什麼呢?
在每一幀數據中,快速找到以每個數據點爲圓心,r = 0.004
爲半徑的範圍內包含的其它數據點。
技術分析
剛拿到這個問題的第一個反應就是利用暴力的方法,循環計算每個點與其它點的距離,把小於r
的記錄下來就可以了。可是這樣做會耗費大量的計算時間。
如何改進呢,首先想到的是把計算歐氏距離的根號去掉,只計算歐氏距離的平方,這樣可以避免開根號,從而提升運算效率。於是寫了“暴力求解算法”的代碼,發現執行效率很低,執行 10次 的平均時間約爲 2754毫秒。
static float Distance(PointF p1, PointF p2)
{
float d = (p1.X - p2.X)*(p1.X - p2.X) + (p1.Y - p2.Y)*(p1.Y - p2.Y);
return d;
}
for (int i = 0; i < lst.Count; i++)
{
PointF p1 = lst[i].Original;
for (int j = i + 1; j < lst.Count; j++)
{
PointF p2 = lst[j].Original;
if (Distance(p1, p2) <= r * r)
{
lst[i].LstWithin.Add(p2);
lst[j].LstWithin.Add(p1);
}
}
}
怎麼辦呢?
我們知道平面上的點屬於R2
空間,是基於e1=(1,0),e2=(0,1)
這組基的,爲了方便比較我們可以做一個座標變換,讓這些數據基於a1=(r,0),a2=(0,r)
這組基。這樣給定一個點,我們就能很快判斷出是否超出單位圓的範圍。由於做基變換的過渡矩陣是可逆的,新的座標和初始座標可以相互線性表示,這樣就把原問題簡化了。於是寫了“利用座標變換進行計算”的代碼,發現效率的提升並不明顯,僅僅縮短了一半左右的執行時間,執行 10次 的平均時間約爲 1296毫秒。
for (int i = 0; i < lst.Count; i++)
{
PointF p1 = lst[i].AfterConversion;
for (int j = i + 1; j < lst.Count; j++)
{
PointF p2 = lst[j].AfterConversion;
if (p1.X + 1 < p2.X || p1.X - 1 > p2.X || p1.Y + 1 < p2.Y || p1.Y - 1 > p2.Y)
continue;
if (Distance(p1, p2) <= 1)
{
lst[i].LstWithin.Add(p2);
lst[j].LstWithin.Add(p1);
}
}
}
怎麼辦呢?
我們發現如果對數據的某一個維度進行排序,這樣可以減少計算距離的次數,於是我們先對數據中的x
維度進行排序,然後再來計算距離。於是寫了“利用座標變換和排序進行計算”的代碼,發現效率提升了90%
以上,執行 10次 的平均時間約爲 235毫秒。
for (int i = 0; i < lst.Count; i++)
{
PointF p1 = lst[i].AfterConversion;
for (int j = i + 1; j < lst.Count; j++)
{
PointF p2 = lst[j].AfterConversion;
if (p1.X + 1 < p2.X)
break;
if (Distance(p1, p2) <= 1)
{
lst[i].LstWithin.Add(p2);
lst[j].LstWithin.Add(p1);
}
}
}
這裏還有一個插曲,我曾經想利用KD Tree
的方式來建立每一幀數據的空間索引結構,這是加速K 鄰近
算法的方法。由於每幀數據都是變化的,雖然查找的速度很快,但構建這個Tree
的過程需要耗費大量的時間,於是放棄了這種方法。
代碼實現
記錄時間的代碼如下:
public class TimeRecord
{
[DllImport("Kernel32.dll")]
private static extern bool QueryPerformanceCounter(out long lpFrequency);
[DllImport("Kernel32.dll")]
private static extern bool QueryPerformanceFrequency(out long lpFrequency);
private long _startTime, _stopTime;
private readonly long _freq;
/// <summary>
/// 創建一個TimeRecord類的新實例
/// </summary>
public TimeRecord()
{
_startTime = 0;
_stopTime = 0;
if (QueryPerformanceFrequency(out _freq) == false)
{
throw new Win32Exception();
}
}
/// <summary>
/// 開始計時
/// </summary>
public void Start()
{
Thread.Sleep(0);
QueryPerformanceCounter(out _startTime);
}
/// <summary>
/// 停止計時
/// </summary>
public void Stop()
{
QueryPerformanceCounter(out _stopTime);
}
/// <summary>
/// 獲取從Start()->Stop()中間的精確時間間隔
/// 單位:毫秒
/// 計數次數/計數頻率
/// </summary>
public double DurationMs
{
get
{
return (double)(_stopTime - _startTime) * 1000 / _freq;
}
}
}
構造數據存儲結構的代碼如下:
public class Data
{
/// <summary>
/// 編號
/// </summary>
public int Id;
/// <summary>
/// 原先點
/// </summary>
public PointF Original;
/// <summary>
/// 轉換後的點
/// </summary>
public PointF AfterConversion;
/// <summary>
/// 範圍內的點
/// </summary>
public List<PointF> LstWithin = new List<PointF>();
}
暴力求解算法的代碼如下:
static List<Data> LoadData(string[] fileNames)
{
List<Data> result = new List<Data>();
foreach (string file in fileNames)
{
string[] strs = File.ReadAllLines(file);
foreach (string str in strs)
{
string[] temp = str.Split(new char[] {','});
Data data = new Data
{
Id = int.Parse(temp[0]),
Original = new PointF(float.Parse(temp[1]), float.Parse(temp[2]))
};
result.Add(data);
}
}
return result;
}
static float Distance(PointF p1, PointF p2)
{
float d = (p1.X - p2.X)*(p1.X - p2.X) + (p1.Y - p2.Y)*(p1.Y - p2.Y);
return d;
}
static double Test01(string[] files)
{
TimeRecord tr = new TimeRecord();
tr.Start();
float r = 0.004f;
List<Data> lst = LoadData(files);
for (int i = 0; i < lst.Count; i++)
{
PointF p1 = lst[i].Original;
for (int j = i + 1; j < lst.Count; j++)
{
PointF p2 = lst[j].Original;
if (Distance(p1, p2) <= r * r)
{
lst[i].LstWithin.Add(p2);
lst[j].LstWithin.Add(p1);
}
}
}
tr.Stop();
return tr.DurationMs;
}
利用座標變換進行計算的代碼如下:
static List<Data> LoadData(string[] fileNames,float d)
{
List<Data> result = new List<Data>();
foreach (string file in fileNames)
{
string[] strs = File.ReadAllLines(file);
foreach (string str in strs)
{
string[] temp = str.Split(new char[] { ',' });
Data data = new Data();
data.Id = int.Parse(temp[0]);
float x = float.Parse(temp[1]);
float y = float.Parse(temp[2]);
data.Original = new PointF(x, y);
data.AfterConversion = new PointF(x*d, y*d);
result.Add(data);
}
}
return result;
}
static double Test02(string[] files)
{
TimeRecord tr = new TimeRecord();
tr.Start();
List<Data> lst = LoadData(files, 250);
for (int i = 0; i < lst.Count; i++)
{
PointF p1 = lst[i].AfterConversion;
for (int j = i + 1; j < lst.Count; j++)
{
PointF p2 = lst[j].AfterConversion;
if (p1.X + 1 < p2.X || p1.X - 1 > p2.X || p1.Y + 1 < p2.Y || p1.Y - 1 > p2.Y)
continue;
if (Distance(p1, p2) <= 1)
{
lst[i].LstWithin.Add(p2);
lst[j].LstWithin.Add(p1);
}
}
}
tr.Stop();
return tr.DurationMs;
}
利用座標變換和排序進行計算的代碼如下:
static double Test03(string[] files)
{
TimeRecord tr = new TimeRecord();
tr.Start();
List<Data> lst = LoadData(files, 250);
lst = lst.OrderBy(a => a.AfterConversion.X).ToList();
for (int i = 0; i < lst.Count; i++)
{
PointF p1 = lst[i].AfterConversion;
for (int j = i + 1; j < lst.Count; j++)
{
PointF p2 = lst[j].AfterConversion;
if (p1.X + 1 < p2.X)
break;
if (Distance(p1, p2) <= 1)
{
lst[i].LstWithin.Add(p2);
lst[j].LstWithin.Add(p1);
}
}
}
tr.Stop();
return tr.DurationMs;
}
測試代碼如下:
delegate double Test(string[] files);
static double Run(Test test)
{
string[] files = new string[]
{
"40000weizhi.csv",
"41000weizhi.csv",
"41500weizhi.csv",
"42000weizhi.csv"
};
double[] r = new double[10];
for (int i = 0; i < 10; i++)
{
r[i] = test(files);
}
return r.Average();
}
static void Main(string[] args)
{
double r1 = Run(Test01);
double r2 = Run(Test02);
double r3 = Run(Test03);
Console.WriteLine("方法1耗費時間:{0}", r1);
Console.WriteLine("方法2耗費時間:{0}", r2);
Console.WriteLine("方法3耗費時間:{0}", r3);
}
測試結果如下:
總結
這篇圖文記錄了我嘗試提升算法執行效率的過程,不知道能否達到這位朋友的要求,畢竟比起他們現在的算法,縮短了90%
以上的時間。今天就這樣吧!See You!大家如果有什麼好的方法,歡迎給我留言啊,我們一起把這個問題解決掉。
往期活動
LSGO軟件技術團隊會定期開展提升編程技能的刻意練習活動,希望大家能夠參與進來一起刻意練習,一起學習進步!