gmsol_utils/
token_config.rs

1use std::collections::BTreeSet;
2
3use anchor_lang::prelude::*;
4
5use crate::{
6    chunk_by::chunk_by,
7    fixed_str::{bytes_to_fixed_str, FixedStrError},
8    market::HasMarketMeta,
9    oracle::PriceProviderKind,
10    pubkey::DEFAULT_PUBKEY,
11    swap::HasSwapParams,
12};
13
14/// Default heartbeat duration for price updates.
15pub const DEFAULT_HEARTBEAT_DURATION: u32 = 30;
16
17/// Default precision for price.
18pub const DEFAULT_PRECISION: u8 = 4;
19
20/// Default timestamp adjustment.
21pub const DEFAULT_TIMESTAMP_ADJUSTMENT: u32 = 0;
22
23const MAX_FEEDS: usize = 4;
24const MAX_FLAGS: usize = 8;
25const MAX_NAME_LEN: usize = 32;
26
27/// Token config error.
28#[derive(Debug, thiserror::Error)]
29pub enum TokenConfigError {
30    /// Not found.
31    #[error("not found")]
32    NotFound,
33    /// Invalid provider index.
34    #[error("invalid provider index")]
35    InvalidProviderIndex,
36    /// Fixed str error.
37    #[error(transparent)]
38    FixedStr(#[from] FixedStrError),
39    /// Exceed max length limit.
40    #[error("exceed max length limit")]
41    ExceedMaxLengthLimit,
42}
43
44pub(crate) type TokenConfigResult<T> = std::result::Result<T, TokenConfigError>;
45
46/// Token Flags.
47#[derive(num_enum::IntoPrimitive)]
48#[repr(u8)]
49#[non_exhaustive]
50pub enum TokenConfigFlag {
51    /// Is initialized.
52    Initialized,
53    /// Enabled.
54    Enabled,
55    /// Is a synthetic asset.
56    Synthetic,
57    // CHECK: Cannot have more than `MAX_FLAGS` flags.
58}
59
60crate::flags!(TokenConfigFlag, MAX_FLAGS, u8);
61
62#[zero_copy]
63#[derive(PartialEq, Eq)]
64#[cfg_attr(feature = "debug", derive(derive_more::Debug))]
65#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
66pub struct TokenConfig {
67    /// Name.
68    pub name: [u8; MAX_NAME_LEN],
69    /// Flags.
70    pub flags: TokenConfigFlagContainer,
71    /// Token decimals.
72    pub token_decimals: u8,
73    /// Precision.
74    pub precision: u8,
75    /// Expected provider.
76    pub expected_provider: u8,
77    /// Price Feeds.
78    pub feeds: [FeedConfig; MAX_FEEDS],
79    /// Heartbeat duration.
80    pub heartbeat_duration: u32,
81    #[cfg_attr(feature = "debug", debug(skip))]
82    reserved: [u8; 32],
83}
84
85#[cfg(feature = "display")]
86impl std::fmt::Display for TokenConfig {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        writeln!(f, "Name: {}", self.name().unwrap_or("*unknown*"))?;
89        writeln!(f, "Enabled: {}", self.is_enabled())?;
90        writeln!(f, "Synthetic: {}", self.is_synthetic())?;
91        writeln!(f, "Decimals: {}", self.token_decimals)?;
92        writeln!(f, "Precision: {}", self.precision)?;
93        writeln!(f, "Heartbeat: {}", self.heartbeat_duration)?;
94        writeln!(
95            f,
96            "Expected Provider: {}",
97            self.expected_provider()
98                .map(|kind| kind.to_string())
99                .unwrap_or("*unknown*".to_string())
100        )?;
101        Ok(())
102    }
103}
104
105impl TokenConfig {
106    /// Get the corresponding price feed config.
107    pub fn get_feed_config(&self, kind: &PriceProviderKind) -> TokenConfigResult<&FeedConfig> {
108        let index = *kind as usize;
109        let config = self.feeds.get(index).ok_or(TokenConfigError::NotFound)?;
110        if config.feed == DEFAULT_PUBKEY {
111            Err(TokenConfigError::NotFound)
112        } else {
113            Ok(config)
114        }
115    }
116
117    /// Set feed config.
118    pub fn set_feed_config(
119        &mut self,
120        kind: &PriceProviderKind,
121        new_config: FeedConfig,
122    ) -> TokenConfigResult<()> {
123        let index = *kind as usize;
124        let config = self
125            .feeds
126            .get_mut(index)
127            .ok_or(TokenConfigError::InvalidProviderIndex)?;
128        *config = new_config;
129        Ok(())
130    }
131
132    /// Get the corresponding price feed address.
133    pub fn get_feed(&self, kind: &PriceProviderKind) -> TokenConfigResult<Pubkey> {
134        Ok(self.get_feed_config(kind)?.feed)
135    }
136
137    /// Set expected provider.
138    pub fn set_expected_provider(&mut self, provider: PriceProviderKind) {
139        self.expected_provider = provider as u8;
140    }
141
142    /// Get expected price provider kind.
143    pub fn expected_provider(&self) -> TokenConfigResult<PriceProviderKind> {
144        let kind = PriceProviderKind::try_from(self.expected_provider)
145            .map_err(|_| TokenConfigError::InvalidProviderIndex)?;
146        Ok(kind)
147    }
148
149    /// Get price feed address for the expected provider.
150    pub fn get_expected_feed(&self) -> TokenConfigResult<Pubkey> {
151        self.get_feed(&self.expected_provider()?)
152    }
153
154    /// Set enabled.
155    pub fn set_enabled(&mut self, enable: bool) {
156        self.set_flag(TokenConfigFlag::Enabled, enable)
157    }
158
159    /// Set synthetic.
160    pub fn set_synthetic(&mut self, is_synthetic: bool) {
161        self.set_flag(TokenConfigFlag::Synthetic, is_synthetic)
162    }
163
164    /// Is enabled.
165    pub fn is_enabled(&self) -> bool {
166        self.flag(TokenConfigFlag::Enabled)
167    }
168
169    /// Is synthetic.
170    pub fn is_synthetic(&self) -> bool {
171        self.flag(TokenConfigFlag::Synthetic)
172    }
173
174    /// Returns whether the config is a valid pool token config.
175    pub fn is_valid_pool_token_config(&self) -> bool {
176        !self.is_synthetic()
177    }
178
179    /// Set flag
180    pub fn set_flag(&mut self, flag: TokenConfigFlag, value: bool) {
181        self.flags.set_flag(flag, value);
182    }
183
184    /// Get flag.
185    pub fn flag(&self, flag: TokenConfigFlag) -> bool {
186        self.flags.get_flag(flag)
187    }
188
189    /// Token decimals.
190    pub fn token_decimals(&self) -> u8 {
191        self.token_decimals
192    }
193
194    /// Price Precision.
195    pub fn precision(&self) -> u8 {
196        self.precision
197    }
198
199    /// Get timestamp adjustment.
200    pub fn timestamp_adjustment(
201        &self,
202        price_provider: &PriceProviderKind,
203    ) -> TokenConfigResult<u32> {
204        Ok(self.get_feed_config(price_provider)?.timestamp_adjustment)
205    }
206
207    /// Heartbeat duration.
208    pub fn heartbeat_duration(&self) -> u32 {
209        self.heartbeat_duration
210    }
211
212    /// Get token name.
213    pub fn name(&self) -> TokenConfigResult<&str> {
214        Ok(bytes_to_fixed_str(&self.name)?)
215    }
216}
217
218impl crate::InitSpace for TokenConfig {
219    const INIT_SPACE: usize = std::mem::size_of::<Self>();
220}
221
222/// Price Feed Config.
223#[zero_copy]
224#[derive(PartialEq, Eq)]
225#[cfg_attr(feature = "debug", derive(derive_more::Debug))]
226#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
227pub struct FeedConfig {
228    #[cfg_attr(
229        feature = "serde",
230        serde(with = "serde_with::As::<serde_with::DisplayFromStr>")
231    )]
232    feed: Pubkey,
233    timestamp_adjustment: u32,
234    #[cfg_attr(feature = "debug", debug(skip))]
235    #[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
236    reserved: [u8; 28],
237}
238
239#[cfg(feature = "display")]
240impl std::fmt::Display for FeedConfig {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        write!(
243            f,
244            "feed = {}, timestamp_adjustment = {}",
245            self.feed, self.timestamp_adjustment
246        )
247    }
248}
249
250impl FeedConfig {
251    /// Create a new feed config.
252    pub fn new(feed: Pubkey) -> Self {
253        Self {
254            feed,
255            timestamp_adjustment: DEFAULT_TIMESTAMP_ADJUSTMENT,
256            reserved: Default::default(),
257        }
258    }
259
260    /// Change the timestamp adjustment.
261    pub fn with_timestamp_adjustment(mut self, timestamp_adjustment: u32) -> Self {
262        self.timestamp_adjustment = timestamp_adjustment;
263        self
264    }
265
266    /// Get feed.
267    pub fn feed(&self) -> &Pubkey {
268        &self.feed
269    }
270
271    /// Get timestamp adjustment.
272    pub fn timestamp_adjustment(&self) -> u32 {
273        self.timestamp_adjustment
274    }
275}
276
277#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
278#[cfg_attr(feature = "debug", derive(Debug))]
279pub struct UpdateTokenConfigParams {
280    /// Heartbeat duration.
281    pub heartbeat_duration: u32,
282    /// Price precision.
283    pub precision: u8,
284    /// Feeds.
285    pub feeds: Vec<Pubkey>,
286    /// Timestamp adjustments.
287    pub timestamp_adjustments: Vec<u32>,
288    /// Expected price provider.
289    pub expected_provider: Option<u8>,
290}
291
292impl Default for UpdateTokenConfigParams {
293    fn default() -> Self {
294        Self {
295            heartbeat_duration: DEFAULT_HEARTBEAT_DURATION,
296            precision: DEFAULT_PRECISION,
297            feeds: vec![DEFAULT_PUBKEY; MAX_FEEDS],
298            timestamp_adjustments: vec![DEFAULT_TIMESTAMP_ADJUSTMENT; MAX_FEEDS],
299            expected_provider: None,
300        }
301    }
302}
303
304impl<'a> From<&'a TokenConfig> for UpdateTokenConfigParams {
305    fn from(config: &'a TokenConfig) -> Self {
306        let (feeds, timestamp_adjustments) = config
307            .feeds
308            .iter()
309            .map(|config| (config.feed, config.timestamp_adjustment))
310            .unzip();
311
312        Self {
313            heartbeat_duration: config.heartbeat_duration(),
314            precision: config.precision(),
315            feeds,
316            timestamp_adjustments,
317            expected_provider: Some(config.expected_provider),
318        }
319    }
320}
321
322impl UpdateTokenConfigParams {
323    /// Update the feed address for the given price provider.
324    /// Return error when the feed was not set before.
325    pub fn update_price_feed(
326        mut self,
327        kind: &PriceProviderKind,
328        new_feed: Pubkey,
329        new_timestamp_adjustment: Option<u32>,
330    ) -> TokenConfigResult<Self> {
331        let index = *kind as usize;
332        let feed = self
333            .feeds
334            .get_mut(index)
335            .ok_or(TokenConfigError::NotFound)?;
336        let timestamp_adjustment = self
337            .timestamp_adjustments
338            .get_mut(index)
339            .ok_or(TokenConfigError::NotFound)?;
340        *feed = new_feed;
341        if let Some(new_timestamp_adjustment) = new_timestamp_adjustment {
342            *timestamp_adjustment = new_timestamp_adjustment;
343        }
344        Ok(self)
345    }
346
347    /// Set heartbeat duration.
348    pub fn with_heartbeat_duration(mut self, duration: u32) -> Self {
349        self.heartbeat_duration = duration;
350        self
351    }
352
353    /// Set precision.
354    pub fn with_precision(mut self, precision: u8) -> Self {
355        self.precision = precision;
356        self
357    }
358
359    /// Set expected provider.
360    pub fn with_expected_provider(mut self, provider: PriceProviderKind) -> Self {
361        self.expected_provider = Some(provider as u8);
362        self
363    }
364}
365
366/// Read Token Map.
367pub trait TokenMapAccess {
368    /// Get the config of the given token.
369    fn get(&self, token: &Pubkey) -> Option<&TokenConfig>;
370
371    /// Get token configs for the given market.
372    ///
373    /// Returns the token configs for `index_token`, `long_token` and `short_token`.
374    fn token_configs_for_market(&self, market: &impl HasMarketMeta) -> Option<[&TokenConfig; 3]> {
375        let meta = market.market_meta();
376        let index_token = self.get(&meta.index_token_mint)?;
377        let long_token = self.get(&meta.long_token_mint)?;
378        let short_token = self.get(&meta.short_token_mint)?;
379        Some([index_token, long_token, short_token])
380    }
381
382    /// Sort tokens by provider. This sort is stable.
383    fn sort_tokens_by_provider(&self, tokens: &mut [Pubkey]) -> Result<()> {
384        // Check the existence of token configs.
385        for token in tokens.iter() {
386            require!(self.get(token).is_some(), ErrorCode::RequireViolated);
387        }
388        tokens.sort_by_cached_key(|token| self.get(token).unwrap().expected_provider);
389        Ok(())
390    }
391}
392
393/// Tokens with feed.
394#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
395#[cfg_attr(feature = "debug", derive(Debug))]
396#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
397pub struct TokensWithFeed {
398    /// Tokens that require prices,
399    /// which must be of the same length with `feeds`.
400    pub tokens: Vec<Pubkey>,
401    /// Token feeds for the tokens,
402    /// which must be of the same length with `tokens`.
403    pub feeds: Vec<Pubkey>,
404    /// Providers set,
405    /// which must be of the same length with `nums`.
406    pub providers: Vec<u8>,
407    /// The numbers of tokens of each provider.
408    pub nums: Vec<u16>,
409}
410
411/// A record of token config.
412#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
413#[cfg_attr(feature = "debug", derive(Debug))]
414#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
415pub struct TokenRecord {
416    token: Pubkey,
417    feed: Pubkey,
418    provider: u8,
419}
420
421impl TokenRecord {
422    /// Create a new [`TokenRecord`]
423    pub fn new(token: Pubkey, feed: Pubkey, provider: PriceProviderKind) -> Self {
424        Self {
425            token,
426            feed,
427            provider: provider as u8,
428        }
429    }
430
431    /// Create a new [`TokenRecord`] from token config,
432    /// using the expected provider and feed.
433    pub fn from_config(token: Pubkey, config: &TokenConfig) -> TokenConfigResult<Self> {
434        Ok(Self::new(
435            token,
436            config.get_expected_feed()?,
437            config.expected_provider()?,
438        ))
439    }
440}
441
442impl TokensWithFeed {
443    /// Create from token records.
444    /// # Panic
445    /// Panics if the number of tokens of the same provider exceeds `u16`.
446    pub fn try_from_records(mut records: Vec<TokenRecord>) -> TokenConfigResult<Self> {
447        records.sort_by_cached_key(|r| r.provider);
448        let mut chunks = chunk_by(&records, |a, b| a.provider == b.provider);
449        let capacity = chunks.size_hint().0;
450        let mut providers = Vec::with_capacity(capacity);
451        let mut nums = Vec::with_capacity(capacity);
452        chunks.try_for_each(|chunk| {
453            providers.push(chunk[0].provider);
454            nums.push(
455                u16::try_from(chunk.len()).map_err(|_| TokenConfigError::ExceedMaxLengthLimit)?,
456            );
457            TokenConfigResult::Ok(())
458        })?;
459        Ok(Self {
460            tokens: records.iter().map(|r| r.token).collect(),
461            feeds: records.iter().map(|r| r.feed).collect(),
462            providers,
463            nums,
464        })
465    }
466}
467
468/// Collect token records for the give tokens.
469pub fn token_records<A: TokenMapAccess>(
470    token_map: &A,
471    tokens: &BTreeSet<Pubkey>,
472) -> TokenConfigResult<Vec<TokenRecord>> {
473    tokens
474        .iter()
475        .map(|token| {
476            let config = token_map.get(token).ok_or(TokenConfigError::NotFound)?;
477            TokenRecord::from_config(*token, config)
478        })
479        .collect::<TokenConfigResult<Vec<_>>>()
480}
481
482/// Tokens Collector.
483pub struct TokensCollector {
484    tokens: Vec<Pubkey>,
485}
486
487impl TokensCollector {
488    /// Create a new [`TokensCollector`].
489    pub fn new(action: Option<&impl HasSwapParams>, extra_capacity: usize) -> Self {
490        let mut tokens;
491        match action {
492            Some(action) => {
493                let swap = action.swap();
494                tokens = Vec::with_capacity(swap.num_tokens() + extra_capacity);
495                // The tokens in the swap params must be sorted.
496                tokens.extend_from_slice(swap.tokens());
497            }
498            None => tokens = Vec::with_capacity(extra_capacity),
499        }
500
501        Self { tokens }
502    }
503
504    /// Insert a new token.
505    pub fn insert_token(&mut self, token: &Pubkey) -> bool {
506        match self.tokens.binary_search(token) {
507            Ok(_) => false,
508            Err(idx) => {
509                self.tokens.insert(idx, *token);
510                true
511            }
512        }
513    }
514
515    /// Convert to a vec.
516    pub fn into_vec(mut self, token_map: &impl TokenMapAccess) -> TokenConfigResult<Vec<Pubkey>> {
517        token_map
518            .sort_tokens_by_provider(&mut self.tokens)
519            .map_err(|_| TokenConfigError::NotFound)?;
520        Ok(self.tokens)
521    }
522
523    /// Convert to [`TokensWithFeed`].
524    pub fn to_feeds(&self, token_map: &impl TokenMapAccess) -> TokenConfigResult<TokensWithFeed> {
525        let records = self
526            .tokens
527            .iter()
528            .map(|token| {
529                let config = token_map.get(token).ok_or(TokenConfigError::NotFound)?;
530                TokenRecord::from_config(*token, config)
531            })
532            .collect::<TokenConfigResult<Vec<_>>>()?;
533        TokensWithFeed::try_from_records(records)
534    }
535}
536
537/// Max number of treasury token flags.
538#[cfg(feature = "treasury")]
539pub const MAX_TREASURY_TOKEN_FLAGS: usize = 8;
540
541/// Token Flags.
542#[cfg(feature = "treasury")]
543#[derive(
544    num_enum::IntoPrimitive, Clone, Copy, strum::EnumString, strum::Display, PartialEq, Eq,
545)]
546#[strum(serialize_all = "snake_case")]
547#[cfg_attr(feature = "enum-iter", derive(strum::EnumIter))]
548#[repr(u8)]
549pub enum TokenFlag {
550    /// Allow deposit.
551    AllowDeposit,
552    /// Allow withdrawal.
553    AllowWithdrawal,
554    // CHECK: cannot have more than `MAX_TREASURY_TOKEN_FLAGS` flags.
555}