bn_mp_root_u32.c 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #include "tommath_private.h"
  2. #ifdef BN_MP_ROOT_U32_C
  3. /* LibTomMath, multiple-precision integer library -- Tom St Denis */
  4. /* SPDX-License-Identifier: Unlicense */
  5. /* find the n'th root of an integer
  6. *
  7. * Result found such that (c)**b <= a and (c+1)**b > a
  8. *
  9. * This algorithm uses Newton's approximation
  10. * x[i+1] = x[i] - f(x[i])/f'(x[i])
  11. * which will find the root in log(N) time where
  12. * each step involves a fair bit.
  13. */
  14. mp_err mp_root_u32(const mp_int *a, uint32_t b, mp_int *c)
  15. {
  16. mp_int t1, t2, t3, a_;
  17. mp_ord cmp;
  18. int ilog2;
  19. mp_err err;
  20. /* input must be positive if b is even */
  21. if (((b & 1u) == 0u) && (a->sign == MP_NEG)) {
  22. return MP_VAL;
  23. }
  24. if ((err = mp_init_multi(&t1, &t2, &t3, NULL)) != MP_OKAY) {
  25. return err;
  26. }
  27. /* if a is negative fudge the sign but keep track */
  28. a_ = *a;
  29. a_.sign = MP_ZPOS;
  30. /* Compute seed: 2^(log_2(n)/b + 2)*/
  31. ilog2 = mp_count_bits(a);
  32. /*
  33. If "b" is larger than INT_MAX it is also larger than
  34. log_2(n) because the bit-length of the "n" is measured
  35. with an int and hence the root is always < 2 (two).
  36. */
  37. if (b > (uint32_t)(INT_MAX/2)) {
  38. mp_set(c, 1uL);
  39. c->sign = a->sign;
  40. err = MP_OKAY;
  41. goto LBL_ERR;
  42. }
  43. /* "b" is smaller than INT_MAX, we can cast safely */
  44. if (ilog2 < (int)b) {
  45. mp_set(c, 1uL);
  46. c->sign = a->sign;
  47. err = MP_OKAY;
  48. goto LBL_ERR;
  49. }
  50. ilog2 = ilog2 / ((int)b);
  51. if (ilog2 == 0) {
  52. mp_set(c, 1uL);
  53. c->sign = a->sign;
  54. err = MP_OKAY;
  55. goto LBL_ERR;
  56. }
  57. /* Start value must be larger than root */
  58. ilog2 += 2;
  59. if ((err = mp_2expt(&t2,ilog2)) != MP_OKAY) goto LBL_ERR;
  60. do {
  61. /* t1 = t2 */
  62. if ((err = mp_copy(&t2, &t1)) != MP_OKAY) goto LBL_ERR;
  63. /* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */
  64. /* t3 = t1**(b-1) */
  65. if ((err = mp_expt_u32(&t1, b - 1u, &t3)) != MP_OKAY) goto LBL_ERR;
  66. /* numerator */
  67. /* t2 = t1**b */
  68. if ((err = mp_mul(&t3, &t1, &t2)) != MP_OKAY) goto LBL_ERR;
  69. /* t2 = t1**b - a */
  70. if ((err = mp_sub(&t2, &a_, &t2)) != MP_OKAY) goto LBL_ERR;
  71. /* denominator */
  72. /* t3 = t1**(b-1) * b */
  73. if ((err = mp_mul_d(&t3, b, &t3)) != MP_OKAY) goto LBL_ERR;
  74. /* t3 = (t1**b - a)/(b * t1**(b-1)) */
  75. if ((err = mp_div(&t2, &t3, &t3, NULL)) != MP_OKAY) goto LBL_ERR;
  76. if ((err = mp_sub(&t1, &t3, &t2)) != MP_OKAY) goto LBL_ERR;
  77. /*
  78. Number of rounds is at most log_2(root). If it is more it
  79. got stuck, so break out of the loop and do the rest manually.
  80. */
  81. if (ilog2-- == 0) {
  82. break;
  83. }
  84. } while (mp_cmp(&t1, &t2) != MP_EQ);
  85. /* result can be off by a few so check */
  86. /* Loop beneath can overshoot by one if found root is smaller than actual root */
  87. for (;;) {
  88. if ((err = mp_expt_u32(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR;
  89. cmp = mp_cmp(&t2, &a_);
  90. if (cmp == MP_EQ) {
  91. err = MP_OKAY;
  92. goto LBL_ERR;
  93. }
  94. if (cmp == MP_LT) {
  95. if ((err = mp_add_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR;
  96. } else {
  97. break;
  98. }
  99. }
  100. /* correct overshoot from above or from recurrence */
  101. for (;;) {
  102. if ((err = mp_expt_u32(&t1, b, &t2)) != MP_OKAY) goto LBL_ERR;
  103. if (mp_cmp(&t2, &a_) == MP_GT) {
  104. if ((err = mp_sub_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR;
  105. } else {
  106. break;
  107. }
  108. }
  109. /* set the result */
  110. mp_exch(&t1, c);
  111. /* set the sign of the result */
  112. c->sign = a->sign;
  113. err = MP_OKAY;
  114. LBL_ERR:
  115. mp_clear_multi(&t1, &t2, &t3, NULL);
  116. return err;
  117. }
  118. #endif