1use std::collections::{BTreeMap, BTreeSet};
11use std::fmt::Write as _;
12use std::fs;
13use std::path::{Path, PathBuf};
14
15use nano_analysis::Hist1D;
16
17use crate::{writer, Result, RootError};
18
19const DATACARD_FILE: &str = "datacard.txt";
20const SHAPES_FILE: &str = "shapes.root";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct DatacardOutput {
25 pub datacard_path: PathBuf,
26 pub shapes_path: PathBuf,
27}
28
29#[derive(Debug, Clone, PartialEq)]
31pub struct FlatWeightSystematic {
32 pub name: String,
33 pub up: f64,
34 pub down: f64,
35}
36
37impl FlatWeightSystematic {
38 pub fn new(name: impl Into<String>, up: f64, down: f64) -> Self {
40 Self {
41 name: name.into(),
42 up,
43 down,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq)]
50pub struct ShapeVariation<'a> {
51 pub up: &'a Hist1D,
52 pub down: &'a Hist1D,
53}
54
55impl<'a> ShapeVariation<'a> {
56 pub fn new(up: &'a Hist1D, down: &'a Hist1D) -> Self {
57 Self { up, down }
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
63pub struct Process<'a> {
64 name: String,
65 index: i32,
66 nominal: &'a Hist1D,
67 shape_variations: BTreeMap<String, ShapeVariation<'a>>,
68 flat_weight_systematics: BTreeMap<String, FlatWeightSystematic>,
69}
70
71impl<'a> Process<'a> {
72 pub fn new(name: impl Into<String>, index: i32, nominal: &'a Hist1D) -> Self {
75 Self {
76 name: name.into(),
77 index,
78 nominal,
79 shape_variations: BTreeMap::new(),
80 flat_weight_systematics: BTreeMap::new(),
81 }
82 }
83
84 pub fn with_shape_systematic(
86 mut self,
87 name: impl Into<String>,
88 up: &'a Hist1D,
89 down: &'a Hist1D,
90 ) -> Self {
91 self.shape_variations
92 .insert(name.into(), ShapeVariation::new(up, down));
93 self
94 }
95
96 pub fn with_flat_weight_systematic(mut self, systematic: FlatWeightSystematic) -> Self {
98 self.flat_weight_systematics
99 .insert(systematic.name.clone(), systematic);
100 self
101 }
102
103 pub fn name(&self) -> &str {
104 &self.name
105 }
106
107 pub fn index(&self) -> i32 {
108 self.index
109 }
110
111 pub fn nominal(&self) -> &Hist1D {
112 self.nominal
113 }
114
115 pub fn shape_variations(&self) -> &BTreeMap<String, ShapeVariation<'a>> {
116 &self.shape_variations
117 }
118
119 pub fn flat_weight_systematics(&self) -> &BTreeMap<String, FlatWeightSystematic> {
120 &self.flat_weight_systematics
121 }
122}
123
124#[derive(Debug, Clone, PartialEq)]
126pub struct MultiProcessChannel<'a> {
127 name: String,
128 data_obs: &'a Hist1D,
129 processes: Vec<Process<'a>>,
130}
131
132impl<'a> MultiProcessChannel<'a> {
133 pub fn new(name: impl Into<String>, data_obs: &'a Hist1D) -> Self {
135 Self {
136 name: name.into(),
137 data_obs,
138 processes: Vec::new(),
139 }
140 }
141
142 pub fn with_process(mut self, process: Process<'a>) -> Self {
144 self.processes.push(process);
145 self
146 }
147
148 pub fn name(&self) -> &str {
149 &self.name
150 }
151
152 pub fn data_obs(&self) -> &Hist1D {
153 self.data_obs
154 }
155
156 pub fn processes(&self) -> &[Process<'a>] {
157 &self.processes
158 }
159}
160
161#[derive(Debug, Clone, PartialEq)]
163pub struct Channel<'a> {
164 name: String,
165 nominal: &'a Hist1D,
166 data_obs: &'a Hist1D,
167 shape_variations: BTreeMap<String, ShapeVariation<'a>>,
168}
169
170impl<'a> Channel<'a> {
171 pub fn new(name: impl Into<String>, nominal: &'a Hist1D, data_obs: &'a Hist1D) -> Self {
173 Self {
174 name: name.into(),
175 nominal,
176 data_obs,
177 shape_variations: BTreeMap::new(),
178 }
179 }
180
181 pub fn with_shape_systematic(
183 mut self,
184 name: impl Into<String>,
185 up: &'a Hist1D,
186 down: &'a Hist1D,
187 ) -> Self {
188 self.shape_variations
189 .insert(name.into(), ShapeVariation::new(up, down));
190 self
191 }
192
193 pub fn name(&self) -> &str {
194 &self.name
195 }
196
197 pub fn nominal(&self) -> &Hist1D {
198 self.nominal
199 }
200
201 pub fn data_obs(&self) -> &Hist1D {
202 self.data_obs
203 }
204
205 pub fn shape_variations(&self) -> &BTreeMap<String, ShapeVariation<'a>> {
206 &self.shape_variations
207 }
208}
209
210#[derive(Debug, Clone, PartialEq)]
212pub struct MultiProcessDatacard<'a> {
213 channels: Vec<MultiProcessChannel<'a>>,
214}
215
216impl<'a> MultiProcessDatacard<'a> {
217 pub fn new() -> Self {
219 Self {
220 channels: Vec::new(),
221 }
222 }
223
224 pub fn with_channel(mut self, channel: MultiProcessChannel<'a>) -> Self {
227 self.channels.push(channel);
228 self
229 }
230
231 pub fn channels(&self) -> &[MultiProcessChannel<'a>] {
232 &self.channels
233 }
234
235 pub fn write(&self, output_dir: &Path) -> Result<DatacardOutput> {
237 self.validate()?;
238 fs::create_dir_all(output_dir)?;
239 let datacard_path = output_dir.join(DATACARD_FILE);
240 let shapes_path = output_dir.join(SHAPES_FILE);
241
242 let shape_inputs = self.shape_inputs();
243 let borrowed = shape_inputs
244 .iter()
245 .map(|(name, hist)| (name.as_str(), *hist))
246 .collect::<Vec<_>>();
247 writer::write_histograms(&shapes_path, &borrowed)?;
248
249 let text = self.to_text(SHAPES_FILE)?;
250 fs::write(&datacard_path, text)?;
251
252 Ok(DatacardOutput {
253 datacard_path,
254 shapes_path,
255 })
256 }
257
258 pub fn to_text(&self, shapes_file: &str) -> Result<String> {
260 self.validate()?;
261 validate_shapes_file(shapes_file)?;
262
263 let columns = self.columns();
264 let shape_systematics = self.shape_systematic_names();
265 let flat_systematics = self.flat_systematic_names();
266 let mut out = String::new();
267
268 writeln!(out, "imax {} number of channels", self.channels.len())?;
269 writeln!(
270 out,
271 "jmax {} number of processes minus 1",
272 self.unique_process_count() - 1
273 )?;
274 writeln!(
275 out,
276 "kmax {} number of nuisance parameters",
277 shape_systematics.len() + flat_systematics.len()
278 )?;
279 writeln!(out, "------------")?;
280 writeln!(
281 out,
282 "shapes * * {shapes_file} $CHANNEL/$PROCESS $CHANNEL/$PROCESS_$SYSTEMATIC"
283 )?;
284 writeln!(out, "------------")?;
285 writeln!(
286 out,
287 "bin {}",
288 join(self.channels.iter().map(|channel| channel.name()))
289 )?;
290 writeln!(
291 out,
292 "observation {}",
293 join(
294 self.channels
295 .iter()
296 .map(|channel| format_rate(rate(channel.data_obs())))
297 )
298 )?;
299 writeln!(out, "------------")?;
300 writeln!(
301 out,
302 "bin {}",
303 join(columns.iter().map(|(channel, _)| channel.name()))
304 )?;
305 writeln!(
306 out,
307 "process {}",
308 join(columns.iter().map(|(_, process)| process.name()))
309 )?;
310 writeln!(
311 out,
312 "process {}",
313 join(
314 columns
315 .iter()
316 .map(|(_, process)| process.index().to_string())
317 )
318 )?;
319 writeln!(
320 out,
321 "rate {}",
322 join(
323 columns
324 .iter()
325 .map(|(_, process)| format_rate(rate(process.nominal())))
326 )
327 )?;
328 writeln!(out, "------------")?;
329
330 for systematic in shape_systematics {
331 writeln!(
332 out,
333 "{systematic} shape {}",
334 join(columns.iter().map(|(_, process)| {
335 if process.shape_variations.contains_key(&systematic) {
336 "1"
337 } else {
338 "-"
339 }
340 }))
341 )?;
342 }
343
344 for systematic in flat_systematics {
345 writeln!(
346 out,
347 "{systematic} lnN {}",
348 join(columns.iter().map(|(_, process)| {
349 process
350 .flat_weight_systematics
351 .get(&systematic)
352 .map_or_else(|| "-".to_string(), format_lnn)
353 }))
354 )?;
355 }
356
357 Ok(out)
358 }
359
360 fn validate(&self) -> Result<()> {
361 if self.channels.is_empty() {
362 return Err(RootError::other(
363 "Combine datacard needs at least one channel",
364 ));
365 }
366
367 let mut channel_names = BTreeSet::new();
368 let mut process_indices = BTreeMap::<&str, i32>::new();
369 let mut all_shape_names = BTreeSet::new();
370 let mut all_flat_names = BTreeSet::new();
371
372 for channel in &self.channels {
373 validate_label("channel", &channel.name)?;
374 if !channel_names.insert(channel.name.as_str()) {
375 return Err(RootError::other(format!(
376 "duplicate Combine channel `{}`",
377 channel.name
378 )));
379 }
380 if channel.processes.is_empty() {
381 return Err(RootError::other(format!(
382 "Combine channel `{}` needs at least one process",
383 channel.name
384 )));
385 }
386 let signal_count = channel
387 .processes
388 .iter()
389 .filter(|process| process.index <= 0)
390 .count();
391 if signal_count != 1 {
392 return Err(RootError::other(format!(
393 "Combine channel `{}` must have exactly one signal process with index <= 0",
394 channel.name
395 )));
396 }
397
398 let mut process_names = BTreeSet::new();
399 for process in &channel.processes {
400 validate_label("process", &process.name)?;
401 if !process_names.insert(process.name.as_str()) {
402 return Err(RootError::other(format!(
403 "duplicate Combine process `{}` in channel `{}`",
404 process.name, channel.name
405 )));
406 }
407 if let Some(existing) = process_indices.insert(process.name.as_str(), process.index)
408 {
409 if existing != process.index {
410 return Err(RootError::other(format!(
411 "Combine process `{}` has inconsistent indices {existing} and {}",
412 process.name, process.index
413 )));
414 }
415 }
416
417 validate_compatible_histograms(process.nominal, channel.data_obs, &channel.name)?;
418 for (name, variation) in &process.shape_variations {
419 validate_label("shape systematic", name)?;
420 validate_compatible_histograms(process.nominal, variation.up, name)?;
421 validate_compatible_histograms(process.nominal, variation.down, name)?;
422 all_shape_names.insert(name.as_str());
423 }
424 let mut process_flat_names = BTreeSet::new();
425 for systematic in process.flat_weight_systematics.values() {
426 validate_flat_systematic(systematic)?;
427 if !process_flat_names.insert(systematic.name.as_str()) {
428 return Err(RootError::other(format!(
429 "duplicate flat weight systematic `{}` on process `{}` in channel `{}`",
430 systematic.name, process.name, channel.name
431 )));
432 }
433 all_flat_names.insert(systematic.name.as_str());
434 }
435 }
436 }
437
438 for name in all_shape_names {
439 if all_flat_names.contains(name) {
440 return Err(RootError::other(format!(
441 "systematic `{name}` is both shape and lnN"
442 )));
443 }
444 }
445
446 Ok(())
447 }
448
449 fn columns(&self) -> Vec<(&MultiProcessChannel<'a>, &Process<'a>)> {
450 self.channels
451 .iter()
452 .flat_map(|channel| {
453 channel
454 .processes
455 .iter()
456 .map(move |process| (channel, process))
457 })
458 .collect()
459 }
460
461 fn unique_process_count(&self) -> usize {
462 self.channels
463 .iter()
464 .flat_map(|channel| {
465 channel
466 .processes
467 .iter()
468 .map(|process| process.name.as_str())
469 })
470 .collect::<BTreeSet<_>>()
471 .len()
472 }
473
474 fn shape_systematic_names(&self) -> Vec<String> {
475 self.channels
476 .iter()
477 .flat_map(|channel| &channel.processes)
478 .flat_map(|process| process.shape_variations.keys().cloned())
479 .collect::<BTreeSet<_>>()
480 .into_iter()
481 .collect()
482 }
483
484 fn flat_systematic_names(&self) -> Vec<String> {
485 self.channels
486 .iter()
487 .flat_map(|channel| &channel.processes)
488 .flat_map(|process| process.flat_weight_systematics.keys().cloned())
489 .collect::<BTreeSet<_>>()
490 .into_iter()
491 .collect()
492 }
493
494 fn shape_inputs(&self) -> Vec<(String, &'a Hist1D)> {
495 let mut histograms = Vec::new();
496 for channel in &self.channels {
497 for process in &channel.processes {
498 histograms.push((shape_name(&channel.name, &process.name), process.nominal));
499 for (systematic, variation) in &process.shape_variations {
500 histograms.push((
501 shape_name(&channel.name, &format!("{}_{systematic}Up", process.name)),
502 variation.up,
503 ));
504 histograms.push((
505 shape_name(&channel.name, &format!("{}_{systematic}Down", process.name)),
506 variation.down,
507 ));
508 }
509 }
510 histograms.push((shape_name(&channel.name, "data_obs"), channel.data_obs));
511 }
512 histograms
513 }
514}
515
516impl<'a> Default for MultiProcessDatacard<'a> {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[derive(Debug, Clone, PartialEq)]
524pub struct SingleProcessDatacard<'a> {
525 process: String,
526 process_index: i32,
527 channels: Vec<Channel<'a>>,
528 flat_weight_systematics: Vec<FlatWeightSystematic>,
529}
530
531impl<'a> SingleProcessDatacard<'a> {
532 pub fn new(process: impl Into<String>) -> Self {
534 Self {
535 process: process.into(),
536 process_index: 0,
537 channels: Vec::new(),
538 flat_weight_systematics: Vec::new(),
539 }
540 }
541
542 pub fn with_process_index(mut self, process_index: i32) -> Self {
544 self.process_index = process_index;
545 self
546 }
547
548 pub fn with_channel(mut self, channel: Channel<'a>) -> Self {
550 self.channels.push(channel);
551 self
552 }
553
554 pub fn with_flat_weight_systematic(mut self, systematic: FlatWeightSystematic) -> Self {
556 self.flat_weight_systematics.push(systematic);
557 self
558 }
559
560 pub fn process(&self) -> &str {
561 &self.process
562 }
563
564 pub fn channels(&self) -> &[Channel<'a>] {
565 &self.channels
566 }
567
568 pub fn flat_weight_systematics(&self) -> &[FlatWeightSystematic] {
569 &self.flat_weight_systematics
570 }
571
572 pub fn write(&self, output_dir: &Path) -> Result<DatacardOutput> {
574 self.validate()?;
575 fs::create_dir_all(output_dir)?;
576 let datacard_path = output_dir.join(DATACARD_FILE);
577 let shapes_path = output_dir.join(SHAPES_FILE);
578
579 let shape_inputs = self.shape_inputs();
580 let borrowed = shape_inputs
581 .iter()
582 .map(|(name, hist)| (name.as_str(), *hist))
583 .collect::<Vec<_>>();
584 writer::write_histograms(&shapes_path, &borrowed)?;
585
586 let text = self.to_text(SHAPES_FILE)?;
587 fs::write(&datacard_path, text)?;
588
589 Ok(DatacardOutput {
590 datacard_path,
591 shapes_path,
592 })
593 }
594
595 pub fn to_text(&self, shapes_file: &str) -> Result<String> {
597 self.validate()?;
598 validate_shapes_file(shapes_file)?;
599
600 let shape_systematics = self.shape_systematic_names();
601 let columns = self.channels.len();
602 let mut out = String::new();
603
604 writeln!(out, "imax {} number of channels", self.channels.len())?;
605 writeln!(out, "jmax 0 number of processes minus 1")?;
606 writeln!(
607 out,
608 "kmax {} number of nuisance parameters",
609 shape_systematics.len() + self.flat_weight_systematics.len()
610 )?;
611 writeln!(out, "------------")?;
612 writeln!(
613 out,
614 "shapes * * {shapes_file} $CHANNEL/$PROCESS $CHANNEL/$PROCESS_$SYSTEMATIC"
615 )?;
616 writeln!(out, "------------")?;
617 writeln!(
618 out,
619 "bin {}",
620 join(self.channels.iter().map(|channel| channel.name()))
621 )?;
622 writeln!(
623 out,
624 "observation {}",
625 join(
626 self.channels
627 .iter()
628 .map(|channel| format_rate(rate(channel.data_obs())))
629 )
630 )?;
631 writeln!(out, "------------")?;
632 writeln!(
633 out,
634 "bin {}",
635 join(self.channels.iter().map(|channel| channel.name()))
636 )?;
637 writeln!(out, "process {}", repeated(&self.process, columns))?;
638 writeln!(
639 out,
640 "process {}",
641 repeated(&self.process_index.to_string(), columns)
642 )?;
643 writeln!(
644 out,
645 "rate {}",
646 join(
647 self.channels
648 .iter()
649 .map(|channel| format_rate(rate(channel.nominal())))
650 )
651 )?;
652 writeln!(out, "------------")?;
653
654 for systematic in shape_systematics {
655 writeln!(
656 out,
657 "{systematic} shape {}",
658 join(self.channels.iter().map(|channel| {
659 if channel.shape_variations.contains_key(&systematic) {
660 "1"
661 } else {
662 "-"
663 }
664 }))
665 )?;
666 }
667
668 for systematic in &self.flat_weight_systematics {
669 writeln!(
670 out,
671 "{} lnN {}",
672 systematic.name,
673 repeated(&format_lnn(systematic), columns)
674 )?;
675 }
676
677 Ok(out)
678 }
679
680 fn validate(&self) -> Result<()> {
681 validate_label("process", &self.process)?;
682 if self.channels.is_empty() {
683 return Err(RootError::other(
684 "Combine datacard needs at least one channel",
685 ));
686 }
687
688 let mut channel_names = BTreeSet::new();
689 for channel in &self.channels {
690 validate_label("channel", &channel.name)?;
691 if !channel_names.insert(channel.name.as_str()) {
692 return Err(RootError::other(format!(
693 "duplicate Combine channel `{}`",
694 channel.name
695 )));
696 }
697 validate_compatible_histograms(channel.nominal, channel.data_obs, &channel.name)?;
698 for (name, variation) in &channel.shape_variations {
699 validate_label("shape systematic", name)?;
700 validate_compatible_histograms(channel.nominal, variation.up, name)?;
701 validate_compatible_histograms(channel.nominal, variation.down, name)?;
702 }
703 }
704
705 let mut flat_names = BTreeSet::new();
706 for systematic in &self.flat_weight_systematics {
707 validate_flat_systematic(systematic)?;
708 if !flat_names.insert(systematic.name.as_str()) {
709 return Err(RootError::other(format!(
710 "duplicate flat weight systematic `{}`",
711 systematic.name
712 )));
713 }
714 }
715
716 for shape in self.shape_systematic_names() {
717 if flat_names.contains(shape.as_str()) {
718 return Err(RootError::other(format!(
719 "systematic `{shape}` is both shape and lnN"
720 )));
721 }
722 }
723
724 Ok(())
725 }
726
727 fn shape_systematic_names(&self) -> Vec<String> {
728 self.channels
729 .iter()
730 .flat_map(|channel| channel.shape_variations.keys().cloned())
731 .collect::<BTreeSet<_>>()
732 .into_iter()
733 .collect()
734 }
735
736 fn shape_inputs(&self) -> Vec<(String, &'a Hist1D)> {
737 let mut histograms = Vec::new();
738 for channel in &self.channels {
739 histograms.push((shape_name(&channel.name, &self.process), channel.nominal));
740 histograms.push((shape_name(&channel.name, "data_obs"), channel.data_obs));
741 for (systematic, variation) in &channel.shape_variations {
742 histograms.push((
743 shape_name(&channel.name, &format!("{}_{systematic}Up", self.process)),
744 variation.up,
745 ));
746 histograms.push((
747 shape_name(&channel.name, &format!("{}_{systematic}Down", self.process)),
748 variation.down,
749 ));
750 }
751 }
752 histograms
753 }
754}
755
756fn validate_shapes_file(shapes_file: &str) -> Result<()> {
757 if shapes_file.trim().is_empty() || shapes_file.chars().any(char::is_whitespace) {
758 return Err(RootError::other(
759 "Combine shapes file name must be non-empty and contain no whitespace",
760 ));
761 }
762 Ok(())
763}
764
765fn validate_flat_systematic(systematic: &FlatWeightSystematic) -> Result<()> {
766 validate_label("flat weight systematic", &systematic.name)?;
767 if !(systematic.up.is_finite() && systematic.down.is_finite()) {
768 return Err(RootError::other(format!(
769 "flat weight systematic `{}` has non-finite up/down factor",
770 systematic.name
771 )));
772 }
773 if systematic.up <= 0.0 || systematic.down <= 0.0 {
774 return Err(RootError::other(format!(
775 "flat weight systematic `{}` must have positive up/down factors",
776 systematic.name
777 )));
778 }
779 Ok(())
780}
781
782fn validate_label(kind: &str, value: &str) -> Result<()> {
783 if value.is_empty()
784 || value.chars().any(char::is_whitespace)
785 || value.contains('/')
786 || value.contains('$')
787 {
788 return Err(RootError::other(format!(
789 "Combine {kind} `{value}` must be non-empty and contain no whitespace, `/`, or `$`"
790 )));
791 }
792 Ok(())
793}
794
795fn validate_compatible_histograms(reference: &Hist1D, other: &Hist1D, context: &str) -> Result<()> {
796 if reference.nbins() != other.nbins()
797 || reference.low() != other.low()
798 || reference.high() != other.high()
799 {
800 return Err(RootError::other(format!(
801 "histogram `{context}` has binning incompatible with the channel nominal histogram"
802 )));
803 }
804 Ok(())
805}
806
807fn rate(hist: &Hist1D) -> f64 {
808 hist.bins().iter().sum()
809}
810
811fn shape_name(channel: &str, process: &str) -> String {
812 format!("{channel}/{process}")
813}
814
815fn format_lnn(systematic: &FlatWeightSystematic) -> String {
816 format!(
817 "{}/{}",
818 format_rate(systematic.down),
819 format_rate(systematic.up)
820 )
821}
822
823fn format_rate(value: f64) -> String {
824 if value == 0.0 {
825 return "0".to_string();
826 }
827 let formatted = format!("{value:.12}");
828 formatted
829 .trim_end_matches('0')
830 .trim_end_matches('.')
831 .to_string()
832}
833
834fn repeated(value: &str, count: usize) -> String {
835 join(std::iter::repeat_n(value, count))
836}
837
838fn join<T: AsRef<str>>(parts: impl IntoIterator<Item = T>) -> String {
839 parts
840 .into_iter()
841 .map(|part| part.as_ref().to_string())
842 .collect::<Vec<_>>()
843 .join(" ")
844}