Skip to main content

nano_io/
samples.rs

1//! Sample-table parsing and per-process histogram production.
2//!
3//! This layer closes the production loop between an analysis spec and the
4//! multi-process datacard emitter: each sample is interpreted, MC samples are
5//! scaled by `xsec*lumi/sumw`, samples sharing a process are summed, and data
6//! samples are accumulated into `data_obs`.
7
8use std::collections::{BTreeMap, BTreeSet};
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use nano_analysis::{Fb, FbInv, Hist1D, HistSet1D, Pb, PbInv};
13use nano_core::Event;
14use nano_spec::interpret::{
15    interpret_and_fill, interpret_and_fill_systematic, InterpretedHistograms,
16};
17use nano_spec::ResolvedPlan;
18
19use crate::datacard::{DatacardOutput, MultiProcessChannel, MultiProcessDatacard, Process};
20use crate::{events, Result, RootError};
21
22const NOMINAL_VARIATION: &str = "Nominal";
23
24/// A parsed sample table with one integrated luminosity and many samples.
25#[derive(Debug, Clone, PartialEq)]
26pub struct SampleTable {
27    lumi: IntegratedLuminosity,
28    samples: Vec<Sample>,
29}
30
31impl SampleTable {
32    /// Parse and validate a TOML sample table.
33    pub fn from_toml_str(input: &str) -> Result<Self> {
34        let raw: RawSampleTable = toml::from_str(input)
35            .map_err(|error| RootError::parse(format!("failed to parse sample TOML: {error}")))?;
36        sample_table_from_raw(raw)
37    }
38
39    /// Load a TOML sample table from disk.
40    pub fn from_path(path: impl AsRef<Path>) -> Result<Self> {
41        let path = path.as_ref();
42        let input = fs::read_to_string(path)?;
43        Self::from_toml_str(&input)
44    }
45
46    pub fn lumi(&self) -> IntegratedLuminosity {
47        self.lumi
48    }
49
50    pub fn samples(&self) -> &[Sample] {
51        &self.samples
52    }
53
54    fn signal_processes(&self) -> BTreeSet<String> {
55        self.samples
56            .iter()
57            .filter(|sample| matches!(sample.kind, SampleKind::Mc { signal: true, .. }))
58            .map(|sample| sample.process.clone())
59            .collect()
60    }
61}
62
63/// One validated sample row.
64#[derive(Debug, Clone, PartialEq)]
65pub struct Sample {
66    process: String,
67    files: Vec<PathBuf>,
68    kind: SampleKind,
69}
70
71impl Sample {
72    pub fn process(&self) -> &str {
73        &self.process
74    }
75
76    pub fn files(&self) -> &[PathBuf] {
77        &self.files
78    }
79
80    pub fn is_data(&self) -> bool {
81        matches!(self.kind, SampleKind::Data)
82    }
83
84    pub fn is_signal(&self) -> bool {
85        matches!(self.kind, SampleKind::Mc { signal: true, .. })
86    }
87
88    /// Per-event normalization for this sample.
89    ///
90    /// MC uses the same formula as `higgs4l_stack_opendata.rs`:
91    /// `luminosity_pb * xsec_pb / sumw`. Data samples return `1.0`.
92    pub fn normalization_factor(&self, lumi: IntegratedLuminosity) -> Result<f64> {
93        match self.kind {
94            SampleKind::Data => Ok(1.0),
95            SampleKind::Mc { xsec, sumw, .. } => {
96                mc_normalization_factor_pb(xsec.to_pb(), lumi.to_pb_inv(), sumw)
97            }
98        }
99    }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq)]
103enum SampleKind {
104    Data,
105    Mc {
106        signal: bool,
107        xsec: CrossSection,
108        sumw: f64,
109    },
110}
111
112/// A cross-section parsed from the sample table.
113#[derive(Debug, Clone, Copy, PartialEq)]
114pub enum CrossSection {
115    Fb(Fb),
116    Pb(Pb),
117}
118
119impl CrossSection {
120    pub fn to_pb(self) -> Pb {
121        match self {
122            Self::Fb(value) => value.to_pb(),
123            Self::Pb(value) => value,
124        }
125    }
126
127    pub fn to_fb(self) -> Fb {
128        match self {
129            Self::Fb(value) => value,
130            Self::Pb(value) => value.to_fb(),
131        }
132    }
133}
134
135/// An integrated luminosity parsed from the sample table.
136#[derive(Debug, Clone, Copy, PartialEq)]
137pub enum IntegratedLuminosity {
138    FbInv(FbInv),
139    PbInv(PbInv),
140}
141
142impl IntegratedLuminosity {
143    pub fn to_pb_inv(self) -> PbInv {
144        match self {
145            Self::FbInv(value) => value.to_pb_inv(),
146            Self::PbInv(value) => value,
147        }
148    }
149
150    pub fn to_fb_inv(self) -> FbInv {
151        match self {
152            Self::FbInv(value) => value,
153            Self::PbInv(value) => value.to_fb_inv(),
154        }
155    }
156}
157
158/// MC normalization in pb/pb^-1 units: `xsec * lumi / sumw`.
159pub fn mc_normalization_factor_pb(xsec: Pb, lumi: PbInv, sumw: f64) -> Result<f64> {
160    validate_sumw(sumw, "normalization")?;
161    if !xsec.0.is_finite() || xsec.0 < 0.0 {
162        return Err(RootError::other(
163            "normalization cross-section must be finite and non-negative",
164        ));
165    }
166    if !lumi.0.is_finite() || lumi.0 <= 0.0 {
167        return Err(RootError::other(
168            "normalization luminosity must be finite and positive",
169        ));
170    }
171    Ok((xsec * lumi) / sumw)
172}
173
174/// MC normalization in fb/fb^-1 units: `xsec * lumi / sumw`.
175pub fn mc_normalization_factor_fb(xsec: Fb, lumi: FbInv, sumw: f64) -> Result<f64> {
176    validate_sumw(sumw, "normalization")?;
177    if !xsec.0.is_finite() || xsec.0 < 0.0 {
178        return Err(RootError::other(
179            "normalization cross-section must be finite and non-negative",
180        ));
181    }
182    if !lumi.0.is_finite() || lumi.0 <= 0.0 {
183        return Err(RootError::other(
184            "normalization luminosity must be finite and positive",
185        ));
186    }
187    Ok((xsec * lumi) / sumw)
188}
189
190/// Run a validated analysis plan over all ROOT files in a sample table.
191pub fn run_interpreted_samples(
192    table: &SampleTable,
193    plan: &ResolvedPlan,
194) -> Result<NormalizedProcessHistograms> {
195    run_interpreted_samples_with_events(table, plan, |path| events(path, &plan.read_branches))
196}
197
198/// Run a validated analysis plan using a caller-supplied event source.
199///
200/// Tests can supply synthetic events while production callers use
201/// [`run_interpreted_samples`] to stream ROOT files.
202pub fn run_interpreted_samples_with_events<F, I>(
203    table: &SampleTable,
204    plan: &ResolvedPlan,
205    mut events_for_file: F,
206) -> Result<NormalizedProcessHistograms>
207where
208    F: FnMut(&Path) -> Result<I>,
209    I: IntoIterator<Item = Result<Event>>,
210{
211    let mut output = NormalizedProcessHistograms::new(table.signal_processes());
212
213    for (sample_index, sample) in table.samples.iter().enumerate() {
214        let factor = sample.normalization_factor(table.lumi)?;
215        let mut histograms = InterpretedHistograms::new(plan);
216        let systematic_variations = systematic_variations(&histograms);
217        let mut events_read = 0_usize;
218        let mut selected = 0_usize;
219
220        for file in &sample.files {
221            for event in events_for_file(file)? {
222                let event = event?;
223                events_read += 1;
224                if plan.spec.has_shape_correction() {
225                    for systematic in &systematic_variations {
226                        let row = interpret_and_fill_systematic(
227                            plan,
228                            &event,
229                            &mut histograms,
230                            systematic,
231                        )
232                        .map_err(|error| RootError::other(error.to_string()))?;
233                        if systematic == NOMINAL_VARIATION && row.is_some() {
234                            selected += 1;
235                        }
236                    }
237                } else if interpret_and_fill(plan, &event, &mut histograms)
238                    .map_err(|error| RootError::other(error.to_string()))?
239                    .is_some()
240                {
241                    selected += 1;
242                }
243            }
244        }
245
246        if !sample.is_data() {
247            histograms.scale(factor);
248        }
249        output.accumulate_sample(sample, &histograms);
250        output.sample_reports.push(SampleRunReport {
251            sample_index,
252            process: sample.process.clone(),
253            data: sample.is_data(),
254            signal: sample.is_signal(),
255            normalization_factor: factor,
256            events_read,
257            selected,
258        });
259    }
260
261    Ok(output)
262}
263
264/// Per-process, normalized histogram output from a sample-table run.
265#[derive(Debug, Clone, PartialEq)]
266pub struct NormalizedProcessHistograms {
267    signal_processes: BTreeSet<String>,
268    processes: BTreeMap<String, BTreeMap<String, HistSet1D<String>>>,
269    data_obs: BTreeMap<String, HistSet1D<String>>,
270    sample_reports: Vec<SampleRunReport>,
271}
272
273impl NormalizedProcessHistograms {
274    fn new(signal_processes: BTreeSet<String>) -> Self {
275        Self {
276            signal_processes,
277            processes: BTreeMap::new(),
278            data_obs: BTreeMap::new(),
279            sample_reports: Vec::new(),
280        }
281    }
282
283    pub fn processes(&self) -> &BTreeMap<String, BTreeMap<String, HistSet1D<String>>> {
284        &self.processes
285    }
286
287    pub fn data_obs(&self) -> &BTreeMap<String, HistSet1D<String>> {
288        &self.data_obs
289    }
290
291    pub fn sample_reports(&self) -> &[SampleRunReport] {
292        &self.sample_reports
293    }
294
295    /// Build a Combine datacard from the normalized nominal histograms.
296    pub fn to_datacard(&self) -> Result<MultiProcessDatacard<'_>> {
297        if self.signal_processes.len() != 1 {
298            return Err(RootError::other(format!(
299                "multi-process datacard needs exactly one signal process, found {}",
300                self.signal_processes.len()
301            )));
302        }
303        let signal_process = self
304            .signal_processes
305            .iter()
306            .next()
307            .expect("checked exactly one signal process");
308
309        let mut datacard = MultiProcessDatacard::new();
310        for (channel, data_set) in &self.data_obs {
311            let data_nominal = nominal_histogram(data_set, channel)?;
312            let mut combine_channel = MultiProcessChannel::new(channel, data_nominal);
313            let mut background_index = 1_i32;
314
315            for (process_name, histograms_by_channel) in &self.processes {
316                let Some(set) = histograms_by_channel.get(channel) else {
317                    continue;
318                };
319                let index = if process_name == signal_process {
320                    0
321                } else {
322                    let index = background_index;
323                    background_index += 1;
324                    index
325                };
326                let mut process =
327                    Process::new(process_name, index, nominal_histogram(set, channel)?);
328                for systematic in paired_shape_systematics(set) {
329                    let up = set.get(format!("{systematic}Up"));
330                    let down = set.get(format!("{systematic}Down"));
331                    process = process.with_shape_systematic(systematic, up, down);
332                }
333                combine_channel = combine_channel.with_process(process);
334            }
335
336            datacard = datacard.with_channel(combine_channel);
337        }
338
339        Ok(datacard)
340    }
341
342    /// Write `datacard.txt` and `shapes.root` for the normalized output.
343    pub fn write_datacard(&self, output_dir: &Path) -> Result<DatacardOutput> {
344        self.to_datacard()?.write(output_dir)
345    }
346
347    fn accumulate_sample(&mut self, sample: &Sample, histograms: &InterpretedHistograms) {
348        let target = if sample.is_data() {
349            &mut self.data_obs
350        } else {
351            self.processes.entry(sample.process.clone()).or_default()
352        };
353        for (name, set) in histograms.iter() {
354            target
355                .entry(name.clone())
356                .and_modify(|existing| existing.add(set))
357                .or_insert_with(|| set.clone());
358        }
359    }
360}
361
362/// Summary for one processed sample row.
363#[derive(Debug, Clone, PartialEq)]
364pub struct SampleRunReport {
365    pub sample_index: usize,
366    pub process: String,
367    pub data: bool,
368    pub signal: bool,
369    pub normalization_factor: f64,
370    pub events_read: usize,
371    pub selected: usize,
372}
373
374#[derive(Debug, serde::Deserialize)]
375struct RawSampleTable {
376    lumi: String,
377    #[serde(default)]
378    sample: Vec<RawSample>,
379}
380
381#[derive(Debug, serde::Deserialize)]
382struct RawSample {
383    process: String,
384    #[serde(default)]
385    signal: bool,
386    #[serde(default)]
387    data: bool,
388    files: Vec<PathBuf>,
389    xsec: Option<String>,
390    sumw: Option<f64>,
391}
392
393fn sample_table_from_raw(raw: RawSampleTable) -> Result<SampleTable> {
394    let lumi = parse_luminosity(&raw.lumi, "lumi")?;
395    if raw.sample.is_empty() {
396        return Err(RootError::other(
397            "sample table must contain at least one [[sample]] row",
398        ));
399    }
400    let samples = raw
401        .sample
402        .into_iter()
403        .enumerate()
404        .map(|(index, sample)| sample_from_raw(index, sample))
405        .collect::<Result<Vec<_>>>()?;
406
407    if !samples
408        .iter()
409        .any(|sample| matches!(sample.kind, SampleKind::Mc { signal: true, .. }))
410    {
411        return Err(RootError::other(
412            "sample table must contain at least one MC sample with signal=true",
413        ));
414    }
415
416    Ok(SampleTable { lumi, samples })
417}
418
419fn sample_from_raw(index: usize, raw: RawSample) -> Result<Sample> {
420    validate_label("process", &raw.process)?;
421    if raw.files.is_empty() {
422        return Err(RootError::other(format!(
423            "sample {index} process `{}` must list at least one file",
424            raw.process
425        )));
426    }
427
428    let kind = if raw.data || raw.xsec.is_none() {
429        if raw.data && raw.xsec.is_some() {
430            return Err(RootError::other(format!(
431                "data sample `{}` must not set xsec",
432                raw.process
433            )));
434        }
435        if raw.signal {
436            return Err(RootError::other(format!(
437                "data sample `{}` cannot set signal=true",
438                raw.process
439            )));
440        }
441        if raw.sumw.is_some() {
442            return Err(RootError::other(format!(
443                "data sample `{}` must not set sumw",
444                raw.process
445            )));
446        }
447        SampleKind::Data
448    } else {
449        let xsec = parse_cross_section(raw.xsec.as_deref().expect("checked xsec"), "xsec")?;
450        let sumw = raw.sumw.ok_or_else(|| {
451            RootError::other(format!(
452                "MC sample `{}` must set sumw with xsec",
453                raw.process
454            ))
455        })?;
456        validate_sumw(sumw, &format!("sample `{}`", raw.process))?;
457        SampleKind::Mc {
458            signal: raw.signal,
459            xsec,
460            sumw,
461        }
462    };
463
464    Ok(Sample {
465        process: raw.process,
466        files: raw.files,
467        kind,
468    })
469}
470
471fn parse_cross_section(input: &str, field: &str) -> Result<CrossSection> {
472    let (value, unit) = parse_quantity(input, field)?;
473    if !value.is_finite() || value < 0.0 {
474        return Err(RootError::other(format!(
475            "{field} must be finite and non-negative"
476        )));
477    }
478    match unit {
479        "Pb" | "pb" => Ok(CrossSection::Pb(Pb(value))),
480        "Fb" | "fb" => Ok(CrossSection::Fb(Fb(value))),
481        "PbInv" | "pb^-1" | "pb-1" | "1/pb" | "FbInv" | "fb^-1" | "fb-1" | "1/fb" => {
482            Err(RootError::other(format!(
483                "{field} unit `{unit}` is not a cross-section; expected Pb or Fb"
484            )))
485        }
486        _ => Err(RootError::other(format!(
487            "{field} unit `{unit}` is unsupported; expected Pb or Fb"
488        ))),
489    }
490}
491
492fn parse_luminosity(input: &str, field: &str) -> Result<IntegratedLuminosity> {
493    let (value, unit) = parse_quantity(input, field)?;
494    if !value.is_finite() || value <= 0.0 {
495        return Err(RootError::other(format!(
496            "{field} must be finite and positive"
497        )));
498    }
499    match unit {
500        "FbInv" | "fb^-1" | "fb-1" | "1/fb" => Ok(IntegratedLuminosity::FbInv(FbInv(value))),
501        "PbInv" | "pb^-1" | "pb-1" | "1/pb" => Ok(IntegratedLuminosity::PbInv(PbInv(value))),
502        "Pb" | "pb" | "Fb" | "fb" => Err(RootError::other(format!(
503            "{field} unit `{unit}` is not an integrated luminosity; expected FbInv or PbInv"
504        ))),
505        _ => Err(RootError::other(format!(
506            "{field} unit `{unit}` is unsupported; expected FbInv or PbInv"
507        ))),
508    }
509}
510
511fn parse_quantity<'a>(input: &'a str, field: &str) -> Result<(f64, &'a str)> {
512    let mut parts = input.split_whitespace();
513    let value = parts
514        .next()
515        .ok_or_else(|| RootError::other(format!("{field} is missing a numeric value")))?
516        .parse::<f64>()?;
517    let unit = parts
518        .next()
519        .ok_or_else(|| RootError::other(format!("{field} is missing a unit")))?;
520    if let Some(extra) = parts.next() {
521        return Err(RootError::other(format!(
522            "unexpected token `{extra}` in {field} quantity `{input}`"
523        )));
524    }
525    Ok((value, unit))
526}
527
528fn validate_sumw(sumw: f64, context: &str) -> Result<()> {
529    if !sumw.is_finite() || sumw <= 0.0 {
530        return Err(RootError::other(format!(
531            "{context} sumw must be finite and > 0"
532        )));
533    }
534    Ok(())
535}
536
537fn validate_label(kind: &str, value: &str) -> Result<()> {
538    if value.is_empty()
539        || value.chars().any(char::is_whitespace)
540        || value.contains('/')
541        || value.contains('$')
542    {
543        return Err(RootError::other(format!(
544            "{kind} `{value}` must be non-empty and contain no whitespace, `/`, or `$`"
545        )));
546    }
547    Ok(())
548}
549
550fn nominal_histogram<'a>(set: &'a HistSet1D<String>, context: &str) -> Result<&'a Hist1D> {
551    set.iter()
552        .find_map(|(systematic, hist)| (systematic == NOMINAL_VARIATION).then_some(hist))
553        .ok_or_else(|| RootError::other(format!("histogram `{context}` is missing Nominal")))
554}
555
556fn paired_shape_systematics(set: &HistSet1D<String>) -> Vec<String> {
557    let variations = set
558        .iter()
559        .map(|(name, _)| name.as_str())
560        .collect::<BTreeSet<_>>();
561    variations
562        .iter()
563        .filter_map(|name| name.strip_suffix("Up"))
564        .filter(|base| variations.contains(format!("{base}Down").as_str()))
565        .map(ToString::to_string)
566        .collect()
567}
568
569fn systematic_variations(histograms: &InterpretedHistograms) -> Vec<String> {
570    histograms
571        .iter()
572        .next()
573        .map(|(_, set)| {
574            set.iter()
575                .map(|(systematic, _)| systematic.clone())
576                .collect::<Vec<_>>()
577        })
578        .unwrap_or_else(|| vec![NOMINAL_VARIATION.to_string()])
579}