gmsol_decode/value/
anchor.rs

1use std::marker::PhantomData;
2
3use crate::Visitor;
4
5/// Visitor that produces a [`ZeroCopy`](anchor_lang::ZeroCopy).
6pub struct ZeroCopyVisitor<T>(PhantomData<T>);
7
8impl<T> Default for ZeroCopyVisitor<T> {
9    fn default() -> Self {
10        Self(Default::default())
11    }
12}
13
14impl<T> Visitor for ZeroCopyVisitor<T>
15where
16    T: anchor_lang::ZeroCopy,
17{
18    type Value = T;
19
20    fn visit_bytes(self, data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
21        use anchor_lang::prelude::{Error, ErrorCode};
22
23        let disc = T::discriminator();
24        if data.len() < disc.len() {
25            return Err(Error::from(ErrorCode::AccountDiscriminatorNotFound).into());
26        }
27        let given_disc = &data[..8];
28        if disc != given_disc {
29            return Err(Error::from(ErrorCode::AccountDiscriminatorMismatch).into());
30        }
31        let end = std::mem::size_of::<T>() + 8;
32        if data.len() < end {
33            return Err(Error::from(ErrorCode::AccountDidNotDeserialize).into());
34        }
35        let data_without_discriminator = data[8..end].to_vec();
36        Ok(*bytemuck::try_from_bytes(&data_without_discriminator)
37            .map_err(|err| crate::DecodeError::custom(format!("bytemuck: {err}")))?)
38    }
39}
40
41/// Implement [`Decode`](crate::Decode) for [`ZeroCopy`](anchor_lang::ZeroCopy).
42#[macro_export]
43macro_rules! impl_decode_for_zero_copy {
44    ($decoded:ty) => {
45        impl $crate::Decode for $decoded {
46            fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
47                decoder.decode_bytes($crate::value::ZeroCopyVisitor::<$decoded>::default())
48            }
49        }
50    };
51}
52
53/// Visitor that produces an [`AccountDeserialize`](anchor_lang::AccountDeserialize).
54pub struct AccountDeserializeVisitor<T>(PhantomData<T>);
55
56impl<T> Default for AccountDeserializeVisitor<T> {
57    fn default() -> Self {
58        Self(Default::default())
59    }
60}
61
62impl<T> Visitor for AccountDeserializeVisitor<T>
63where
64    T: anchor_lang::AccountDeserialize,
65{
66    type Value = T;
67
68    fn visit_bytes(self, mut data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
69        Ok(T::try_deserialize(&mut data)?)
70    }
71}
72
73/// Implement [`Decode`](crate::Decode) for [`AccountDeserialize`](anchor_lang::AccountDeserialize).
74#[macro_export]
75macro_rules! impl_decode_for_account_deserialize {
76    ($decoded:ty) => {
77        impl $crate::Decode for $decoded {
78            fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
79                decoder
80                    .decode_bytes($crate::value::AccountDeserializeVisitor::<$decoded>::default())
81            }
82        }
83    };
84}
85
86/// Visitor that produces an CPI [`Event`](anchor_lang::Event).
87pub struct CPIEventVisitor<T>(PhantomData<T>);
88
89impl<T> Default for CPIEventVisitor<T> {
90    fn default() -> Self {
91        Self(Default::default())
92    }
93}
94
95impl<T> Visitor for CPIEventVisitor<T>
96where
97    T: anchor_lang::Event,
98{
99    type Value = T;
100
101    fn visit_bytes(self, data: &[u8]) -> Result<Self::Value, crate::DecodeError> {
102        use anchor_lang::{
103            event::EVENT_IX_TAG_LE,
104            prelude::{Error, ErrorCode},
105        };
106
107        // Valdiate the ix tag.
108        if data.len() < EVENT_IX_TAG_LE.len() {
109            return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
110        }
111        let given_tag = &data[..8];
112        if given_tag != EVENT_IX_TAG_LE {
113            return Err(crate::DecodeError::custom("not an anchor event ix"));
114        }
115
116        let data = &data[8..];
117
118        // Validate the discriminator.
119        let disc = T::discriminator();
120        if data.len() < disc.len() {
121            return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
122        }
123        let given_disc = &data[..8];
124        if disc != given_disc {
125            return Err(Error::from(ErrorCode::InstructionDidNotDeserialize).into());
126        }
127
128        // Deserialize.
129        Ok(T::try_from_slice(&data[8..]).map_err(anchor_lang::prelude::Error::from)?)
130    }
131}
132
133/// Implement [`Decode`](crate::Decode) for CPI events.
134#[macro_export]
135macro_rules! impl_decode_for_cpi_event {
136    ($decoded:ty) => {
137        impl $crate::Decode for $decoded {
138            fn decode<D: $crate::Decoder>(decoder: D) -> Result<Self, $crate::DecodeError> {
139                decoder.decode_bytes($crate::value::CPIEventVisitor::<$decoded>::default())
140            }
141        }
142    };
143}