Skip to content

Commit

Permalink
uint256: optimize Mul, squared (#152)
Browse files Browse the repository at this point in the history
Using four local `uint64` variables enables register allocation instead of memory allocation, which improves the performance of `Mul`, `squared` and `Exp` about 40%
  • Loading branch information
AaronChen0 authored Mar 25, 2024
1 parent 97405b6 commit c9fc0ce
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,26 +372,26 @@ func umul(x, y *Int) [8]uint64 {
// Mul sets z to the product x*y
func (z *Int) Mul(x, y *Int) *Int {
var (
res Int
carry uint64
res1, res2, res3 uint64
carry uint64
res0, res1, res2, res3 uint64
)

carry, res[0] = bits.Mul64(x[0], y[0])
carry, res0 = bits.Mul64(x[0], y[0])
carry, res1 = umulHop(carry, x[1], y[0])
carry, res2 = umulHop(carry, x[2], y[0])
res3 = x[3]*y[0] + carry

carry, res[1] = umulHop(res1, x[0], y[1])
carry, res1 = umulHop(res1, x[0], y[1])
carry, res2 = umulStep(res2, x[1], y[1], carry)
res3 = res3 + x[2]*y[1] + carry

carry, res[2] = umulHop(res2, x[0], y[2])
carry, res2 = umulHop(res2, x[0], y[2])
res3 = res3 + x[1]*y[2] + carry

res[3] = res3 + x[0]*y[3]
res3 = res3 + x[0]*y[3]

return z.Set(&res)
z[0], z[1], z[2], z[3] = res0, res1, res2, res3
return z
}

// MulOverflow sets z to the product x*y, and returns z and whether overflow occurred
Expand All @@ -403,23 +403,22 @@ func (z *Int) MulOverflow(x, y *Int) (*Int, bool) {

func (z *Int) squared() {
var (
res Int
carry0, carry1, carry2 uint64
res1, res2 uint64
res0, res1, res2, res3 uint64
)

carry0, res[0] = bits.Mul64(z[0], z[0])
carry0, res0 = bits.Mul64(z[0], z[0])
carry0, res1 = umulHop(carry0, z[0], z[1])
carry0, res2 = umulHop(carry0, z[0], z[2])

carry1, res[1] = umulHop(res1, z[0], z[1])
carry1, res1 = umulHop(res1, z[0], z[1])
carry1, res2 = umulStep(res2, z[1], z[1], carry1)

carry2, res[2] = umulHop(res2, z[0], z[2])
carry2, res2 = umulHop(res2, z[0], z[2])

res[3] = 2*(z[0]*z[3]+z[1]*z[2]) + carry0 + carry1 + carry2
res3 = 2*(z[0]*z[3]+z[1]*z[2]) + carry0 + carry1 + carry2

z.Set(&res)
z[0], z[1], z[2], z[3] = res0, res1, res2, res3
}

// isBitSet returns true if bit n-th is set, where n = 0 is LSB.
Expand Down

0 comments on commit c9fc0ce

Please sign in to comment.