1use core::cmp::Ordering;
4use core::marker::PhantomData;
5
6use super::parse::{Parse, ParseError};
7use crate::length::{LengthField, NoLengthField, OptionalLengthField};
8use crate::parse::ByteReader;
9use crate::serialize::{ByteWriter, Serialize, SerializeError};
10
11#[derive(Debug)]
18pub struct DynamicLengthArray<'a, T, L, const MAX_ELEMENTS: usize> {
19 reader: ByteReader<'a>,
21 _marker: PhantomData<(T, L)>,
23}
24
25impl<'a, T, L, const MAX_ELEMENTS: usize> DynamicLengthArray<'a, T, L, MAX_ELEMENTS> {
26 pub fn create<'b, I>(
28 mut elements: I,
29 buffer: &'a mut [u8],
30 ) -> Result<DynamicLengthArray<'a, T, L, MAX_ELEMENTS>, SerializeError>
31 where
32 T: Serialize + 'b,
33 I: Iterator<Item = &'b T>,
34 {
35 let mut byte_writer = ByteWriter::new(buffer);
36
37 let used_bytes = byte_writer.write_counted(move |byte_writer| {
38 let mut element_count = 0;
39
40 for element in elements.by_ref() {
41 element.serialize_partial(byte_writer)?;
42 element_count += 1;
43 }
44
45 match element_count <= MAX_ELEMENTS {
46 true => Ok(()),
47 false => Err(SerializeError::DynamicTypeOverflow),
48 }
49 })?;
50
51 let reader = ByteReader::new(&buffer[..used_bytes]);
52
53 Ok(DynamicLengthArray {
54 reader,
55 _marker: PhantomData,
56 })
57 }
58
59 pub fn iter(&self) -> DynamicLengthArrayIterator<'a, T, MAX_ELEMENTS> {
61 DynamicLengthArrayIterator {
62 reader: self.reader.clone(),
63 element_count: 0,
64 _marker: PhantomData,
65 }
66 }
67}
68
69impl<T, L, const MAX_ELEMENTS: usize> Clone for DynamicLengthArray<'_, T, L, MAX_ELEMENTS> {
70 fn clone(&self) -> Self {
71 Self {
72 reader: self.reader.clone(),
73 _marker: PhantomData,
74 }
75 }
76}
77
78impl<'a, T, L, const MAX_ELEMENTS: usize> PartialEq for DynamicLengthArray<'a, T, L, MAX_ELEMENTS>
79where
80 T: Parse<'a> + PartialEq,
81{
82 fn eq(&self, other: &Self) -> bool {
83 self.iter().eq(other.iter())
84 }
85}
86
87impl<'a, T, L, const MAX_ELEMENTS: usize> Eq for DynamicLengthArray<'a, T, L, MAX_ELEMENTS> where
88 T: Parse<'a> + Eq
89{
90}
91
92impl<'a, T, L, const MAX_ELEMENTS: usize> Parse<'a> for DynamicLengthArray<'a, T, L, MAX_ELEMENTS>
93where
94 T: Parse<'a>,
95 L: LengthField,
96{
97 fn parse_partial(reader: &mut ByteReader<'a>) -> Result<Self, ParseError> {
98 let length = L::get_length(reader)?;
99 let reader = reader.sub_reader(length)?;
100
101 {
103 let mut element_reader = reader.clone();
104 let mut element_count = 0;
105
106 while !element_reader.is_empty() && element_count < MAX_ELEMENTS {
108 let _ = T::parse_partial(&mut element_reader)?;
109 element_count += 1;
110 }
111 }
112
113 Ok(Self {
114 reader,
115 _marker: PhantomData,
116 })
117 }
118}
119
120impl<T, L, const MAX_ELEMENTS: usize> Serialize for DynamicLengthArray<'_, T, L, MAX_ELEMENTS>
121where
122 L: LengthField + Serialize,
123{
124 fn required_length(&self) -> usize {
125 core::mem::size_of::<L>() + self.reader.len()
126 }
127
128 fn serialize_partial(&self, byte_writer: &mut ByteWriter) -> Result<(), SerializeError> {
129 let reserved_length = byte_writer.reserve_length()?;
130
131 let length = byte_writer
132 .write_counted(|byte_writer| byte_writer.write_slice(self.reader.remaining_slice()))?;
133
134 byte_writer.write_length(reserved_length, &L::from_length(length)?)
135 }
136}
137
138#[derive(Debug)]
140pub struct DynamicLengthArrayIterator<'a, T, const MAX_ELEMENTS: usize> {
141 reader: ByteReader<'a>,
142 element_count: usize,
143 _marker: PhantomData<T>,
144}
145
146impl<'a, T, const MAX_ELEMENTS: usize> Iterator for DynamicLengthArrayIterator<'a, T, MAX_ELEMENTS>
147where
148 T: Parse<'a>,
149{
150 type Item = T;
151
152 fn next(&mut self) -> Option<Self::Item> {
153 if self.reader.is_empty() || self.element_count >= MAX_ELEMENTS {
154 return None;
155 }
156
157 self.element_count += 1;
158
159 Some(T::parse_partial(&mut self.reader).unwrap())
160 }
161}
162
163#[derive(Debug)]
168pub struct FixedLengthArray<'a, T, L, const ELEMENT_COUNT: usize> {
169 reader: ByteReader<'a>,
171 _marker: PhantomData<(T, L)>,
173}
174
175impl<'a, T, L, const ELEMENT_COUNT: usize> FixedLengthArray<'a, T, L, ELEMENT_COUNT> {
176 pub fn create<'b, I>(
178 mut elements: I,
179 buffer: &'a mut [u8],
180 ) -> Result<FixedLengthArray<'a, T, L, ELEMENT_COUNT>, SerializeError>
181 where
182 T: Serialize + 'b,
183 I: Iterator<Item = &'b T>,
184 {
185 let mut byte_writer = ByteWriter::new(buffer);
186
187 let used_bytes = byte_writer.write_counted(move |byte_writer| {
188 let mut element_count = 0;
189
190 for element in elements.by_ref() {
191 element.serialize_partial(byte_writer)?;
192 element_count += 1;
193 }
194
195 match element_count.cmp(&ELEMENT_COUNT) {
196 Ordering::Less => Err(SerializeError::DynamicTypeUnderflow),
197 Ordering::Equal => Ok(()),
198 Ordering::Greater => Err(SerializeError::DynamicTypeOverflow),
199 }
200 })?;
201
202 let reader = ByteReader::new(&buffer[..used_bytes]);
203
204 Ok(FixedLengthArray {
205 reader,
206 _marker: PhantomData,
207 })
208 }
209
210 pub fn iter(&self) -> FixedLengthArrayIterator<'a, T, ELEMENT_COUNT> {
212 FixedLengthArrayIterator {
213 reader: self.reader.clone(),
214 element_count: 0,
215 _marker: PhantomData,
216 }
217 }
218}
219
220impl<T, L, const ELEMENT_COUNT: usize> Clone for FixedLengthArray<'_, T, L, ELEMENT_COUNT> {
221 fn clone(&self) -> Self {
222 Self {
223 reader: self.reader.clone(),
224 _marker: PhantomData,
225 }
226 }
227}
228
229impl<'a, T, L, const ELEMENT_COUNT: usize> PartialEq for FixedLengthArray<'a, T, L, ELEMENT_COUNT>
230where
231 T: Parse<'a> + PartialEq,
232{
233 fn eq(&self, other: &Self) -> bool {
234 self.iter().eq(other.iter())
235 }
236}
237
238impl<'a, T, L, const ELEMENT_COUNT: usize> Eq for FixedLengthArray<'a, T, L, ELEMENT_COUNT> where
239 T: Parse<'a> + Eq
240{
241}
242
243impl<'a, T, L, const ELEMENT_COUNT: usize> Parse<'a> for FixedLengthArray<'a, T, L, ELEMENT_COUNT>
244where
245 T: Parse<'a>,
246 L: OptionalLengthField,
247{
248 fn parse_partial(reader: &mut ByteReader<'a>) -> Result<Self, ParseError> {
249 let optional_length = L::try_get_length(reader)?;
250 let array_reader = reader.clone();
251
252 {
254 if optional_length.is_some_and(|length| length < ELEMENT_COUNT) {
258 return Err(ParseError::MalformedMessage {
259 failed_at: core::any::type_name::<Self>(),
260 });
261 }
262
263 if optional_length.is_some_and(|length| length > ELEMENT_COUNT) {
265 return Err(ParseError::MalformedMessage {
269 failed_at: core::any::type_name::<Self>(),
270 });
271 }
272
273 for _ in 0..ELEMENT_COUNT {
274 let _ = T::parse_partial(reader)?;
275 }
276 }
277
278 Ok(Self {
279 reader: array_reader,
280 _marker: PhantomData,
281 })
282 }
283}
284
285impl<T, L, const ELEMENT_COUNT: usize> Serialize for FixedLengthArray<'_, T, L, ELEMENT_COUNT>
286where
287 L: LengthField + Serialize,
288{
289 fn required_length(&self) -> usize {
290 core::mem::size_of::<L>() + self.reader.len()
291 }
292
293 fn serialize_partial(&self, byte_writer: &mut ByteWriter) -> Result<(), SerializeError> {
294 L::from_length(ELEMENT_COUNT)?.serialize_partial(byte_writer)?;
295 byte_writer.write_slice(self.reader.remaining_slice())
296 }
297}
298
299impl<T, const ELEMENT_COUNT: usize> Serialize
300 for FixedLengthArray<'_, T, NoLengthField, ELEMENT_COUNT>
301{
302 fn required_length(&self) -> usize {
303 self.reader.len()
304 }
305
306 fn serialize_partial(&self, byte_writer: &mut ByteWriter) -> Result<(), SerializeError> {
307 byte_writer.write_slice(self.reader.remaining_slice())
308 }
309}
310
311#[derive(Debug)]
313pub struct FixedLengthArrayIterator<'a, T, const ELEMENT_COUNT: usize> {
314 reader: ByteReader<'a>,
315 element_count: usize,
316 _marker: PhantomData<T>,
317}
318
319impl<'a, T, const ELEMENT_COUNT: usize> Iterator for FixedLengthArrayIterator<'a, T, ELEMENT_COUNT>
320where
321 T: Parse<'a>,
322{
323 type Item = T;
324
325 fn next(&mut self) -> Option<Self::Item> {
326 if self.element_count >= ELEMENT_COUNT {
327 return None;
328 }
329
330 self.element_count += 1;
331
332 Some(T::parse_partial(&mut self.reader).unwrap())
333 }
334}
335
336#[cfg(test)]
337#[cfg_attr(coverage_nightly, coverage(off))]
338mod dynamic_length_array {
339
340 use crate::array::DynamicLengthArray;
341 use crate::parse::{ByteReader, Parse, ParseError, ParseExt};
342 use crate::serialize::{Serialize, SerializeError, SerializeExt};
343
344 #[test]
345 fn create_valid() {
346 const EXPECTED_ELEMENTS: [u32; 2] = [10, 30];
347
348 let mut buffer = [0; 64];
349 let array =
350 DynamicLengthArray::<'_, u32, u16, 5>::create(EXPECTED_ELEMENTS.iter(), &mut buffer)
351 .unwrap();
352
353 assert!(array.iter().eq(EXPECTED_ELEMENTS.into_iter()));
354 }
355
356 #[test]
357 fn create_too_many_elements() {
358 const TEST_ELEMENTS: [u32; 3] = [10, 30, 50];
359
360 let mut buffer = [0; 64];
361 assert_eq!(
362 DynamicLengthArray::<'_, u32, u16, 2>::create(TEST_ELEMENTS.iter(), &mut buffer),
363 Err(SerializeError::DynamicTypeOverflow)
364 );
365 }
366
367 #[test]
368 fn create_buffer_too_small() {
369 const TEST_ELEMENTS: [u32; 2] = [10, 30];
370
371 let mut buffer = [0; 7];
372 assert_eq!(
373 DynamicLengthArray::<'_, u32, u16, 5>::create(TEST_ELEMENTS.iter(), &mut buffer),
374 Err(SerializeError::BufferTooSmall)
375 );
376 }
377
378 #[test]
379 fn conversion() {
380 const TEST_ELEMENTS: [u32; 2] = [10, 30];
381 const EXPECTED_BYTES: &[u8] = &[
382 0, 8, 0, 0, 0, 10, 0, 0, 0, 30, ];
386
387 let mut buffer = [0; 64];
388 let array =
389 DynamicLengthArray::<'_, u32, u16, 5>::create(TEST_ELEMENTS.iter(), &mut buffer)
390 .unwrap();
391
392 test_round_trip!(DynamicLengthArray::<'_, u32, u16, 5>, array, EXPECTED_BYTES);
393 }
394
395 #[test]
396 fn parse_too_many_elements() {
397 const TEST_DATA: &[u8] = &[
398 0, 8, 0, 0, 0, 23, 0, 0, 0, 34, ];
402
403 const EXPECTED_ELEMENTS: [u32; 1] = [23];
404
405 assert!(
406 DynamicLengthArray::<'_, u32, u16, 1>::parse(TEST_DATA)
407 .unwrap()
408 .iter()
409 .eq(EXPECTED_ELEMENTS.into_iter())
410 );
411 }
412
413 #[test]
414 fn parse_element_fails() {
415 const TEST_DATA: &[u8] = &[
416 0, 1, 0, 0, 0, 23, ];
419
420 #[derive(Debug)]
421 struct Fails;
422
423 impl<'a> Parse<'a> for Fails {
424 fn parse_partial(_: &mut ByteReader<'a>) -> Result<Self, ParseError> {
425 Err(ParseError::MalformedMessage { failed_at: "Fails" })
426 }
427 }
428
429 assert!(matches!(
430 DynamicLengthArray::<'_, Fails, u16, 2>::parse(TEST_DATA),
431 Err(ParseError::MalformedMessage { .. }),
432 ));
433 }
434
435 #[test]
436 fn serialize_length_overflow() {
437 const TEST_ELEMENTS: [u8; 256] = [0; 256];
438
439 let mut buffer = [0; 512];
440 let array =
441 DynamicLengthArray::<'_, u8, u8, 512>::create(TEST_ELEMENTS.iter(), &mut buffer)
442 .unwrap();
443
444 assert_eq!(
445 array.serialize(&mut [0; 512]),
446 Err(SerializeError::LengthOverflow)
447 );
448 }
449
450 #[test]
451 fn clone_without_clone_bound() {
452 #[derive(Debug, PartialEq, Parse, Serialize)]
453 struct NotClone;
454
455 let mut buffer = [0; 512];
456 let array = DynamicLengthArray::<'_, NotClone, u16, 2>::create(
457 [NotClone, NotClone].iter(),
458 &mut buffer,
459 )
460 .unwrap();
461
462 assert_eq!(array, array.clone());
463 }
464
465 #[test]
466 fn eq() {
467 const TEST_DATA_1: &[u8] = &[
468 0, 4, 0, 0, 0, 23, ];
471
472 const TEST_DATA_2: &[u8] = &[
473 0, 4, 0, 0, 0, 23, 1, 2, 3, 4, ];
477
478 let mut reader = ByteReader::new(TEST_DATA_1);
479 let array_1 = DynamicLengthArray::<'_, u32, u16, 2>::parse_partial(&mut reader).unwrap();
480
481 let mut reader = ByteReader::new(TEST_DATA_2);
482 let array_2 = DynamicLengthArray::<'_, u32, u16, 2>::parse_partial(&mut reader).unwrap();
483
484 assert_eq!(array_1, array_2);
485 }
486}
487
488#[cfg(test)]
489#[cfg_attr(coverage_nightly, coverage(off))]
490mod fixed_length_array {
491
492 use crate::array::FixedLengthArray;
493 use crate::length::NoLengthField;
494 use crate::parse::{ByteReader, Parse, ParseError, ParseExt};
495 use crate::serialize::{Serialize, SerializeError, SerializeExt};
496
497 #[test]
498 fn create_valid() {
499 const EXPECTED_ELEMENTS: [u32; 2] = [10, 30];
500
501 let mut buffer = [0; 64];
502 let array =
503 FixedLengthArray::<'_, u32, u16, 2>::create(EXPECTED_ELEMENTS.iter(), &mut buffer)
504 .unwrap();
505
506 assert!(array.iter().eq(EXPECTED_ELEMENTS.into_iter()));
507 }
508
509 #[test]
510 fn create_too_many_elements() {
511 const TEST_ELEMENTS: [u32; 3] = [10, 30, 50];
512
513 let mut buffer = [0; 64];
514 assert_eq!(
515 FixedLengthArray::<'_, u32, u16, 2>::create(TEST_ELEMENTS.iter(), &mut buffer),
516 Err(SerializeError::DynamicTypeOverflow)
517 );
518 }
519
520 #[test]
521 fn create_too_few_elements() {
522 const TEST_ELEMENTS: [u32; 1] = [10];
523
524 let mut buffer = [0; 64];
525 assert_eq!(
526 FixedLengthArray::<'_, u32, u16, 2>::create(TEST_ELEMENTS.iter(), &mut buffer),
527 Err(SerializeError::DynamicTypeUnderflow)
528 );
529 }
530
531 #[test]
532 fn create_buffer_too_small() {
533 const TEST_ELEMENTS: [u32; 2] = [10, 30];
534
535 let mut buffer = [0; 7];
536 assert_eq!(
537 FixedLengthArray::<'_, u32, u16, 2>::create(TEST_ELEMENTS.iter(), &mut buffer),
538 Err(SerializeError::BufferTooSmall)
539 );
540 }
541
542 #[test]
543 fn conversion_without_length() {
544 const TEST_ELEMENTS: [u32; 2] = [10, 30];
545 const EXPECTED_BYTES: &[u8] = &[
546 0, 0, 0, 10, 0, 0, 0, 30, ];
549
550 let mut buffer = [0; 64];
551 let array = FixedLengthArray::<'_, u32, NoLengthField, 2>::create(
552 TEST_ELEMENTS.iter(),
553 &mut buffer,
554 )
555 .unwrap();
556
557 test_round_trip!(
558 FixedLengthArray::<'_, u32, NoLengthField, 2>,
559 array,
560 EXPECTED_BYTES
561 );
562 }
563
564 #[test]
565 fn conversion_with_length() {
566 const TEST_ELEMENTS: [u32; 2] = [10, 30];
567 const EXPECTED_BYTES: &[u8] = &[
568 0, 2, 0, 0, 0, 10, 0, 0, 0, 30, ];
572
573 let mut buffer = [0; 64];
574 let array =
575 FixedLengthArray::<'_, u32, u16, 2>::create(TEST_ELEMENTS.iter(), &mut buffer).unwrap();
576
577 test_round_trip!(FixedLengthArray::<'_, u32, u16, 2>, array, EXPECTED_BYTES);
578 }
579
580 #[test]
583 fn parse_too_few_elements() {
584 const TEST_DATA: &[u8] = &[
585 0, 1, 0, 0, 0, 23, ];
588
589 assert!(matches!(
590 FixedLengthArray::<'_, u32, u16, 2>::parse(TEST_DATA),
591 Err(ParseError::MalformedMessage { .. }),
592 ));
593 }
594
595 #[test]
598 fn parse_too_many_elements() {
599 const TEST_DATA: &[u8] = &[
600 0, 3, 0, 0, 0, 23, 0, 0, 0, 34, 0, 0, 0, 45, ];
605
606 assert!(matches!(
607 FixedLengthArray::<'_, u32, u16, 2>::parse(TEST_DATA),
608 Err(ParseError::MalformedMessage { .. }),
609 ));
610 }
611
612 #[test]
613 fn serialize_length_overflow() {
614 const TEST_ELEMENTS: [u8; 256] = [0; 256];
615
616 let mut buffer = [0; 512];
617 let array =
618 FixedLengthArray::<'_, u8, u8, 256>::create(TEST_ELEMENTS.iter(), &mut buffer).unwrap();
619
620 assert_eq!(
621 array.serialize(&mut [0; 512]),
622 Err(SerializeError::LengthOverflow)
623 );
624 }
625
626 #[test]
627 fn clone_without_clone_bound() {
628 #[derive(Debug, PartialEq, Parse, Serialize)]
629 struct NotClone;
630
631 let mut buffer = [0; 512];
632 let array = FixedLengthArray::<'_, NotClone, NoLengthField, 2>::create(
633 [NotClone, NotClone].iter(),
634 &mut buffer,
635 )
636 .unwrap();
637
638 assert_eq!(array, array.clone());
639 }
640
641 #[test]
642 fn eq() {
643 const TEST_DATA_1: &[u8] = &[
644 0, 1, 0, 0, 0, 23, ];
647
648 const TEST_DATA_2: &[u8] = &[
649 0, 1, 0, 0, 0, 23, 1, 2, 3, 4, ];
653
654 let mut reader = ByteReader::new(TEST_DATA_1);
655 let array_1 = FixedLengthArray::<'_, u32, u16, 1>::parse_partial(&mut reader).unwrap();
656
657 let mut reader = ByteReader::new(TEST_DATA_2);
658 let array_2 = FixedLengthArray::<'_, u32, u16, 1>::parse_partial(&mut reader).unwrap();
659
660 assert_eq!(array_1, array_2);
661 }
662}