work
authorBen Pfaff <blp@cs.stanford.edu>
Wed, 13 Aug 2025 15:55:11 +0000 (08:55 -0700)
committerBen Pfaff <blp@cs.stanford.edu>
Wed, 13 Aug 2025 15:55:11 +0000 (08:55 -0700)
rust/pspp/src/main.rs
rust/pspp/src/output/driver.rs

index 2c7287e95d2264851a2b9cb6df2ad9362be8568a..dacee9fcee9ecfc354d5081f632f621e208716ae 100644 (file)
@@ -20,7 +20,7 @@ use encoding_rs::Encoding;
 use pspp::{
     crypto::EncryptedFile,
     output::{
-        driver::{Config, Driver, DriverType},
+        driver::{Config, Driver},
         Details, Item, Text,
     },
     sys::{
@@ -38,7 +38,6 @@ use std::{
     io::{stdout, BufReader, Write},
     path::{Path, PathBuf},
     rc::Rc,
-    str::{self, FromStr},
     sync::Arc,
 };
 use thiserror::Error as ThisError;
@@ -255,7 +254,7 @@ struct Show {
     mode: Mode,
 
     /// Output format.
-    #[arg(long, value_parser = ShowFormat::from_str)]
+    #[arg(long)]
     format: Option<ShowFormat>,
 
     /// The encoding to use.
@@ -321,62 +320,72 @@ impl Show {
         let format = if let Some(format) = self.format {
             format
         } else if let Some(output_file) = &self.output_file {
-            ShowFormat::from_str(
-                output_file
-                    .extension()
-                    .unwrap_or(OsStr::new(""))
-                    .to_str()
-                    .unwrap_or(""),
-            )
-            .map_err(|_| {
-                anyhow!(
-                    "{}: no default output format for file name",
-                    output_file.display()
-                )
-            })?
+            match output_file
+                .extension()
+                .unwrap_or(OsStr::new(""))
+                .to_str()
+                .unwrap_or("")
+            {
+                "json" => ShowFormat::Json,
+                "ndjson" => ShowFormat::Ndjson,
+                _ => ShowFormat::Output,
+            }
         } else {
             ShowFormat::Json
         };
 
-        let output = if let ShowFormat::Output(driver) = format {
-            let mut config = String::new();
-
-            #[derive(Serialize)]
-            struct DriverConfig {
-                driver: DriverType,
-            }
-            config.push_str(&toml::to_string_pretty(&DriverConfig { driver }).unwrap());
+        let output = match format {
+            ShowFormat::Output => {
+                let mut config = String::new();
 
-            if let Some(file) = &self.output_file {
-                #[derive(Serialize)]
-                struct File<'a> {
-                    file: &'a Path,
+                if let Some(file) = &self.output_file {
+                    #[derive(Serialize)]
+                    struct File<'a> {
+                        file: &'a Path,
+                    }
+                    let file = File {
+                        file: file.as_path(),
+                    };
+                    let toml_file = toml::to_string_pretty(&file).unwrap();
+                    config.push_str(&toml_file);
                 }
-                let file = File {
-                    file: file.as_path(),
-                };
-                let toml_file = toml::to_string_pretty(&file).unwrap();
-                config.push_str(&toml_file);
-
                 for option in &self.output_options {
                     writeln!(&mut config, "{option}").unwrap();
                 }
-            }
 
-            let config: Config = toml::from_str(&config)?;
-            Output::Driver {
-                mode: self.mode,
-                driver: Rc::new(RefCell::new(Box::new(<dyn Driver>::new(&config)?))),
+                let table: toml::Table = toml::from_str(&config)?;
+                if !table.contains_key("driver")
+                    && let Some(file) = &self.output_file
+                {
+                    let driver =
+                        <dyn Driver>::driver_type_from_filename(file).ok_or_else(|| {
+                            anyhow!("{}: no default output format for file name", file.display())
+                        })?;
+
+                    #[derive(Serialize)]
+                    struct DriverConfig {
+                        driver: &'static str,
+                    }
+                    config.insert_str(
+                        0,
+                        &toml::to_string_pretty(&DriverConfig { driver }).unwrap(),
+                    );
+                }
+
+                let config: Config = toml::from_str(&config)?;
+                Output::Driver {
+                    mode: self.mode,
+                    driver: Rc::new(RefCell::new(Box::new(<dyn Driver>::new(&config)?))),
+                }
             }
-        } else {
-            Output::Json {
+            ShowFormat::Json | ShowFormat::Ndjson => Output::Json {
                 pretty: format == ShowFormat::Json,
                 writer: if let Some(output_file) = &self.output_file {
                     Rc::new(RefCell::new(Box::new(File::create(output_file)?)))
                 } else {
                     Rc::new(RefCell::new(Box::new(stdout())))
                 },
-            }
+            },
         };
 
         let reader = File::open(&self.input_file)?;
@@ -512,7 +521,7 @@ impl Display for Mode {
     }
 }
 
-#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize)]
+#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, ValueEnum)]
 #[serde(rename_all = "snake_case")]
 enum ShowFormat {
     /// Pretty-printed JSON.
@@ -520,27 +529,7 @@ enum ShowFormat {
     Json,
     /// Newline-delimited JSON.
     Ndjson,
-    Output(DriverType),
-}
-
-#[derive(ThisError, Debug)]
-#[error("{0}: unknown format")]
-struct UnknownFormat(String);
-
-impl FromStr for ShowFormat {
-    type Err = UnknownFormat;
-
-    fn from_str(s: &str) -> Result<Self, Self::Err> {
-        if s.eq_ignore_ascii_case("json") {
-            Ok(Self::Json)
-        } else if s.eq_ignore_ascii_case("ndjson") {
-            Ok(Self::Ndjson)
-        } else if let Ok(driver_type) = DriverType::from_str(s, true) {
-            Ok(Self::Output(driver_type))
-        } else {
-            Err(UnknownFormat(String::from(s)))
-        }
-    }
+    Output,
 }
 
 fn main() -> Result<()> {
index 9b272f6a05ceb5bdca383ca43ea19594b304131a..0d722e1f630882bb4acf4e71764cbacd647c918f 100644 (file)
@@ -14,7 +14,7 @@
 // You should have received a copy of the GNU General Public License along with
 // this program.  If not, see <http://www.gnu.org/licenses/>.
 
-use std::{borrow::Cow, sync::Arc};
+use std::{borrow::Cow, path::Path, sync::Arc};
 
 use clap::ValueEnum;
 use serde::{Deserialize, Serialize};
@@ -120,6 +120,17 @@ impl dyn Driver {
             Config::Spv(spv_config) => Ok(Box::new(SpvDriver::new(spv_config)?)),
         }
     }
+
+    pub fn driver_type_from_filename(file: impl AsRef<Path>) -> Option<&'static str> {
+        match file.as_ref().extension()?.to_str()? {
+            "txt" | "text" => Some("text"),
+            "pdf" => Some("pdf"),
+            "htm" | "html" => Some("html"),
+            "csv" => Some("csv"),
+            "spv" => Some("spv"),
+            _ => None,
+        }
+    }
 }
 
 #[cfg(test)]