indicatif/
rayon.rs

1use rayon::iter::plumbing::{Consumer, Folder, Producer, ProducerCallback, UnindexedConsumer};
2use rayon::iter::{IndexedParallelIterator, ParallelIterator};
3
4use crate::{ProgressBar, ProgressBarIter};
5
6/// Wraps a Rayon parallel iterator.
7///
8/// See [`ProgressIterator`](trait.ProgressIterator.html) for method
9/// documentation.
10#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
11pub trait ParallelProgressIterator
12where
13    Self: Sized + ParallelIterator,
14{
15    /// Wrap an iterator with a custom progress bar.
16    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
17
18    /// Wrap an iterator with an explicit element count.
19    fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
20        self.progress_with(ProgressBar::new(len))
21    }
22
23    fn progress(self) -> ProgressBarIter<Self>
24    where
25        Self: IndexedParallelIterator,
26    {
27        let len = u64::try_from(self.len()).unwrap();
28        self.progress_count(len)
29    }
30
31    /// Wrap an iterator with a progress bar and style it.
32    fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
33    where
34        Self: IndexedParallelIterator,
35    {
36        let len = u64::try_from(self.len()).unwrap();
37        let bar = ProgressBar::new(len).with_style(style);
38        self.progress_with(bar)
39    }
40}
41
42impl<S: Send, T: ParallelIterator<Item = S>> ParallelProgressIterator for T {
43    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
44        ProgressBarIter { it: self, progress }
45    }
46}
47
48impl<S: Send, T: IndexedParallelIterator<Item = S>> IndexedParallelIterator for ProgressBarIter<T> {
49    fn len(&self) -> usize {
50        self.it.len()
51    }
52
53    fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> <C as Consumer<Self::Item>>::Result {
54        let consumer = ProgressConsumer::new(consumer, self.progress);
55        self.it.drive(consumer)
56    }
57
58    fn with_producer<CB: ProducerCallback<Self::Item>>(
59        self,
60        callback: CB,
61    ) -> <CB as ProducerCallback<Self::Item>>::Output {
62        return self.it.with_producer(Callback {
63            callback,
64            progress: self.progress,
65        });
66
67        struct Callback<CB> {
68            callback: CB,
69            progress: ProgressBar,
70        }
71
72        impl<T, CB: ProducerCallback<T>> ProducerCallback<T> for Callback<CB> {
73            type Output = CB::Output;
74
75            fn callback<P>(self, base: P) -> CB::Output
76            where
77                P: Producer<Item = T>,
78            {
79                let producer = ProgressProducer {
80                    base,
81                    progress: self.progress,
82                };
83                self.callback.callback(producer)
84            }
85        }
86    }
87}
88
89struct ProgressProducer<T> {
90    base: T,
91    progress: ProgressBar,
92}
93
94impl<T, P: Producer<Item = T>> Producer for ProgressProducer<P> {
95    type Item = T;
96    type IntoIter = ProgressBarIter<P::IntoIter>;
97
98    fn into_iter(self) -> Self::IntoIter {
99        ProgressBarIter {
100            it: self.base.into_iter(),
101            progress: self.progress,
102        }
103    }
104
105    fn min_len(&self) -> usize {
106        self.base.min_len()
107    }
108
109    fn max_len(&self) -> usize {
110        self.base.max_len()
111    }
112
113    fn split_at(self, index: usize) -> (Self, Self) {
114        let (left, right) = self.base.split_at(index);
115        (
116            ProgressProducer {
117                base: left,
118                progress: self.progress.clone(),
119            },
120            ProgressProducer {
121                base: right,
122                progress: self.progress,
123            },
124        )
125    }
126}
127
128struct ProgressConsumer<C> {
129    base: C,
130    progress: ProgressBar,
131}
132
133impl<C> ProgressConsumer<C> {
134    fn new(base: C, progress: ProgressBar) -> Self {
135        ProgressConsumer { base, progress }
136    }
137}
138
139impl<T, C: Consumer<T>> Consumer<T> for ProgressConsumer<C> {
140    type Folder = ProgressFolder<C::Folder>;
141    type Reducer = C::Reducer;
142    type Result = C::Result;
143
144    fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
145        let (left, right, reducer) = self.base.split_at(index);
146        (
147            ProgressConsumer::new(left, self.progress.clone()),
148            ProgressConsumer::new(right, self.progress),
149            reducer,
150        )
151    }
152
153    fn into_folder(self) -> Self::Folder {
154        ProgressFolder {
155            base: self.base.into_folder(),
156            progress: self.progress,
157        }
158    }
159
160    fn full(&self) -> bool {
161        self.base.full()
162    }
163}
164
165impl<T, C: UnindexedConsumer<T>> UnindexedConsumer<T> for ProgressConsumer<C> {
166    fn split_off_left(&self) -> Self {
167        ProgressConsumer::new(self.base.split_off_left(), self.progress.clone())
168    }
169
170    fn to_reducer(&self) -> Self::Reducer {
171        self.base.to_reducer()
172    }
173}
174
175struct ProgressFolder<C> {
176    base: C,
177    progress: ProgressBar,
178}
179
180impl<T, C: Folder<T>> Folder<T> for ProgressFolder<C> {
181    type Result = C::Result;
182
183    fn consume(self, item: T) -> Self {
184        self.progress.inc(1);
185        ProgressFolder {
186            base: self.base.consume(item),
187            progress: self.progress,
188        }
189    }
190
191    fn complete(self) -> C::Result {
192        self.base.complete()
193    }
194
195    fn full(&self) -> bool {
196        self.base.full()
197    }
198}
199
200impl<S: Send, T: ParallelIterator<Item = S>> ParallelIterator for ProgressBarIter<T> {
201    type Item = S;
202
203    fn drive_unindexed<C: UnindexedConsumer<Self::Item>>(self, consumer: C) -> C::Result {
204        let consumer1 = ProgressConsumer::new(consumer, self.progress.clone());
205        self.it.drive_unindexed(consumer1)
206    }
207}
208
209#[cfg(test)]
210mod test {
211    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
212
213    use crate::{ParallelProgressIterator, ProgressBar, ProgressBarIter, ProgressStyle};
214
215    #[test]
216    fn it_can_wrap_a_parallel_iterator() {
217        let v = vec![1, 2, 3];
218        fn wrap<'a, T: ParallelIterator<Item = &'a i32>>(it: ProgressBarIter<T>) {
219            assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
220        }
221
222        wrap(v.par_iter().progress_count(3));
223        wrap({
224            let pb = ProgressBar::new(v.len() as u64);
225            v.par_iter().progress_with(pb)
226        });
227
228        wrap({
229            let style = ProgressStyle::default_bar()
230                .template("{wide_bar:.red} {percent}/100%")
231                .unwrap();
232            v.par_iter().progress_with_style(style)
233        });
234    }
235}