Skip to content

Commit ef54971

Browse files
authored
Add more re-implementations of the bolt protocol (#171)
* Move packstream mod into a more expected directory structure * Allow for differnt vis than pub on cenums * Add Commit protocol message * Add Discard protocol message * Add Goodbye protocol message * Add Hello protocol message * Add Reset protocol message * Add Rollback protocol message * Export new protocol messages * Silence dead_code warnings until feature is done
1 parent 99da90c commit ef54971

File tree

14 files changed

+776
-241
lines changed

14 files changed

+776
-241
lines changed

lib/src/bolt/mod.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#![allow(unused_imports)]
1+
#![allow(unused_imports, dead_code)]
22

33
use std::marker::PhantomData;
44

@@ -10,13 +10,15 @@ use serde::{
1010

1111
mod detail;
1212
mod packstream;
13+
mod request;
1314
mod summary;
1415

15-
pub use detail::Detail;
1616
#[cfg(debug_assertions)]
1717
pub use packstream::debug::Dbg;
18-
use packstream::ser::AsMap;
19-
pub use packstream::{de, ser};
18+
#[cfg(test)]
19+
pub use packstream::value::{bolt, BoltBytesBuilder};
20+
pub use packstream::{de, from_bytes, ser, to_bytes};
21+
pub use request::{Commit, Discard, Goodbye, Hello, Reset, Rollback, WrapExtra};
2022
pub use summary::{Failure, Streaming, StreamingSummary, Success, Summary};
2123

2224
pub(crate) trait Message: Serialize {
@@ -93,3 +95,15 @@ impl<'de, R: Deserialize<'de>, S: Deserialize<'de>> Deserialize<'de> for Respons
9395
)
9496
}
9597
}
98+
99+
impl<R: std::fmt::Debug, S: std::fmt::Debug> Response<R, S> {
100+
pub fn into_error(self, msg: &'static str) -> crate::errors::Error {
101+
match self {
102+
Response::Failure(f) => f.into_error(msg),
103+
otherwise => crate::Error::UnexpectedMessage(format!(
104+
"unexpected response for {}: {:?}",
105+
msg, otherwise
106+
)),
107+
}
108+
}
109+
}

lib/src/bolt/de.rs renamed to lib/src/bolt/packstream/de.rs

Lines changed: 82 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ impl<'a: 'de, 'de> de::Deserializer<'de> for Deserializer<'a> {
1818

1919
forward_to_deserialize_any! {
2020
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str
21-
string newtype_struct ignored_any
22-
map unit_struct struct enum identifier
21+
string ignored_any map unit_struct struct enum identifier
2322
}
2423

2524
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
@@ -118,6 +117,21 @@ impl<'a: 'de, 'de> de::Deserializer<'de> for Deserializer<'a> {
118117
self.deserialize_bytes(visitor)
119118
}
120119

120+
fn deserialize_newtype_struct<V>(
121+
self,
122+
name: &'static str,
123+
visitor: V,
124+
) -> Result<V::Value, Self::Error>
125+
where
126+
V: Visitor<'de>,
127+
{
128+
if name == "__neo4rs::RawBytes" {
129+
self.parse_next_item(Visitation::RawBytes, visitor)
130+
} else {
131+
self.parse_next_item(Visitation::default(), visitor)
132+
}
133+
}
134+
121135
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
122136
where
123137
V: Visitor<'de>,
@@ -158,14 +172,21 @@ impl<'de> Deserializer<'de> {
158172
return Err(Error::Empty);
159173
}
160174

161-
if let Visitation::SeqAsTuple(2) = v {
162-
return if self.bytes[0] == 0x92 {
163-
self.bytes.advance(1);
164-
Self::parse_list(v, 2, self.bytes, visitor)
165-
} else {
166-
visitor.visit_seq(ItemsParser::new(2, self.bytes))
167-
};
168-
}
175+
match v {
176+
Visitation::SeqAsTuple(2) => {
177+
return if self.bytes[0] == 0x92 {
178+
self.bytes.advance(1);
179+
Self::parse_list(v, 2, self.bytes, visitor)
180+
} else {
181+
visitor.visit_seq(ItemsParser::new(2, self.bytes))
182+
};
183+
}
184+
Visitation::RawBytes => {
185+
let bytes = self.next_item_as_bytes()?;
186+
return visitor.visit_bytes(&bytes);
187+
}
188+
_ => (),
189+
};
169190

170191
Self::parse(v, self.bytes, visitor)
171192
}
@@ -175,6 +196,22 @@ impl<'de> Deserializer<'de> {
175196
.map(|_| ())
176197
}
177198

199+
fn next_item_as_bytes(self) -> Result<Bytes, Error> {
200+
let mut full_bytes = self.bytes.clone();
201+
202+
{
203+
let this = Deserializer { bytes: self.bytes };
204+
this.skip_next_item()?;
205+
}
206+
207+
let start = full_bytes.as_ptr();
208+
let end = self.bytes.as_ptr();
209+
210+
let len = unsafe { end.offset_from(start) };
211+
full_bytes.truncate(len.unsigned_abs());
212+
Ok(full_bytes)
213+
}
214+
178215
fn parse<V: Visitor<'de>>(
179216
v: Visitation,
180217
bytes: &'de mut Bytes,
@@ -227,12 +264,11 @@ impl<'de> Deserializer<'de> {
227264
bytes: &'de mut Bytes,
228265
visitor: V,
229266
) -> Result<V::Value, Error> {
230-
debug_assert!(bytes.len() >= len);
231-
232-
let bytes = bytes.split_to(len);
267+
let bytes = Self::take_slice(len, bytes);
233268
if v.visit_bytes_as_bytes() {
234-
let bytes: &'de [u8] = unsafe { std::mem::transmute(bytes.as_ref()) };
235-
visitor.visit_borrowed_bytes(bytes)
269+
visitor.visit_borrowed_bytes(unsafe {
270+
std::slice::from_raw_parts(bytes.as_ptr(), bytes.len())
271+
})
236272
} else {
237273
visitor.visit_seq(SeqDeserializer::new(bytes.into_iter()))
238274
}
@@ -243,17 +279,20 @@ impl<'de> Deserializer<'de> {
243279
bytes: &'de mut Bytes,
244280
visitor: V,
245281
) -> Result<V::Value, Error> {
246-
debug_assert!(bytes.len() >= len);
247-
248-
let bytes = bytes.split_to(len);
249-
let bytes: &'de [u8] = unsafe { std::mem::transmute(bytes.as_ref()) };
250-
251-
match std::str::from_utf8(bytes) {
252-
Ok(s) => visitor.visit_borrowed_str(s),
282+
let bytes = Self::take_slice(len, bytes);
283+
match std::str::from_utf8(&bytes) {
284+
Ok(s) => visitor.visit_borrowed_str(unsafe {
285+
std::str::from_utf8_unchecked(std::slice::from_raw_parts(s.as_ptr(), s.len()))
286+
}),
253287
Err(e) => Err(Error::InvalidUtf8(e)),
254288
}
255289
}
256290

291+
fn take_slice(len: usize, bytes: &mut Bytes) -> Bytes {
292+
debug_assert!(bytes.len() >= len);
293+
bytes.split_to(len)
294+
}
295+
257296
fn parse_list<V: Visitor<'de>>(
258297
v: Visitation,
259298
len: usize,
@@ -299,6 +338,7 @@ impl<'de> Deserializer<'de> {
299338
}
300339
}
301340

341+
#[derive(Debug)]
302342
struct ItemsParser<'a> {
303343
len: usize,
304344
excess: usize,
@@ -339,6 +379,10 @@ impl<'a, 'de> SeqAccess<'de> for ItemsParser<'a> {
339379
let bytes = self.bytes.get();
340380
seed.deserialize(Deserializer { bytes }).map(Some)
341381
}
382+
383+
fn size_hint(&self) -> Option<usize> {
384+
Some(self.len)
385+
}
342386
}
343387

344388
impl<'a, 'de> MapAccess<'de> for ItemsParser<'a> {
@@ -364,94 +408,42 @@ impl<'a, 'de> MapAccess<'de> for ItemsParser<'a> {
364408
let bytes = self.bytes.get();
365409
seed.deserialize(Deserializer { bytes })
366410
}
411+
412+
fn size_hint(&self) -> Option<usize> {
413+
Some(self.len)
414+
}
367415
}
368416

369417
impl<'a, 'de> VariantAccess<'de> for ItemsParser<'a> {
370418
type Error = Error;
371419

372420
fn unit_variant(mut self) -> Result<(), Self::Error> {
373421
self.next_value()
374-
// if self.len != 0 {
375-
// return Err(Error::InvalidLength {
376-
// expected: 0,
377-
// actual: self.len,
378-
// });
379-
// }
380-
// Ok(())
381422
}
382423

383424
fn newtype_variant_seed<T>(mut self, seed: T) -> Result<T::Value, Self::Error>
384425
where
385426
T: DeserializeSeed<'de>,
386427
{
387428
self.next_value_seed(seed)
388-
// let bytes = self.bytes.get();
389-
// seed.deserialize(Deserializer { bytes })
390429
}
391430

392-
fn tuple_variant<V>(mut self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
431+
fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
393432
where
394433
V: Visitor<'de>,
395434
{
396-
struct TupleVariant<V> {
397-
len: usize,
398-
visitor: V,
399-
}
400-
401-
impl<'de, V> DeserializeSeed<'de> for TupleVariant<V>
402-
where
403-
V: Visitor<'de>,
404-
{
405-
type Value = V::Value;
406-
407-
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
408-
where
409-
D: serde::Deserializer<'de>,
410-
{
411-
deserializer.deserialize_tuple(self.len, self.visitor)
412-
}
413-
}
414-
415-
self.next_value_seed(TupleVariant { len, visitor })
416-
417-
// if len != self.len {
418-
// return Err(Error::InvalidLength {
419-
// expected: len,
420-
// actual: self.len,
421-
// });
422-
// }
423-
// visitor.visit_seq(self)
435+
visitor.visit_seq(self)
424436
}
425437

426438
fn struct_variant<V>(
427-
mut self,
439+
self,
428440
_fields: &'static [&'static str],
429441
visitor: V,
430442
) -> Result<V::Value, Self::Error>
431443
where
432444
V: Visitor<'de>,
433445
{
434-
struct StructVariant<V> {
435-
visitor: V,
436-
}
437-
438-
impl<'de, V> DeserializeSeed<'de> for StructVariant<V>
439-
where
440-
V: Visitor<'de>,
441-
{
442-
type Value = V::Value;
443-
444-
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
445-
where
446-
D: serde::Deserializer<'de>,
447-
{
448-
deserializer.deserialize_map(self.visitor)
449-
}
450-
}
451-
452-
self.next_value_seed(StructVariant { visitor })
453-
454-
// visitor.visit_map(self)
446+
visitor.visit_seq(self)
455447
}
456448
}
457449

@@ -479,6 +471,7 @@ enum Visitation {
479471
#[default]
480472
Default,
481473
BytesAsBytes,
474+
RawBytes,
482475
MapAsSeq,
483476
SeqAsTuple(usize),
484477
}
@@ -498,6 +491,13 @@ struct SharedBytes<'a> {
498491
_lifetime: PhantomData<&'a mut ()>,
499492
}
500493

494+
#[cfg(debug_assertions)]
495+
impl<'a> fmt::Debug for SharedBytes<'a> {
496+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
497+
crate::bolt::Dbg(unsafe { &*self.bytes }).fmt(f)
498+
}
499+
}
500+
501501
impl<'a> SharedBytes<'a> {
502502
fn new(bytes: &'a mut Bytes) -> Self {
503503
Self {

0 commit comments

Comments
 (0)