Skip to content

Commit 4e3cc87

Browse files
committed
Add Field::sum_of_products method
Closes #79.
1 parent a35b5eb commit 4e3cc87

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ and this library adheres to Rust's notion of
66
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## [Unreleased]
9+
### Added
10+
- `ff::Field::{sum_of_products, sum_of_products_iter}`
911

1012
## [0.12.1] - 2022-10-28
1113
### Fixed

src/lib.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,40 @@ pub trait Field:
124124

125125
res
126126
}
127+
128+
/// Returns `a.into_iter().zip(b).fold(Self::zero(), |acc, (a_i, b_i)| acc + a_i * b_i)`.
129+
///
130+
/// This computes the "dot product" or "inner product" `a ⋅ b`.
131+
///
132+
/// The provided implementation of this trait method uses the direct calculation given
133+
/// above. Implementations of `Field` should override this to use more efficient
134+
/// methods that take advantage of their internal representation, such as interleaving
135+
/// or sharing modular reductions.
136+
fn sum_of_products<const T: usize>(a: [Self; T], b: [Self; T]) -> Self {
137+
a.into_iter()
138+
.zip(b)
139+
.fold(Self::zero(), |acc, (a_i, b_i)| acc + a_i * b_i)
140+
}
141+
142+
/// Returns `pairs.into_iter().fold(Self::zero(), |acc, (a_i, b_i)| acc + (*a_i * b_i))`.
143+
///
144+
/// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length
145+
/// sequences of elements `a` and `b`, such that `pairs = a.iter().zip(b.iter())`.
146+
///
147+
/// This method is generally slower than [`Self::sum_of_products`] but allows for the
148+
/// number of pairs to be determined at runtime.
149+
///
150+
/// The provided implementation of this trait method uses the direct calculation given
151+
/// above. Implementations of `Field` should override this to use more efficient
152+
/// methods that take advantage of their internal representation, such as interleaving
153+
/// or sharing modular reductions.
154+
fn sum_of_products_iter<'a, I: IntoIterator<Item = (&'a Self, &'a Self)> + Clone>(
155+
pairs: I,
156+
) -> Self {
157+
pairs
158+
.into_iter()
159+
.fold(Self::zero(), |acc, (a_i, b_i)| acc + (*a_i * b_i))
160+
}
127161
}
128162

129163
/// This represents an element of a prime field.

tests/derive.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,74 @@ mod full_limbs {
3737
}
3838
}
3939

40+
#[test]
41+
fn sum_of_products() {
42+
use ff::{Field, PrimeField};
43+
44+
let one = Bls381K12Scalar::one();
45+
46+
// [1, 2, 3, 4]
47+
let values = {
48+
let mut iter = (0..4).scan(one, |acc, _| {
49+
let ret = *acc;
50+
*acc += &one;
51+
Some(ret)
52+
});
53+
[
54+
iter.next().unwrap(),
55+
iter.next().unwrap(),
56+
iter.next().unwrap(),
57+
iter.next().unwrap(),
58+
]
59+
};
60+
61+
// We'll pair each value with itself.
62+
let expected = Bls381K12Scalar::from_str_vartime("30").unwrap();
63+
64+
assert_eq!(Bls381K12Scalar::sum_of_products(values, values), expected,);
65+
}
66+
67+
#[test]
68+
fn sum_of_products_iter() {
69+
use ff::{Field, PrimeField};
70+
71+
let one = Bls381K12Scalar::one();
72+
73+
// [1, 2, 3, 4]
74+
let values: Vec<_> = (0..4)
75+
.scan(one, |acc, _| {
76+
let ret = *acc;
77+
*acc += &one;
78+
Some(ret)
79+
})
80+
.collect();
81+
82+
// We'll pair each value with itself.
83+
let expected = Bls381K12Scalar::from_str_vartime("30").unwrap();
84+
85+
// Check that we can produce the necessary input from two iterators.
86+
assert_eq!(
87+
// Directly produces (&v, &v)
88+
Bls381K12Scalar::sum_of_products_iter(values.iter().zip(values.iter())),
89+
expected,
90+
);
91+
92+
// Check that we can produce the necessary input from an iterator of values.
93+
assert_eq!(
94+
// Maps &v to (&v, &v)
95+
Bls381K12Scalar::sum_of_products_iter(values.iter().map(|v| (v, v))),
96+
expected,
97+
);
98+
99+
// Check that we can produce the necessary input from an iterator of tuples.
100+
let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect();
101+
assert_eq!(
102+
// Maps &(a, b) to (&a, &b)
103+
Bls381K12Scalar::sum_of_products_iter(tuples.iter().map(|(a, b)| (a, b))),
104+
expected,
105+
);
106+
}
107+
40108
#[test]
41109
fn batch_inversion() {
42110
use ff::{BatchInverter, Field};

0 commit comments

Comments
 (0)