1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use crate::{Algorithm, OTPMethod};
use anyhow::Result;
use percent_encoding::percent_decode_str;
use std::str::FromStr;

#[derive(Debug, Clone)]
pub struct OTPUri {
    pub algorithm: Algorithm,
    pub label: String,
    pub secret: String,
    pub issuer: String,
    pub method: OTPMethod,
    pub digits: Option<u32>,
    pub period: Option<u32>,
    pub counter: Option<u32>,
}

impl FromStr for OTPUri {
    type Err = anyhow::Error;
    fn from_str(uri: &str) -> Result<Self, Self::Err> {
        let url = url::Url::parse(uri)?;

        if url.scheme() != "otpauth" {
            anyhow::bail!("Invalid OTP uri format, expected otpauth");
        }
        let mut period = None;
        let mut counter = None;
        let mut digits = None;
        let mut provider_name = None;
        let mut algorithm = None;
        let mut secret = None;
        let pairs = url.query_pairs();

        let method = OTPMethod::from_str(url.host_str().unwrap())?;

        let account_info = url
            .path()
            .trim_start_matches('/')
            .split(':')
            .collect::<Vec<&str>>();

        let account_name = if account_info.len() == 1 {
            account_info.get(0).unwrap()
        } else {
            // If we have "Provider:Account"
            provider_name = Some(account_info.get(0).unwrap().to_string());
            account_info.get(1).unwrap()
        };

        pairs.for_each(|(key, value)| match key.into_owned().as_str() {
            "period" => {
                period = value.parse::<u32>().ok();
            }
            "digits" => {
                digits = value.parse::<u32>().ok();
            }
            "counter" => {
                counter = value.parse::<u32>().ok();
            }
            "issuer" => {
                provider_name = Some(value.to_string());
            }
            "algorithm" => {
                algorithm = Algorithm::from_str(&value).ok();
            }
            "secret" => {
                secret = Some(value.to_string());
            }
            _ => (),
        });

        if secret.is_none() {
            anyhow::bail!("OTP uri must contain a secret");
        }

        let label = percent_decode_str(account_name).decode_utf8()?.into_owned();
        let issuer = if let Some(n) = provider_name {
            percent_decode_str(&n).decode_utf8()?.into_owned()
        } else {
            "Default".to_string()
        };

        Ok(Self {
            method,
            label,
            secret: secret.unwrap(),
            issuer,
            algorithm: algorithm.unwrap_or_default(),
            digits,
            period,
            counter,
        })
    }
}

impl Into<String> for OTPUri {
    fn into(self) -> String {
        let mut otp_uri = format!(
            "otpauth://{}/{}?secret={}&issuer={}&algorithm={}",
            self.method.to_string(),
            self.label,
            self.secret,
            self.issuer,
            self.algorithm.to_string(),
        );
        if let Some(digits) = self.digits {
            otp_uri.push_str(&format!("&digits={}", digits));
        }
        if self.method == OTPMethod::HOTP {
            otp_uri.push_str(&format!(
                "&counter={}",
                self.counter.unwrap_or(crate::HOTP_DEFAULT_COUNTER)
            ));
        } else {
            otp_uri.push_str(&format!(
                "&period={}",
                self.period.unwrap_or(crate::TOTP_DEFAULT_PERIOD)
            ));
        }
        otp_uri
    }
}