经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 程序设计 » NumPy » 查看文章
Numpy计算近邻表时间对比
来源:cnblogs  作者:DECHIN  时间:2024/1/10 8:57:06  对本文有异议

技术背景

所谓的近邻表求解,就是给定N个原子的体系,找出满足cutoff要求的每一对原子。在前面的几篇博客中,我们分别介绍过CUDA近邻表计算JAX-MD关于格点法求解近邻表的实现。虽然我们从理论上可以知道,用格点法求解近邻表,在复杂度上肯定是要优于传统的算法。本文主要从Python代码的实现上来具体测试一下二者的速度差异,这里使用的硬件还是CPU。

算法解析

若一对原子A和B满足下述条件,则称A、B为一对近邻原子:

\[|\textbf{r}_A-\textbf{r}_B|\leq cutoff \]

传统的求解方法,就是把所有原子间距都计算一遍,然后对每个原子的近邻原子进行排序,最终按照给定的cutoff截断值确定相关的近邻原子。在Python中的实现,因为有numpy这样的强力工具,我们在计算原子两两间距时,只需要对一组维度为(N,D)的原子坐标进行扩维,分别变成(1,N,D)和(N,1,D)大小的原子坐标。然后将二者相减,计算过程中会自动广播(Broadcast)成(N,N,D)和(N,N,D)的两个数组进行计算。对得到的结果做一个Norm,就可以得到维度为(N,N)的两两间距矩阵。该算法的计算复杂度为O(N^2)

相对高效的一种求解方案是将原子坐标所在的空间划分成众多的小区域,通常我们设定这些小区域为边长等于cutoff的小正方体。这种设定有一个好处是,我们可以确定每一个正方体的近邻原子,一定在最靠近其周边的26个小正方体区域内。这样一来,我们就不需要去计算全局的两两间距,只需要计算单个小正方体内(假定有M个原子)的两两间距(M,M),以及单个正方体与周边正方体内原子的配对间距(M,26M)。之所以这样分开计算,是为了减少原子跟自身间距的这一项重复计算。那么对于整个空间的原子,就需要计算(N,27M)这么多次的原子间距,是一个复杂度为O(NlogN)的算法。

Numpy代码实现

这里我们基于Python中的numpy框架来实现这两个不同的计算近邻表的算法。其实当我们使用numpy来进行计算的时候,应当尽可能的避免循环体的使用。但是这里仅演示两种算法的差异性,因此在实现格点法的时候偷了点懒,用了两个for循环,感兴趣的童鞋可以自行优化。

  1. import time
  2. from itertools import chain
  3. from operator import itemgetter
  4. import numpy as np
  5. # 在格点法中,为了避免重复计算,我们可以仅计算一半的近邻格点中的原子间距
  6. NEIGHBOUR_GRID = np.array([
  7. [-1, 1, 0],
  8. [-1, -1, 1],
  9. [-1, 0, 1],
  10. [-1, 1, 1],
  11. [ 0, -1, 1],
  12. [ 0, 0, 1],
  13. [ 0, 1, 0],
  14. [ 0, 1, 1],
  15. [ 1, -1, 1],
  16. [ 1, 0, 0],
  17. [ 1, 0, 1],
  18. [ 1, 1, 0],
  19. [ 1, 1, 1]], np.int32)
  20. # 原始的两两间距计算方法,需要排序
  21. def get_neighbours_by_dist(crd, cutoff):
  22. large_dis = np.tril(np.ones((crd.shape[0], crd.shape[0])) * 999)
  23. # (N, N)
  24. dis = np.linalg.norm(crd[None] - crd[:, None], axis=-1) + large_dis
  25. # (N, M)
  26. neigh = np.argsort(dis, axis=-1)
  27. # (N, M)
  28. cut = np.take_along_axis(dis, neigh, axis=1)
  29. # (2, P)
  30. pairs = np.where(cut <= cutoff)
  31. # (P, )
  32. pairs_id0 = pairs[0]
  33. pairs_id1 = neigh[pairs]
  34. # (P, 2)
  35. sort_args = np.argsort(pairs_id0)
  36. return np.hstack((pairs_id0[..., None], pairs_id1[..., None]))[sort_args]
  37. # 格点法计算近邻表,先分格点,然后分两个模块计算单格点内原子间距,和中心格点-周边格点内的原子间距
  38. def get_neighbours_by_grid(crd, cutoff):
  39. # (D, )
  40. min_xyz = np.min(crd, axis=0)
  41. max_xyz = np.max(crd, axis=0)
  42. space = max_xyz - min_xyz
  43. grids = np.ceil(space / cutoff).astype(np.int32)
  44. num_grids = np.product(grids)
  45. buffer = (grids * cutoff - space) / 2
  46. start_crd = min_xyz - buffer
  47. # (N, D)
  48. grid_id = ((crd - start_crd) // cutoff).astype(np.int32)
  49. grid_coe = np.array([1, grids[0], grids[1]], np.int32)
  50. # (N, )
  51. grid_id_1d = np.sum(grid_id * grid_coe, axis=-1).astype(np.int32)
  52. # (N, 2)
  53. grid_id_dict = np.ndenumerate(grid_id_1d)
  54. # (G, *)
  55. grid_dict = dict.fromkeys(range(num_grids), ())
  56. for index, value in grid_id_dict:
  57. grid_dict[value] += index
  58. neighbour_grid = (NEIGHBOUR_GRID * grid_coe).sum(axis=-1).astype(np.int32)
  59. neighbour_pairs = []
  60. for i in range(num_grids):
  61. if grid_dict[i]:
  62. keeps = np.where((neighbour_grid + i < num_grids) & (neighbour_grid + i >= 0))[0]
  63. neighbour_grid_keep = neighbour_grid[keeps] + i
  64. grid_atoms = np.array(list(grid_dict[i]), np.int32)
  65. try:
  66. grid_neighbours = np.array(list(chain(*itemgetter(*neighbour_grid_keep)(grid_dict))), np.int32)
  67. except TypeError:
  68. if neighbour_grid_keep.size == 0:
  69. grid_neighbours = np.array([], np.int32)
  70. else:
  71. grid_neighbours = np.array(list(itemgetter(*neighbour_grid_keep)(grid_dict)), np.int32)
  72. grid_crds = crd[grid_atoms]
  73. grid_neighbour_crds = crd[grid_neighbours]
  74. large_dis = np.tril(np.ones((grid_crds.shape[0], grid_crds.shape[0])) * 999)
  75. # 单格点内部原子间距
  76. grid_dis = np.linalg.norm(grid_crds[None] - grid_crds[:, None], axis=-1) + large_dis
  77. grid_pairs = np.argsort(grid_dis, axis=-1)
  78. grid_cut = np.take_along_axis(grid_dis, grid_pairs, axis=-1)
  79. pairs = np.where(grid_cut <= cutoff)
  80. pairs_id0 = grid_atoms[pairs[0]]
  81. pairs_id1 = grid_atoms[grid_pairs[pairs]]
  82. neighbour_pairs.extend(list(np.hstack((pairs_id0[..., None], pairs_id1[..., None]))))
  83. # 中心格点-周边格点内原子间距
  84. grid_dis = np.linalg.norm(grid_crds[:, None] - grid_neighbour_crds[None], axis=-1)
  85. grid_pairs = np.argsort(grid_dis, axis=-1)
  86. grid_cut = np.take_along_axis(grid_dis, grid_pairs, axis=-1)
  87. pairs = np.where(grid_cut <= cutoff)
  88. pairs_id0 = grid_atoms[pairs[0]]
  89. pairs_id1 = grid_neighbours[grid_pairs[pairs]]
  90. neighbour_pairs.extend(list(np.hstack((pairs_id0[..., None], pairs_id1[..., None]))))
  91. neighbour_pairs = np.sort(np.array(neighbour_pairs), axis=-1)
  92. sort_args = np.argsort(neighbour_pairs[:, 0])
  93. return neighbour_pairs[sort_args]
  94. # 时间测算函数
  95. def benchmark(N, cutoff=0.3, D=3):
  96. crd = np.random.random((N, D)).astype(np.float32) * np.array([3., 4., 5.], np.float32)
  97. # Solution 1
  98. time0 = time.time()
  99. neighbours_1 = get_neighbours_by_dist(crd, cutoff)
  100. time1 = time.time()
  101. record_1 = time1 - time0
  102. # Solution 2
  103. time0 = time.time()
  104. neighbours_2 = get_neighbours_by_grid(crd, cutoff)
  105. time1 = time.time()
  106. record_2 = time1 - time0
  107. for pair in neighbours_1:
  108. if (np.isin(neighbours_2, pair).sum(axis=-1) < 2).all():
  109. print (pair)
  110. assert neighbours_1.shape == neighbours_2.shape
  111. return record_1, record_2
  112. # 绘图主函数
  113. if __name__ == '__main__':
  114. import matplotlib.pyplot as plt
  115. sizes = range(1000, 10000, 1000)
  116. time_dis = []
  117. time_grid = []
  118. for size in sizes:
  119. print (size)
  120. times = benchmark(size)
  121. time_dis.append(times[0])
  122. time_grid.append(times[1])
  123. plt.figure()
  124. plt.title('Neighbour List Calculation Time')
  125. plt.plot(sizes, time_dis, color='black', label='Full Connect')
  126. plt.plot(sizes, time_grid, color='blue', label='Cell List')
  127. plt.xlabel('Size')
  128. plt.ylabel('Time/s')
  129. plt.legend()
  130. plt.grid()
  131. plt.show()

上述代码的运行结果如下图所示:

其实因为格点法中使用了for循环的问题,函数效率并不高。因此在体系非常小的场景下(比如只有几十个原子的体系),本文用到的格点法代码效率并不如计算所有的原子两两间距。但是毕竟格点法的复杂度较低,因此在运行过程中随着体系的增长,格点法的优势也越来越大。

近邻表计算与分子动力学模拟

在分子动力学模拟中计算长程相互作用时,会经常使用到近邻表。如果要在GPU上实现格点近邻算法,有可能会遇到这样的一些问题:

  1. GPU更加擅长处理静态Shape的张量,因此往往会使用一个最大近邻数,对每一个原子的近邻原子标号进行限制,一般不允许满足cutoff的近邻原子数超过最大近邻数,否则这个cutoff就失去意义了。而如果单个原子的近邻原子数量低于最大近邻数,这时候就会用一个没有意义的数对剩下分配好的张量空间进行填充(Padding),这样一来会带来很多不必要的计算。
  2. 在运行分子动力学模拟的过程中,体系原子的坐标在不断的变化,近邻表也会随之变化,而此时的最大近邻数有可能无法存储完整的cutoff内的原子。

总结概要

本文介绍了在Python的numpy框架下计算近邻表的两种不同算法的原理以及复杂度,另有分别对应的两种代码实现。在实际使用中,我们更偏向于第二种算法的使用。因为对于第一种算法来说,哪怕是一个10000个原子的小体系,如果要计算两两间距,也会变成10000*10000这么大的一个张量的运算。可想而知,这样计算的效率肯定是比较低下的。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/cell-list.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

原文链接:https://www.cnblogs.com/dechinphy/p/17954648/cell-list

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号