1use std::collections::{BTreeMap, BTreeSet};
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use nano_analysis::{Fb, FbInv, Hist1D, HistSet1D, Pb, PbInv};
13use nano_core::Event;
14use nano_spec::interpret::{
15 interpret_and_fill, interpret_and_fill_systematic, InterpretedHistograms,
16};
17use nano_spec::ResolvedPlan;
18
19use crate::datacard::{DatacardOutput, MultiProcessChannel, MultiProcessDatacard, Process};
20use crate::{events, Result, RootError};
21
22const NOMINAL_VARIATION: &str = "Nominal";
23
24#[derive(Debug, Clone, PartialEq)]
26pub struct SampleTable {
27 lumi: IntegratedLuminosity,
28 samples: Vec<Sample>,
29}
30
31impl SampleTable {
32 pub fn from_toml_str(input: &str) -> Result<Self> {
34 let raw: RawSampleTable = toml::from_str(input)
35 .map_err(|error| RootError::parse(format!("failed to parse sample TOML: {error}")))?;
36 sample_table_from_raw(raw)
37 }
38
39 pub fn from_path(path: impl AsRef<Path>) -> Result<Self> {
41 let path = path.as_ref();
42 let input = fs::read_to_string(path)?;
43 Self::from_toml_str(&input)
44 }
45
46 pub fn lumi(&self) -> IntegratedLuminosity {
47 self.lumi
48 }
49
50 pub fn samples(&self) -> &[Sample] {
51 &self.samples
52 }
53
54 fn signal_processes(&self) -> BTreeSet<String> {
55 self.samples
56 .iter()
57 .filter(|sample| matches!(sample.kind, SampleKind::Mc { signal: true, .. }))
58 .map(|sample| sample.process.clone())
59 .collect()
60 }
61}
62
63#[derive(Debug, Clone, PartialEq)]
65pub struct Sample {
66 process: String,
67 files: Vec<PathBuf>,
68 kind: SampleKind,
69}
70
71impl Sample {
72 pub fn process(&self) -> &str {
73 &self.process
74 }
75
76 pub fn files(&self) -> &[PathBuf] {
77 &self.files
78 }
79
80 pub fn is_data(&self) -> bool {
81 matches!(self.kind, SampleKind::Data)
82 }
83
84 pub fn is_signal(&self) -> bool {
85 matches!(self.kind, SampleKind::Mc { signal: true, .. })
86 }
87
88 pub fn normalization_factor(&self, lumi: IntegratedLuminosity) -> Result<f64> {
93 match self.kind {
94 SampleKind::Data => Ok(1.0),
95 SampleKind::Mc { xsec, sumw, .. } => {
96 mc_normalization_factor_pb(xsec.to_pb(), lumi.to_pb_inv(), sumw)
97 }
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq)]
103enum SampleKind {
104 Data,
105 Mc {
106 signal: bool,
107 xsec: CrossSection,
108 sumw: f64,
109 },
110}
111
112#[derive(Debug, Clone, Copy, PartialEq)]
114pub enum CrossSection {
115 Fb(Fb),
116 Pb(Pb),
117}
118
119impl CrossSection {
120 pub fn to_pb(self) -> Pb {
121 match self {
122 Self::Fb(value) => value.to_pb(),
123 Self::Pb(value) => value,
124 }
125 }
126
127 pub fn to_fb(self) -> Fb {
128 match self {
129 Self::Fb(value) => value,
130 Self::Pb(value) => value.to_fb(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq)]
137pub enum IntegratedLuminosity {
138 FbInv(FbInv),
139 PbInv(PbInv),
140}
141
142impl IntegratedLuminosity {
143 pub fn to_pb_inv(self) -> PbInv {
144 match self {
145 Self::FbInv(value) => value.to_pb_inv(),
146 Self::PbInv(value) => value,
147 }
148 }
149
150 pub fn to_fb_inv(self) -> FbInv {
151 match self {
152 Self::FbInv(value) => value,
153 Self::PbInv(value) => value.to_fb_inv(),
154 }
155 }
156}
157
158pub fn mc_normalization_factor_pb(xsec: Pb, lumi: PbInv, sumw: f64) -> Result<f64> {
160 validate_sumw(sumw, "normalization")?;
161 if !xsec.0.is_finite() || xsec.0 < 0.0 {
162 return Err(RootError::other(
163 "normalization cross-section must be finite and non-negative",
164 ));
165 }
166 if !lumi.0.is_finite() || lumi.0 <= 0.0 {
167 return Err(RootError::other(
168 "normalization luminosity must be finite and positive",
169 ));
170 }
171 Ok((xsec * lumi) / sumw)
172}
173
174pub fn mc_normalization_factor_fb(xsec: Fb, lumi: FbInv, sumw: f64) -> Result<f64> {
176 validate_sumw(sumw, "normalization")?;
177 if !xsec.0.is_finite() || xsec.0 < 0.0 {
178 return Err(RootError::other(
179 "normalization cross-section must be finite and non-negative",
180 ));
181 }
182 if !lumi.0.is_finite() || lumi.0 <= 0.0 {
183 return Err(RootError::other(
184 "normalization luminosity must be finite and positive",
185 ));
186 }
187 Ok((xsec * lumi) / sumw)
188}
189
190pub fn run_interpreted_samples(
192 table: &SampleTable,
193 plan: &ResolvedPlan,
194) -> Result<NormalizedProcessHistograms> {
195 run_interpreted_samples_with_events(table, plan, |path| events(path, &plan.read_branches))
196}
197
198pub fn run_interpreted_samples_with_events<F, I>(
203 table: &SampleTable,
204 plan: &ResolvedPlan,
205 mut events_for_file: F,
206) -> Result<NormalizedProcessHistograms>
207where
208 F: FnMut(&Path) -> Result<I>,
209 I: IntoIterator<Item = Result<Event>>,
210{
211 let mut output = NormalizedProcessHistograms::new(table.signal_processes());
212
213 for (sample_index, sample) in table.samples.iter().enumerate() {
214 let factor = sample.normalization_factor(table.lumi)?;
215 let mut histograms = InterpretedHistograms::new(plan);
216 let systematic_variations = systematic_variations(&histograms);
217 let mut events_read = 0_usize;
218 let mut selected = 0_usize;
219
220 for file in &sample.files {
221 for event in events_for_file(file)? {
222 let event = event?;
223 events_read += 1;
224 if plan.spec.has_shape_correction() {
225 for systematic in &systematic_variations {
226 let row = interpret_and_fill_systematic(
227 plan,
228 &event,
229 &mut histograms,
230 systematic,
231 )
232 .map_err(|error| RootError::other(error.to_string()))?;
233 if systematic == NOMINAL_VARIATION && row.is_some() {
234 selected += 1;
235 }
236 }
237 } else if interpret_and_fill(plan, &event, &mut histograms)
238 .map_err(|error| RootError::other(error.to_string()))?
239 .is_some()
240 {
241 selected += 1;
242 }
243 }
244 }
245
246 if !sample.is_data() {
247 histograms.scale(factor);
248 }
249 output.accumulate_sample(sample, &histograms);
250 output.sample_reports.push(SampleRunReport {
251 sample_index,
252 process: sample.process.clone(),
253 data: sample.is_data(),
254 signal: sample.is_signal(),
255 normalization_factor: factor,
256 events_read,
257 selected,
258 });
259 }
260
261 Ok(output)
262}
263
264#[derive(Debug, Clone, PartialEq)]
266pub struct NormalizedProcessHistograms {
267 signal_processes: BTreeSet<String>,
268 processes: BTreeMap<String, BTreeMap<String, HistSet1D<String>>>,
269 data_obs: BTreeMap<String, HistSet1D<String>>,
270 sample_reports: Vec<SampleRunReport>,
271}
272
273impl NormalizedProcessHistograms {
274 fn new(signal_processes: BTreeSet<String>) -> Self {
275 Self {
276 signal_processes,
277 processes: BTreeMap::new(),
278 data_obs: BTreeMap::new(),
279 sample_reports: Vec::new(),
280 }
281 }
282
283 pub fn processes(&self) -> &BTreeMap<String, BTreeMap<String, HistSet1D<String>>> {
284 &self.processes
285 }
286
287 pub fn data_obs(&self) -> &BTreeMap<String, HistSet1D<String>> {
288 &self.data_obs
289 }
290
291 pub fn sample_reports(&self) -> &[SampleRunReport] {
292 &self.sample_reports
293 }
294
295 pub fn to_datacard(&self) -> Result<MultiProcessDatacard<'_>> {
297 if self.signal_processes.len() != 1 {
298 return Err(RootError::other(format!(
299 "multi-process datacard needs exactly one signal process, found {}",
300 self.signal_processes.len()
301 )));
302 }
303 let signal_process = self
304 .signal_processes
305 .iter()
306 .next()
307 .expect("checked exactly one signal process");
308
309 let mut datacard = MultiProcessDatacard::new();
310 for (channel, data_set) in &self.data_obs {
311 let data_nominal = nominal_histogram(data_set, channel)?;
312 let mut combine_channel = MultiProcessChannel::new(channel, data_nominal);
313 let mut background_index = 1_i32;
314
315 for (process_name, histograms_by_channel) in &self.processes {
316 let Some(set) = histograms_by_channel.get(channel) else {
317 continue;
318 };
319 let index = if process_name == signal_process {
320 0
321 } else {
322 let index = background_index;
323 background_index += 1;
324 index
325 };
326 let mut process =
327 Process::new(process_name, index, nominal_histogram(set, channel)?);
328 for systematic in paired_shape_systematics(set) {
329 let up = set.get(format!("{systematic}Up"));
330 let down = set.get(format!("{systematic}Down"));
331 process = process.with_shape_systematic(systematic, up, down);
332 }
333 combine_channel = combine_channel.with_process(process);
334 }
335
336 datacard = datacard.with_channel(combine_channel);
337 }
338
339 Ok(datacard)
340 }
341
342 pub fn write_datacard(&self, output_dir: &Path) -> Result<DatacardOutput> {
344 self.to_datacard()?.write(output_dir)
345 }
346
347 fn accumulate_sample(&mut self, sample: &Sample, histograms: &InterpretedHistograms) {
348 let target = if sample.is_data() {
349 &mut self.data_obs
350 } else {
351 self.processes.entry(sample.process.clone()).or_default()
352 };
353 for (name, set) in histograms.iter() {
354 target
355 .entry(name.clone())
356 .and_modify(|existing| existing.add(set))
357 .or_insert_with(|| set.clone());
358 }
359 }
360}
361
362#[derive(Debug, Clone, PartialEq)]
364pub struct SampleRunReport {
365 pub sample_index: usize,
366 pub process: String,
367 pub data: bool,
368 pub signal: bool,
369 pub normalization_factor: f64,
370 pub events_read: usize,
371 pub selected: usize,
372}
373
374#[derive(Debug, serde::Deserialize)]
375struct RawSampleTable {
376 lumi: String,
377 #[serde(default)]
378 sample: Vec<RawSample>,
379}
380
381#[derive(Debug, serde::Deserialize)]
382struct RawSample {
383 process: String,
384 #[serde(default)]
385 signal: bool,
386 #[serde(default)]
387 data: bool,
388 files: Vec<PathBuf>,
389 xsec: Option<String>,
390 sumw: Option<f64>,
391}
392
393fn sample_table_from_raw(raw: RawSampleTable) -> Result<SampleTable> {
394 let lumi = parse_luminosity(&raw.lumi, "lumi")?;
395 if raw.sample.is_empty() {
396 return Err(RootError::other(
397 "sample table must contain at least one [[sample]] row",
398 ));
399 }
400 let samples = raw
401 .sample
402 .into_iter()
403 .enumerate()
404 .map(|(index, sample)| sample_from_raw(index, sample))
405 .collect::<Result<Vec<_>>>()?;
406
407 if !samples
408 .iter()
409 .any(|sample| matches!(sample.kind, SampleKind::Mc { signal: true, .. }))
410 {
411 return Err(RootError::other(
412 "sample table must contain at least one MC sample with signal=true",
413 ));
414 }
415
416 Ok(SampleTable { lumi, samples })
417}
418
419fn sample_from_raw(index: usize, raw: RawSample) -> Result<Sample> {
420 validate_label("process", &raw.process)?;
421 if raw.files.is_empty() {
422 return Err(RootError::other(format!(
423 "sample {index} process `{}` must list at least one file",
424 raw.process
425 )));
426 }
427
428 let kind = if raw.data || raw.xsec.is_none() {
429 if raw.data && raw.xsec.is_some() {
430 return Err(RootError::other(format!(
431 "data sample `{}` must not set xsec",
432 raw.process
433 )));
434 }
435 if raw.signal {
436 return Err(RootError::other(format!(
437 "data sample `{}` cannot set signal=true",
438 raw.process
439 )));
440 }
441 if raw.sumw.is_some() {
442 return Err(RootError::other(format!(
443 "data sample `{}` must not set sumw",
444 raw.process
445 )));
446 }
447 SampleKind::Data
448 } else {
449 let xsec = parse_cross_section(raw.xsec.as_deref().expect("checked xsec"), "xsec")?;
450 let sumw = raw.sumw.ok_or_else(|| {
451 RootError::other(format!(
452 "MC sample `{}` must set sumw with xsec",
453 raw.process
454 ))
455 })?;
456 validate_sumw(sumw, &format!("sample `{}`", raw.process))?;
457 SampleKind::Mc {
458 signal: raw.signal,
459 xsec,
460 sumw,
461 }
462 };
463
464 Ok(Sample {
465 process: raw.process,
466 files: raw.files,
467 kind,
468 })
469}
470
471fn parse_cross_section(input: &str, field: &str) -> Result<CrossSection> {
472 let (value, unit) = parse_quantity(input, field)?;
473 if !value.is_finite() || value < 0.0 {
474 return Err(RootError::other(format!(
475 "{field} must be finite and non-negative"
476 )));
477 }
478 match unit {
479 "Pb" | "pb" => Ok(CrossSection::Pb(Pb(value))),
480 "Fb" | "fb" => Ok(CrossSection::Fb(Fb(value))),
481 "PbInv" | "pb^-1" | "pb-1" | "1/pb" | "FbInv" | "fb^-1" | "fb-1" | "1/fb" => {
482 Err(RootError::other(format!(
483 "{field} unit `{unit}` is not a cross-section; expected Pb or Fb"
484 )))
485 }
486 _ => Err(RootError::other(format!(
487 "{field} unit `{unit}` is unsupported; expected Pb or Fb"
488 ))),
489 }
490}
491
492fn parse_luminosity(input: &str, field: &str) -> Result<IntegratedLuminosity> {
493 let (value, unit) = parse_quantity(input, field)?;
494 if !value.is_finite() || value <= 0.0 {
495 return Err(RootError::other(format!(
496 "{field} must be finite and positive"
497 )));
498 }
499 match unit {
500 "FbInv" | "fb^-1" | "fb-1" | "1/fb" => Ok(IntegratedLuminosity::FbInv(FbInv(value))),
501 "PbInv" | "pb^-1" | "pb-1" | "1/pb" => Ok(IntegratedLuminosity::PbInv(PbInv(value))),
502 "Pb" | "pb" | "Fb" | "fb" => Err(RootError::other(format!(
503 "{field} unit `{unit}` is not an integrated luminosity; expected FbInv or PbInv"
504 ))),
505 _ => Err(RootError::other(format!(
506 "{field} unit `{unit}` is unsupported; expected FbInv or PbInv"
507 ))),
508 }
509}
510
511fn parse_quantity<'a>(input: &'a str, field: &str) -> Result<(f64, &'a str)> {
512 let mut parts = input.split_whitespace();
513 let value = parts
514 .next()
515 .ok_or_else(|| RootError::other(format!("{field} is missing a numeric value")))?
516 .parse::<f64>()?;
517 let unit = parts
518 .next()
519 .ok_or_else(|| RootError::other(format!("{field} is missing a unit")))?;
520 if let Some(extra) = parts.next() {
521 return Err(RootError::other(format!(
522 "unexpected token `{extra}` in {field} quantity `{input}`"
523 )));
524 }
525 Ok((value, unit))
526}
527
528fn validate_sumw(sumw: f64, context: &str) -> Result<()> {
529 if !sumw.is_finite() || sumw <= 0.0 {
530 return Err(RootError::other(format!(
531 "{context} sumw must be finite and > 0"
532 )));
533 }
534 Ok(())
535}
536
537fn validate_label(kind: &str, value: &str) -> Result<()> {
538 if value.is_empty()
539 || value.chars().any(char::is_whitespace)
540 || value.contains('/')
541 || value.contains('$')
542 {
543 return Err(RootError::other(format!(
544 "{kind} `{value}` must be non-empty and contain no whitespace, `/`, or `$`"
545 )));
546 }
547 Ok(())
548}
549
550fn nominal_histogram<'a>(set: &'a HistSet1D<String>, context: &str) -> Result<&'a Hist1D> {
551 set.iter()
552 .find_map(|(systematic, hist)| (systematic == NOMINAL_VARIATION).then_some(hist))
553 .ok_or_else(|| RootError::other(format!("histogram `{context}` is missing Nominal")))
554}
555
556fn paired_shape_systematics(set: &HistSet1D<String>) -> Vec<String> {
557 let variations = set
558 .iter()
559 .map(|(name, _)| name.as_str())
560 .collect::<BTreeSet<_>>();
561 variations
562 .iter()
563 .filter_map(|name| name.strip_suffix("Up"))
564 .filter(|base| variations.contains(format!("{base}Down").as_str()))
565 .map(ToString::to_string)
566 .collect()
567}
568
569fn systematic_variations(histograms: &InterpretedHistograms) -> Vec<String> {
570 histograms
571 .iter()
572 .next()
573 .map(|(_, set)| {
574 set.iter()
575 .map(|(systematic, _)| systematic.clone())
576 .collect::<Vec<_>>()
577 })
578 .unwrap_or_else(|| vec![NOMINAL_VARIATION.to_string()])
579}