threefry_engine.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2015 Muhammad Junaid Muzammil <mjunaidmuzammil@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_RANDOM_THREEFRY_HPP
  11. #define BOOST_COMPUTE_RANDOM_THREEFRY_HPP
  12. #include <algorithm>
  13. #include <boost/compute/types.hpp>
  14. #include <boost/compute/buffer.hpp>
  15. #include <boost/compute/kernel.hpp>
  16. #include <boost/compute/context.hpp>
  17. #include <boost/compute/program.hpp>
  18. #include <boost/compute/command_queue.hpp>
  19. #include <boost/compute/algorithm/transform.hpp>
  20. #include <boost/compute/detail/iterator_range_size.hpp>
  21. #include <boost/compute/utility/program_cache.hpp>
  22. #include <boost/compute/container/vector.hpp>
  23. #include <boost/compute/iterator/discard_iterator.hpp>
  24. namespace boost {
  25. namespace compute {
  26. /// \class threefry_engine
  27. /// \brief Threefry pseudorandom number generator.
  28. template<class T = uint_>
  29. class threefry_engine
  30. {
  31. public:
  32. static const size_t threads = 1024;
  33. typedef T result_type;
  34. /// Creates a new threefry_engine and seeds it with \p value.
  35. explicit threefry_engine(command_queue &queue)
  36. : m_context(queue.get_context())
  37. {
  38. // setup program
  39. load_program();
  40. }
  41. /// Creates a new threefry_engine object as a copy of \p other.
  42. threefry_engine(const threefry_engine<T> &other)
  43. : m_context(other.m_context),
  44. m_program(other.m_program)
  45. {
  46. }
  47. /// Copies \p other to \c *this.
  48. threefry_engine<T>& operator=(const threefry_engine<T> &other)
  49. {
  50. if(this != &other){
  51. m_context = other.m_context;
  52. m_program = other.m_program;
  53. }
  54. return *this;
  55. }
  56. /// Destroys the threefry_engine object.
  57. ~threefry_engine()
  58. {
  59. }
  60. private:
  61. /// \internal_
  62. void load_program()
  63. {
  64. boost::shared_ptr<program_cache> cache =
  65. program_cache::get_global_cache(m_context);
  66. std::string cache_key =
  67. std::string("threefry_engine_32x2");
  68. // Copyright 2010-2012, D. E. Shaw Research.
  69. // All rights reserved.
  70. // Redistribution and use in source and binary forms, with or without
  71. // modification, are permitted provided that the following conditions are
  72. // met:
  73. // * Redistributions of source code must retain the above copyright
  74. // notice, this list of conditions, and the following disclaimer.
  75. // * Redistributions in binary form must reproduce the above copyright
  76. // notice, this list of conditions, and the following disclaimer in the
  77. // documentation and/or other materials provided with the distribution.
  78. // * Neither the name of D. E. Shaw Research nor the names of its
  79. // contributors may be used to endorse or promote products derived from
  80. // this software without specific prior written permission.
  81. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  82. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  83. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  84. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  85. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  86. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  87. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  88. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  89. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  90. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  91. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  92. const char source[] =
  93. "#define THREEFRY2x32_DEFAULT_ROUNDS 20\n"
  94. "#define SKEIN_KS_PARITY_32 0x1BD11BDA\n"
  95. "enum r123_enum_threefry32x2 {\n"
  96. " R_32x2_0_0=13,\n"
  97. " R_32x2_1_0=15,\n"
  98. " R_32x2_2_0=26,\n"
  99. " R_32x2_3_0= 6,\n"
  100. " R_32x2_4_0=17,\n"
  101. " R_32x2_5_0=29,\n"
  102. " R_32x2_6_0=16,\n"
  103. " R_32x2_7_0=24\n"
  104. "};\n"
  105. "static uint RotL_32(uint x, uint N)\n"
  106. "{\n"
  107. " return (x << (N & 31)) | (x >> ((32-N) & 31));\n"
  108. "}\n"
  109. "struct r123array2x32 {\n"
  110. " uint v[2];\n"
  111. "};\n"
  112. "typedef struct r123array2x32 threefry2x32_ctr_t;\n"
  113. "typedef struct r123array2x32 threefry2x32_key_t;\n"
  114. "threefry2x32_ctr_t threefry2x32_R(unsigned int Nrounds, threefry2x32_ctr_t in, threefry2x32_key_t k)\n"
  115. "{\n"
  116. " threefry2x32_ctr_t X;\n"
  117. " uint ks[3];\n"
  118. " uint i; \n"
  119. " ks[2] = SKEIN_KS_PARITY_32;\n"
  120. " for (i=0;i < 2; i++) {\n"
  121. " ks[i] = k.v[i];\n"
  122. " X.v[i] = in.v[i];\n"
  123. " ks[2] ^= k.v[i];\n"
  124. " }\n"
  125. " X.v[0] += ks[0]; X.v[1] += ks[1];\n"
  126. " if(Nrounds>0){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_0_0); X.v[1] ^= X.v[0]; }\n"
  127. " if(Nrounds>1){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_1_0); X.v[1] ^= X.v[0]; }\n"
  128. " if(Nrounds>2){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_2_0); X.v[1] ^= X.v[0]; }\n"
  129. " if(Nrounds>3){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_3_0); X.v[1] ^= X.v[0]; }\n"
  130. " if(Nrounds>3){\n"
  131. " X.v[0] += ks[1]; X.v[1] += ks[2];\n"
  132. " X.v[1] += 1;\n"
  133. " }\n"
  134. " if(Nrounds>4){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_4_0); X.v[1] ^= X.v[0]; }\n"
  135. " if(Nrounds>5){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_5_0); X.v[1] ^= X.v[0]; }\n"
  136. " if(Nrounds>6){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_6_0); X.v[1] ^= X.v[0]; }\n"
  137. " if(Nrounds>7){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_7_0); X.v[1] ^= X.v[0]; }\n"
  138. " if(Nrounds>7){\n"
  139. " X.v[0] += ks[2]; X.v[1] += ks[0];\n"
  140. " X.v[1] += 2;\n"
  141. " }\n"
  142. " if(Nrounds>8){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_0_0); X.v[1] ^= X.v[0]; }\n"
  143. " if(Nrounds>9){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_1_0); X.v[1] ^= X.v[0]; }\n"
  144. " if(Nrounds>10){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_2_0); X.v[1] ^= X.v[0]; }\n"
  145. " if(Nrounds>11){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_3_0); X.v[1] ^= X.v[0]; }\n"
  146. " if(Nrounds>11){\n"
  147. " X.v[0] += ks[0]; X.v[1] += ks[1];\n"
  148. " X.v[1] += 3;\n"
  149. " }\n"
  150. " if(Nrounds>12){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_4_0); X.v[1] ^= X.v[0]; }\n"
  151. " if(Nrounds>13){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_5_0); X.v[1] ^= X.v[0]; }\n"
  152. " if(Nrounds>14){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_6_0); X.v[1] ^= X.v[0]; }\n"
  153. " if(Nrounds>15){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_7_0); X.v[1] ^= X.v[0]; }\n"
  154. " if(Nrounds>15){\n"
  155. " X.v[0] += ks[1]; X.v[1] += ks[2];\n"
  156. " X.v[1] += 4;\n"
  157. " }\n"
  158. " if(Nrounds>16){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_0_0); X.v[1] ^= X.v[0]; }\n"
  159. " if(Nrounds>17){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_1_0); X.v[1] ^= X.v[0]; }\n"
  160. " if(Nrounds>18){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_2_0); X.v[1] ^= X.v[0]; }\n"
  161. " if(Nrounds>19){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_3_0); X.v[1] ^= X.v[0]; }\n"
  162. " if(Nrounds>19){\n"
  163. " X.v[0] += ks[2]; X.v[1] += ks[0];\n"
  164. " X.v[1] += 5;\n"
  165. " }\n"
  166. " if(Nrounds>20){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_4_0); X.v[1] ^= X.v[0]; }\n"
  167. " if(Nrounds>21){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_5_0); X.v[1] ^= X.v[0]; }\n"
  168. " if(Nrounds>22){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_6_0); X.v[1] ^= X.v[0]; }\n"
  169. " if(Nrounds>23){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_7_0); X.v[1] ^= X.v[0]; }\n"
  170. " if(Nrounds>23){\n"
  171. " X.v[0] += ks[0]; X.v[1] += ks[1];\n"
  172. " X.v[1] += 6;\n"
  173. " }\n"
  174. " if(Nrounds>24){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_0_0); X.v[1] ^= X.v[0]; }\n"
  175. " if(Nrounds>25){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_1_0); X.v[1] ^= X.v[0]; }\n"
  176. " if(Nrounds>26){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_2_0); X.v[1] ^= X.v[0]; }\n"
  177. " if(Nrounds>27){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_3_0); X.v[1] ^= X.v[0]; }\n"
  178. " if(Nrounds>27){\n"
  179. " X.v[0] += ks[1]; X.v[1] += ks[2];\n"
  180. " X.v[1] += 7;\n"
  181. " }\n"
  182. " if(Nrounds>28){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_4_0); X.v[1] ^= X.v[0]; }\n"
  183. " if(Nrounds>29){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_5_0); X.v[1] ^= X.v[0]; }\n"
  184. " if(Nrounds>30){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_6_0); X.v[1] ^= X.v[0]; }\n"
  185. " if(Nrounds>31){ X.v[0] += X.v[1]; X.v[1] = RotL_32(X.v[1],R_32x2_7_0); X.v[1] ^= X.v[0]; }\n"
  186. " if(Nrounds>31){\n"
  187. " X.v[0] += ks[2]; X.v[1] += ks[0];\n"
  188. " X.v[1] += 8;\n"
  189. " }\n"
  190. " return X;\n"
  191. "}\n"
  192. "__kernel void generate_rng(__global uint *ctr, __global uint *key, const uint offset) {\n"
  193. " threefry2x32_ctr_t in;\n"
  194. " threefry2x32_key_t k;\n"
  195. " const uint i = get_global_id(0);\n"
  196. " in.v[0] = ctr[2 * (offset + i)];\n"
  197. " in.v[1] = ctr[2 * (offset + i) + 1];\n"
  198. " k.v[0] = key[2 * (offset + i)];\n"
  199. " k.v[1] = key[2 * (offset + i) + 1];\n"
  200. " in = threefry2x32_R(20, in, k);\n"
  201. " ctr[2 * (offset + i)] = in.v[0];\n"
  202. " ctr[2 * (offset + i) + 1] = in.v[1];\n"
  203. "}\n";
  204. m_program = cache->get_or_build(cache_key, std::string(), source, m_context);
  205. }
  206. public:
  207. /// Generates Threefry random numbers using both the counter and key values, and then stores
  208. /// them to the range [\p first_ctr, \p last_ctr).
  209. template<class OutputIterator>
  210. void generate(OutputIterator first_ctr, OutputIterator last_ctr, OutputIterator first_key, OutputIterator last_key, command_queue &queue) {
  211. const size_t size_ctr = detail::iterator_range_size(first_ctr, last_ctr);
  212. const size_t size_key = detail::iterator_range_size(first_key, last_key);
  213. if(!size_ctr || !size_key || (size_ctr != size_key)) {
  214. return;
  215. }
  216. kernel rng_kernel = m_program.create_kernel("generate_rng");
  217. rng_kernel.set_arg(0, first_ctr.get_buffer());
  218. rng_kernel.set_arg(1, first_key.get_buffer());
  219. size_t offset = 0;
  220. for(;;){
  221. size_t count = 0;
  222. size_t size = size_ctr/2;
  223. if(size > threads){
  224. count = (std::min)(static_cast<size_t>(threads), size - offset);
  225. }
  226. else {
  227. count = size;
  228. }
  229. rng_kernel.set_arg(2, static_cast<const uint_>(offset));
  230. queue.enqueue_1d_range_kernel(rng_kernel, 0, count, 0);
  231. offset += count;
  232. if(offset >= size){
  233. break;
  234. }
  235. }
  236. }
  237. template<class OutputIterator>
  238. void generate(OutputIterator first_ctr, OutputIterator last_ctr, command_queue &queue) {
  239. const size_t size_ctr = detail::iterator_range_size(first_ctr, last_ctr);
  240. if(!size_ctr) {
  241. return;
  242. }
  243. boost::compute::vector<uint_> vector_key(size_ctr, m_context);
  244. vector_key.assign(size_ctr, 0, queue);
  245. kernel rng_kernel = m_program.create_kernel("generate_rng");
  246. rng_kernel.set_arg(0, first_ctr.get_buffer());
  247. rng_kernel.set_arg(1, vector_key);
  248. size_t offset = 0;
  249. for(;;){
  250. size_t count = 0;
  251. size_t size = size_ctr/2;
  252. if(size > threads){
  253. count = (std::min)(static_cast<size_t>(threads), size - offset);
  254. }
  255. else {
  256. count = size;
  257. }
  258. rng_kernel.set_arg(2, static_cast<const uint_>(offset));
  259. queue.enqueue_1d_range_kernel(rng_kernel, 0, count, 0);
  260. offset += count;
  261. if(offset >= size){
  262. break;
  263. }
  264. }
  265. }
  266. private:
  267. context m_context;
  268. program m_program;
  269. };
  270. } // end compute namespace
  271. } // end boost namespace
  272. #endif // BOOST_COMPUTE_RANDOM_THREEFRY_HPP