Skip to content

API Reference

Sicifus

sicifus.Sicifus

Main API for Sicifus.

Source code in src/sicifus/api.py
  18
  19
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
class Sicifus:
    """
    Main API for Sicifus.
    """

    def __init__(self, db_path: str = "sicifus_db", xtb_work_dir: str = "./xtb_work"):
        self.db_path = Path(db_path)
        self.backbone_path = self.db_path / "backbone"
        self.heavy_atoms_path = self.db_path / "heavy_atoms"
        self.hydrogens_path = self.db_path / "hydrogens"
        self.ligands_path = self.db_path / "ligands"
        self.metadata_path = self.db_path / "metadata"

        # Legacy path support
        self.legacy_all_atom_path = self.db_path / "all_atom"

        self.loader = CIFLoader()
        self.aligner = StructuralAligner()
        self.toolkit = AnalysisToolkit()
        self.ligand_analyzer = LigandAnalyzer()
        self.xtb_scorer = XTBScorer(work_dir=xtb_work_dir)
        self.mutation_engine = MutationEngine(work_dir=str(self.db_path / "mutate_work"))

        self._backbone_lf: Optional[pl.LazyFrame] = None
        self._heavy_atoms_lf: Optional[pl.LazyFrame] = None
        self._hydrogens_lf: Optional[pl.LazyFrame] = None
        self._ligands_lf: Optional[pl.LazyFrame] = None
        self._metadata_lfs: Dict[str, pl.LazyFrame] = {}

        # Cached tree/cluster state
        self._linkage: Optional[np.ndarray] = None
        self._tree_labels: Optional[List[str]] = None
        self._rmsd_matrix: Optional[np.ndarray] = None
        self._clusters: Optional[pl.DataFrame] = None

    def ingest(self, input_folder: str, batch_size: int = 100, file_extension: str = "cif", 
               protonate: bool = False):
        """
        Ingests structure files from a folder into the database.

        Args:
            input_folder: Folder containing structure files.
            batch_size: Number of files per parquet partition.
            file_extension: File extension to look for (e.g., "cif", "pdb").
            protonate: If True, uses PDBFixer (OpenMM) to add hydrogens to the structure 
                       before parsing. This ensures consistent protonation for energy calculations.
        """
        print(f"Ingesting {file_extension} files from {input_folder} to {self.db_path}...")
        if protonate:
            print("  Protonation enabled (PDBFixer). This may take longer.")

        self.loader.ingest_folder(input_folder, str(self.db_path), batch_size, file_extension, protonate=protonate)
        self.load()

    def load(self):
        """Loads the database (lazy)."""
        if self.backbone_path.exists():
            self._backbone_lf = pl.scan_parquet(str(self.backbone_path / "*.parquet"))

        # Load heavy atoms (preferred) or legacy all_atom
        if self.heavy_atoms_path.exists():
            self._heavy_atoms_lf = pl.scan_parquet(str(self.heavy_atoms_path / "*.parquet"))
        elif self.legacy_all_atom_path.exists():
            self._heavy_atoms_lf = pl.scan_parquet(str(self.legacy_all_atom_path / "*.parquet"))

        if self.hydrogens_path.exists():
            self._hydrogens_lf = pl.scan_parquet(str(self.hydrogens_path / "*.parquet"))

        if self.ligands_path.exists():
            self._ligands_lf = pl.scan_parquet(str(self.ligands_path / "*.parquet"))
        if self.metadata_path.exists():
            for pq in self.metadata_path.glob("*.parquet"):
                name = pq.stem
                self._metadata_lfs[name] = pl.scan_parquet(str(pq))

    @property
    def backbone(self) -> pl.LazyFrame:
        if self._backbone_lf is None:
            self.load()
        if self._backbone_lf is None:
            raise ValueError("No backbone data found. Run ingest() first.")
        return self._backbone_lf

    @property
    def all_atom(self) -> pl.LazyFrame:
        """
        Returns protein heavy atoms (sidechains included). 
        Hydrogens are excluded by default for performance, unless using legacy data.
        """
        if self._heavy_atoms_lf is None:
            self.load()
        if self._heavy_atoms_lf is None:
            raise ValueError(
                "No heavy atom data found. Re-ingest your structures."
            )
        return self._heavy_atoms_lf

    @property
    def hydrogens(self) -> pl.LazyFrame:
        """Returns protein hydrogens (if available)."""
        if self._hydrogens_lf is None:
            self.load()
        # It's okay if this is None (e.g. legacy data or no protonation)
        return self._hydrogens_lf

    @property
    def ligands(self) -> pl.LazyFrame:
        if self._ligands_lf is None:
            self.load()
        if self._ligands_lf is None:
            raise ValueError("No ligand data found. Run ingest() first.")
        return self._ligands_lf

    def get_structure(self, structure_id: str) -> pl.DataFrame:
        """Retrieves a specific structure as a DataFrame."""
        return self.backbone.filter(pl.col("structure_id") == structure_id).collect()

    def get_all_atoms(self, structure_id: str) -> pl.DataFrame:
        """Retrieves ALL protein atoms (including sidechains) for a structure."""
        return self.all_atom.filter(pl.col("structure_id") == structure_id).collect()

    def get_ligands(self, structure_id: str) -> pl.DataFrame:
        """Retrieves ligands for a specific structure."""
        return self.ligands.filter(pl.col("structure_id") == structure_id).collect()

    # ── Metadata ─────────────────────────────────────────────────────────

    def load_metadata(self, path: str, name: Optional[str] = None, 
                      id_column: str = "id") -> pl.DataFrame:
        """
        Loads external metadata (CSV) and stores it in the database as parquet.
        The metadata is joined to structures via structure_id.

        Supports:
          - A single CSV file with an id column matching structure IDs.
          - A directory of CSVs — all are concatenated.

        Args:
            path: Path to a CSV file or a directory of CSVs.
            name: Name for this metadata source (used for storage and lookup).
                  Defaults to the filename stem (e.g. "3ca3.summarize" → "3ca3_summarize").
            id_column: Name of the column in the CSV that contains structure IDs.
                       Defaults to "id".

        Returns:
            The loaded metadata as a Polars DataFrame.
        """
        p = Path(path).expanduser()

        if p.is_file():
            df = pl.read_csv(str(p))
            if name is None:
                name = p.stem.replace(".", "_").replace("-", "_")
        elif p.is_dir():
            csvs = list(p.rglob("*.csv"))
            if not csvs:
                raise FileNotFoundError(f"No CSV files found in {p}")
            dfs = [pl.read_csv(str(f)) for f in csvs]
            df = pl.concat(dfs, how="diagonal")
            if name is None:
                name = p.name.replace(".", "_").replace("-", "_")
        else:
            raise FileNotFoundError(f"Path not found: {p}")

        # Rename the id column to structure_id for consistency
        if id_column in df.columns and id_column != "structure_id":
            df = df.rename({id_column: "structure_id"})
        elif "structure_id" not in df.columns:
            raise ValueError(
                f"Column '{id_column}' not found in CSV. "
                f"Available columns: {df.columns}. "
                f"Set id_column= to the column containing structure IDs."
            )

        # Store as parquet
        self.metadata_path.mkdir(parents=True, exist_ok=True)
        out_path = self.metadata_path / f"{name}.parquet"
        df.write_parquet(str(out_path))

        # Cache the lazy frame
        self._metadata_lfs[name] = pl.scan_parquet(str(out_path))

        n_rows = df.height
        n_cols = len(df.columns) - 1  # minus structure_id
        matched = 0
        if self._backbone_lf is not None:
            all_ids = self.backbone.select("structure_id").unique().collect().to_series()
            matched = df.filter(pl.col("structure_id").is_in(all_ids)).height

        print(f"Loaded metadata '{name}': {n_rows} rows, {n_cols} columns")
        if matched > 0:
            print(f"  {matched}/{n_rows} rows match structures in the database")

        return df

    @property
    def meta(self) -> pl.LazyFrame:
        """
        Returns all loaded metadata joined into a single LazyFrame on structure_id.
        If multiple metadata sources are loaded, they are joined together.
        """
        if not self._metadata_lfs:
            self.load()
        if not self._metadata_lfs:
            raise ValueError("No metadata loaded. Use load_metadata() first.")

        lfs = list(self._metadata_lfs.values())
        if len(lfs) == 1:
            return lfs[0]

        # Left-join all metadata sources on structure_id
        combined = lfs[0]
        for lf in lfs[1:]:
            combined = combined.join(lf, on="structure_id", how="full", coalesce=True)
        return combined

    def meta_columns(self) -> List[str]:
        """Lists all available metadata columns (across all loaded sources)."""
        cols = set()
        for lf in self._metadata_lfs.values():
            cols.update(c for c in lf.columns if c != "structure_id")
        return sorted(cols)

    def hist(self, column: str, bins: int = 30, title: Optional[str] = None,
             output_file: Optional[str] = None, **kwargs):
        """
        Plots a histogram of any metadata column.

        If cluster annotations exist, you can pass color_by="cluster" to 
        color the histogram by cluster assignment.

        Args:
            column: Column name from the metadata (e.g. "radius_of_gyration").
            bins: Number of histogram bins.
            title: Plot title. Defaults to the column name.
            output_file: Save to file instead of showing.
            **kwargs: Extra kwargs passed to matplotlib hist().

        Examples:
            db.hist("radius_of_gyration")
            db.hist("protein_length", bins=50)
        """
        # Collect the column from metadata
        df = self.meta.select(["structure_id", column]).collect().drop_nulls(column)

        if df.height == 0:
            print(f"No data found for column '{column}'. Available columns:")
            print(f"  {self.meta_columns()}")
            return

        values = df.get_column(column).to_numpy()

        fig, ax = plt.subplots(figsize=(10, 6))

        color_by = kwargs.pop("color_by", None)

        if color_by == "cluster" and self._clusters is not None:
            # Join with cluster assignments
            joined = df.join(self._clusters, on="structure_id", how="left")
            cluster_col = joined.get_column("cluster")
            unique_clusters = sorted(cluster_col.drop_nulls().unique().to_list())

            n_clust = len(unique_clusters)
            cmap = plt.cm.get_cmap("tab20" if n_clust <= 20 else "hsv", n_clust)

            for i, cid in enumerate(unique_clusters):
                cluster_vals = joined.filter(pl.col("cluster") == cid).get_column(column).to_numpy()
                color = cmap(i / max(n_clust - 1, 1))
                ax.hist(cluster_vals, bins=bins, alpha=0.6, color=color, 
                        label=f"Cluster {cid}", **kwargs)
            ax.legend(fontsize=7, ncol=2)
        else:
            ax.hist(values, bins=bins, edgecolor='black', alpha=0.8, **kwargs)

        ax.set_xlabel(column, fontsize=11)
        ax.set_ylabel("Count", fontsize=11)
        ax.set_title(title or column, fontsize=13)
        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches='tight')
            plt.close()
        else:
            plt.show()

    def scatter(self, x: str, y: str, title: Optional[str] = None,
                output_file: Optional[str] = None, **kwargs):
        """
        Scatter plot of two metadata columns.

        Args:
            x: Column name for x-axis.
            y: Column name for y-axis.
            title: Plot title.
            output_file: Save to file instead of showing.
            **kwargs: Extra kwargs passed to matplotlib scatter().

        Examples:
            db.scatter("protein_length", "radius_of_gyration")
        """
        df = self.meta.select(["structure_id", x, y]).collect().drop_nulls([x, y])

        if df.height == 0:
            print(f"No data found. Available columns: {self.meta_columns()}")
            return

        fig, ax = plt.subplots(figsize=(10, 6))

        color_by = kwargs.pop("color_by", None)

        if color_by == "cluster" and self._clusters is not None:
            joined = df.join(self._clusters, on="structure_id", how="left")
            cluster_col = joined.get_column("cluster").to_numpy().astype(float)
            sc = ax.scatter(joined.get_column(x).to_numpy(), 
                           joined.get_column(y).to_numpy(),
                           c=cluster_col, cmap="tab20", s=10, alpha=0.7, **kwargs)
            plt.colorbar(sc, label="Cluster")
        else:
            ax.scatter(df.get_column(x).to_numpy(), 
                      df.get_column(y).to_numpy(), s=10, alpha=0.7, **kwargs)

        ax.set_xlabel(x, fontsize=11)
        ax.set_ylabel(y, fontsize=11)
        ax.set_title(title or f"{y} vs {x}", fontsize=13)
        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches='tight')
            plt.close()
        else:
            plt.show()

    def align_all(self, reference_id: str, target_ids: Optional[List[str]] = None) -> pl.DataFrame:
        """
        Aligns all (or specified) structures to a reference structure.
        Returns a DataFrame with RMSD and alignment stats.
        """
        ref_df = self.get_structure(reference_id)
        if ref_df.height == 0:
            raise ValueError(f"Reference structure {reference_id} not found.")

        if target_ids is None:
            # Get all unique IDs (this might be expensive for massive DB, better to use metadata)
            target_ids = self.backbone.select("structure_id").unique().collect().to_series().to_list()

        # Remove reference from targets
        if reference_id in target_ids:
            target_ids.remove(reference_id)

        results = []
        print(f"Aligning {len(target_ids)} structures to {reference_id}...")

        for tid in target_ids:
            target_df = self.get_structure(tid)
            if target_df.height > 0:
                try:
                    rmsd, n_aligned = self.aligner.align_and_superimpose(target_df, ref_df)
                    results.append({
                        "structure_id": tid,
                        "reference_id": reference_id,
                        "rmsd": rmsd,
                        "aligned_residues": n_aligned
                    })
                except Exception as e:
                    print(f"Failed to align {tid}: {e}")

        return pl.DataFrame(results)

    def get_aligned_structure(self, structure_id: str, reference_id: str) -> pl.DataFrame:
        """
        Returns the structure transformed to align with the reference.
        """
        mobile_df = self.get_structure(structure_id)
        ref_df = self.get_structure(reference_id)

        if mobile_df.height == 0 or ref_df.height == 0:
            raise ValueError("Structure not found.")

        transformed_df, rmsd = self.aligner.align_and_transform(mobile_df, ref_df)
        print(f"Aligned {structure_id} to {reference_id} with RMSD: {rmsd:.2f}")
        return transformed_df

    def generate_tree(self, structure_ids: Optional[List[str]] = None, output_file: Optional[str] = None, 
                      root_id: Optional[str] = None, newick_file: Optional[str] = None,
                      pruning_threshold: Optional[float] = None,
                      layout: str = "circular"):
        """
        Generates a structural phylogenetic tree. Unrooted by default.
        Branch lengths are RMSD values.

        This is the expensive step (O(N^2) alignments). After this, use tree_stats() 
        to inspect branch lengths, then annotate_clusters() to assign clusters cheaply.

        Args:
            structure_ids: List of structure IDs. If None, uses all structures (warning: O(N^2)).
            output_file: Save the tree plot to this file (e.g. "tree.png").
            root_id: Root the tree at this structure ID. If None, tree is unrooted.
            newick_file: Export to Newick format for iTOL or similar tools.
            pruning_threshold: Skip alignment for structurally dissimilar pairs (0.0-1.0).
            layout: Tree layout for the plot: "circular" (default, unrooted radial) or "rectangular".

        Returns:
            Biopython Tree object.
        """
        t_total = time.perf_counter()

        if structure_ids is None:
            t0 = time.perf_counter()
            structure_ids = self.backbone.select("structure_id").unique().collect().to_series().to_list()

        if len(structure_ids) > 100:
            print(f"Warning: Generating tree for {len(structure_ids)} structures. This involves O(N^2) alignments.")

        # Load ALL backbone data in ONE scan, then group by structure_id
        all_data = self.backbone.filter(
            pl.col("structure_id").is_in(structure_ids)
        ).collect()

        structures = {}
        for sid, group in all_data.group_by("structure_id"):
            structures[sid[0] if isinstance(sid, tuple) else sid] = group

        matrix, labels = self.toolkit.compute_rmsd_matrix(structures, pruning_threshold=pruning_threshold)

        # Build the linkage matrix (fast, C-based) — needed for circular plot
        Z = self.toolkit.build_tree(matrix, labels)

        # Cache for later use
        self._linkage = Z
        self._tree_labels = labels
        self._rmsd_matrix = matrix

        # Build the Biopython tree (fast — needed for Newick and clustering)
        tree_obj = self.toolkit.build_phylo_tree(matrix, labels, root_id)
        self._tree = tree_obj

        # Write Newick
        if newick_file:
            from Bio import Phylo
            Phylo.write(tree_obj, newick_file, "newick")

        # Plot the tree
        if output_file:
            if layout == "circular":
                self.toolkit.plot_circular_tree(
                    Z, labels, 
                    cluster_df=self._clusters,
                    output_file=output_file
                )
            else:
                self.toolkit.plot_tree(tree_obj, output_file=output_file)

        return tree_obj

    def cluster(self, structure_ids: Optional[List[str]] = None,
                distance_threshold: float = 2.0,
                coverage_threshold: float = 0.8,
                output_file: Optional[str] = None) -> pl.DataFrame:
        """Fast greedy structural clustering (no full tree required).

        Uses a 3Di k-mer prefilter to rapidly identify candidate centroids,
        then only computes RMSD for those candidates.  Much faster than
        building a full phylogenetic tree for large datasets.

        Args:
            structure_ids: Structures to cluster. If None, uses all.
            distance_threshold: Max RMSD (Ã…) for assigning to a centroid.
            coverage_threshold: Min length-ratio for comparing two structures.
            output_file: Save a summary bar-chart of cluster sizes.

        Returns:
            Polars DataFrame with columns
            ``[structure_id, cluster, centroid_id, rmsd_to_centroid]``.
        """
        t0 = time.perf_counter()

        if structure_ids is None:
            structure_ids = (
                self.backbone.select("structure_id")
                .unique().collect().to_series().to_list()
            )

        all_data = self.backbone.filter(
            pl.col("structure_id").is_in(structure_ids)
        ).collect()

        structures = {}
        for sid, group in all_data.group_by("structure_id"):
            structures[sid[0] if isinstance(sid, tuple) else sid] = group

        df = self.toolkit.cluster_fast(
            structures,
            distance_threshold=distance_threshold,
            coverage_threshold=coverage_threshold,
        )

        self._clusters = df.select(["structure_id", "cluster"])

        elapsed = time.perf_counter() - t0
        print(f"Clustering completed in {elapsed:.1f}s")

        if output_file:
            import matplotlib.pyplot as plt
            sizes = (
                df.group_by("cluster")
                .agg(pl.col("structure_id").count().alias("size"))
                .sort("size", descending=True)
            )
            fig, ax = plt.subplots(figsize=(10, 5))
            ax.bar(range(sizes.height), sizes["size"].to_list(), edgecolor="black", alpha=0.8)
            ax.set_xlabel("Cluster (sorted by size)")
            ax.set_ylabel("Members")
            ax.set_title(f"Fast Clustering — {df['cluster'].n_unique()} clusters "
                         f"(threshold={distance_threshold} Ã…)")
            plt.tight_layout()
            plt.savefig(output_file, dpi=150, bbox_inches="tight")
            plt.close()

        return df

    def annotate_clusters(self, distance_threshold: float, output_file: Optional[str] = None,
                          layout: str = "circular") -> pl.DataFrame:
        """
        Annotates the tree with cluster labels by cutting branches whose RMSD 
        exceeds distance_threshold. Each resulting subtree becomes a cluster.

        This is cheap and instant — run it multiple times with different thresholds 
        after generate_tree() to explore coarse vs fine clustering.

        Use tree_stats() first to see the branch length distribution and pick 
        a meaningful threshold.

        Args:
            distance_threshold: Cut branches longer than this RMSD value.
                               e.g. 1.0 = subtrees separated by > 1 Ã… RMSD become different clusters.
            output_file: Optionally re-plot the tree with cluster colors.
            layout: "circular" (default) or "rectangular".

        Returns:
            Polars DataFrame with columns: structure_id, cluster
        """
        if not hasattr(self, '_tree') or self._tree is None:
            raise ValueError("No tree available. Run generate_tree() first.")

        self._clusters = self.toolkit.cluster_from_tree(self._tree, distance_threshold)
        n_clust = self._clusters["cluster"].n_unique()
        n_structs = self._clusters.height

        # Report singleton vs multi-member breakdown
        cluster_sizes = self._clusters.group_by("cluster").agg(pl.col("structure_id").count().alias("size"))
        n_singletons = cluster_sizes.filter(pl.col("size") == 1).height
        n_multi = n_clust - n_singletons
        n_in_multi = n_structs - n_singletons

        print(f"Annotated {n_structs} structures into {n_clust} clusters (threshold={distance_threshold})")
        print(f"  {n_multi} clusters with 2+ members ({n_in_multi} structures)")
        print(f"  {n_singletons} singletons (distinct outliers)")

        # Re-plot with cluster colors if requested
        if output_file:
            if layout == "circular":
                self.toolkit.plot_circular_tree(
                    self._linkage, self._tree_labels,
                    cluster_df=self._clusters,
                    output_file=output_file
                )
            else:
                self.toolkit.plot_tree(self._tree, output_file=output_file)

        return self._clusters

    # ── Tree inspection ───────────────────────────────────────────────────

    def tree_branch_lengths(self) -> np.ndarray:
        """
        Returns all branch lengths from the tree. Use this to understand the
        distribution and pick a meaningful distance_threshold for clustering.

        Example:
            bls = db.tree_branch_lengths()
            print(f"min={bls.min():.2f}, median={np.median(bls):.2f}, max={bls.max():.2f}")
            # Then pick a threshold that makes biological sense
        """
        if not hasattr(self, '_tree') or self._tree is None:
            raise ValueError("No tree available. Run generate_tree() first.")

        branch_lengths = []
        for clade in self._tree.find_clades():
            if clade.branch_length is not None and clade.branch_length > 0:
                branch_lengths.append(clade.branch_length)
        return np.array(branch_lengths)

    def tree_stats(self):
        """
        Prints summary statistics of the tree's branch lengths (RMSD).
        Helps you pick a good distance_threshold for clustering.
        """
        bls = self.tree_branch_lengths()
        percentiles = [10, 25, 50, 75, 90, 95, 99]
        pvals = np.percentile(bls, percentiles)

        print(f"Tree branch length statistics ({len(bls)} branches):")
        print(f"  min:    {bls.min():.4f}")
        for p, v in zip(percentiles, pvals):
            print(f"  {p:3d}th:  {v:.4f}")
        print(f"  max:    {bls.max():.4f}")
        print(f"  mean:   {bls.mean():.4f}")
        print()
        print("Tip: set distance_threshold to a percentile value.")
        print("  Higher threshold → fewer, larger clusters (coarser)")
        print("  Lower threshold  → more, smaller clusters (finer)")

    # ── Cluster querying API ──────────────────────────────────────────────

    @property
    def clusters(self) -> pl.DataFrame:
        """
        Returns the cluster assignments DataFrame (structure_id, cluster).
        Available after calling annotate_clusters().
        """
        if self._clusters is None:
            raise ValueError(
                "No cluster assignments available. "
                "Run annotate_clusters(distance_threshold=...) first."
            )
        return self._clusters

    def get_cluster(self, cluster_id: int) -> List[str]:
        """
        Returns all structure IDs belonging to a specific cluster.

        Args:
            cluster_id: The cluster number to query.

        Returns:
            List of structure IDs in that cluster.
        """
        return (
            self.clusters
            .filter(pl.col("cluster") == cluster_id)
            .get_column("structure_id")
            .to_list()
        )

    def get_cluster_for(self, structure_id: str) -> int:
        """
        Returns the cluster ID that a specific structure belongs to.

        Args:
            structure_id: The structure to look up.

        Returns:
            The cluster number.
        """
        result = (
            self.clusters
            .filter(pl.col("structure_id") == structure_id)
            .get_column("cluster")
            .to_list()
        )
        if not result:
            raise ValueError(f"Structure {structure_id} not found in cluster assignments.")
        return result[0]

    def get_cluster_siblings(self, structure_id: str) -> List[str]:
        """
        Returns all structure IDs in the same cluster as the given structure.

        Args:
            structure_id: The reference structure.

        Returns:
            List of structure IDs in the same cluster (including the reference).
        """
        cluster_id = self.get_cluster_for(structure_id)
        return self.get_cluster(cluster_id)

    def cluster_summary(self) -> pl.DataFrame:
        """
        Returns a summary of all clusters: cluster ID, count, and member IDs.
        """
        return (
            self.clusters
            .group_by("cluster")
            .agg([
                pl.col("structure_id").count().alias("count"),
                pl.col("structure_id").alias("members")
            ])
            .sort("cluster")
        )

    def get_clustered_ids(self, min_size: int = 2) -> List[str]:
        """
        Returns structure IDs that belong to clusters with at least min_size members.
        Filters out singletons (or small clusters) — the outliers on long branches.

        Use this to rebuild a pruned tree with only related structures:
            related = db.get_clustered_ids(min_size=2)
            db.generate_tree(structure_ids=related, ...)

        Args:
            min_size: Minimum cluster size to include (default 2 = drop singletons).

        Returns:
            List of structure IDs.
        """
        cluster_sizes = (
            self.clusters
            .group_by("cluster")
            .agg(pl.col("structure_id").count().alias("size"))
        )
        big_clusters = cluster_sizes.filter(pl.col("size") >= min_size).get_column("cluster").to_list()

        return (
            self.clusters
            .filter(pl.col("cluster").is_in(big_clusters))
            .get_column("structure_id")
            .to_list()
        )

    def analyze_ligand_binding(self, ligand_name: str, structure_ids: Optional[List[str]] = None,
                               output_file: Optional[str] = None):
        """
        Analyzes binding residues for a specific ligand across structures.
        Plots a histogram of binding residue types.
        """
        if structure_ids is None:
            structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                        .select("structure_id").unique().collect().to_series().to_list()

        print(f"Analyzing {ligand_name} binding in {len(structure_ids)} structures...")

        all_binding_residues = []
        for sid in structure_ids:
            backbone = self.get_structure(sid)
            ligands = self.get_ligands(sid)
            binding_df = self.ligand_analyzer.find_binding_residues(backbone, ligands, ligand_name)
            if binding_df.height > 0:
                all_binding_residues.extend(binding_df["residue_name"].to_list())

        if all_binding_residues:
            self.ligand_analyzer.plot_binding_histogram(
                all_binding_residues, 
                title=f"{ligand_name} — Binding Residue Distribution ({len(structure_ids)} structures)",
                output_file=output_file,
            )
        else:
            print("No binding residues found.")

    def analyze_pi_stacking(self, ligand_name: str, structure_ids: Optional[List[str]] = None,
                            output_file: Optional[str] = None,
                            charge: Optional[int] = None,
                            infer_bond_orders: bool = True) -> pl.DataFrame:
        """
        Detects pi-stacking interactions (sandwich, parallel displaced, T-shaped)
        between protein aromatic residues and aromatic rings in the specified ligand.

        Requires all_atom data (re-ingest if you only have backbone/CA).

        Produces a grouped bar chart of interaction types and residue breakdown.
        Returns a DataFrame of all detected interactions.

        The ``ligand_ring_atoms`` column uses the same canonical atom labels
        as ``analyze_ligand_contacts`` and the 2D depiction, so atom names
        are consistent across all analyses.

        Args:
            ligand_name: Three-letter ligand code (e.g. "GLC", "ATP").
            structure_ids: Optional list of structures to analyze. Defaults to all
                           structures containing the ligand.
            output_file: Save the plot instead of displaying.
            charge: Total formal charge of the ligand (passed to build_ligand_mol).
            infer_bond_orders: Whether to infer double/aromatic bonds (passed to
                    build_ligand_mol).

        Example:
            pi_df = db.analyze_pi_stacking("ATP")
        """
        if structure_ids is None:
            structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                        .select("structure_id").unique().collect().to_series().to_list()

        print(f"Detecting pi-stacking with {ligand_name} in {len(structure_ids)} structures...")

        all_interactions = []
        unique_smiles = set()

        for sid in structure_ids:
            try:
                all_atoms = self.get_all_atoms(sid)
            except ValueError:
                print("  No all_atom data available. Re-ingest structures to enable pi-stacking analysis.")
                return pl.DataFrame()
            ligands = self.get_ligands(sid)

            pi_df = self.ligand_analyzer.detect_pi_stacking(all_atoms, ligands, ligand_name)

            if pi_df.height > 0:
                # ── Per-structure Canonical Mapping ──────────────────────────
                # Remap ring atoms to canonical labels for this specific structure
                lig_for_struct = ligands.filter(pl.col("residue_name") == ligand_name)

                _, pdb_names_i, canonical_labels_i, smi_i = \
                    self.ligand_analyzer.build_ligand_mol(
                        lig_for_struct,
                        charge=charge,
                        infer_bond_orders=infer_bond_orders,
                    )

                if smi_i:
                    unique_smiles.add(smi_i)

                if pdb_names_i and canonical_labels_i:
                    pdb_to_canonical = dict(zip(pdb_names_i, canonical_labels_i))

                    def _remap_ring_atoms(pdb_csv: str) -> str:
                        return ",".join(
                            pdb_to_canonical.get(a.strip(), a.strip())
                            for a in pdb_csv.split(",")
                        )

                    remapped = [
                        _remap_ring_atoms(v)
                        for v in pi_df.get_column("ligand_ring_atoms").to_list()
                    ]
                    pi_df = pi_df.with_columns(
                        pl.Series("ligand_ring_atoms", remapped)
                    )

                pi_df = pi_df.with_columns(pl.lit(sid).alias("structure_id"))
                all_interactions.append(pi_df)

        if len(unique_smiles) > 1:
            print(f"Warning: Found {len(unique_smiles)} different molecular graphs (SMILES) for {ligand_name}.")
            print("  Canonical labeling might be inconsistent if structures have different connectivity/protonation.")

        if not all_interactions:
            print("No pi-stacking interactions detected.")
            return pl.DataFrame()

        result = pl.concat(all_interactions)

        n_sandwich = result.filter(pl.col("interaction_type") == "sandwich").height
        n_parallel = result.filter(pl.col("interaction_type") == "parallel_displaced").height
        n_tshaped = result.filter(pl.col("interaction_type") == "t_shaped").height
        print(f"  Found {result.height} interactions: "
              f"{n_sandwich} sandwich, {n_parallel} parallel displaced, {n_tshaped} T-shaped")

        # Plot
        self.ligand_analyzer.plot_pi_stacking(
            result.to_dicts(),
            title=f"{ligand_name} — Pi-Stacking Interactions ({len(structure_ids)} structures)",
            output_file=output_file,
        )
        return result

    def analyze_ligand_contacts(self, ligand_name: str, distance_cutoff: float = 3.3,
                                structure_ids: Optional[List[str]] = None,
                                output_file: Optional[str] = None,
                                ligand_2d_file: Optional[str] = None,
                                charge: Optional[int] = None,
                                infer_bond_orders: bool = True) -> pl.DataFrame:
        """
        Identifies atom-level protein-ligand contacts within a distance cutoff.
        Default 3.3 Ã… is appropriate for hydrogen bonding.

        Requires all_atom data (re-ingest if you only have backbone/CA).

        Produces TWO visualizations:
          1) Stacked bar chart: which ligand atoms form the most contacts.
          2) 2D ligand depiction (RDKit): atoms color-coded by contact count
             (red = many, blue = few) so you can cross-reference the chart.

        Returns a DataFrame of all contacts.

        Args:
            ligand_name: Three-letter ligand code.
            distance_cutoff: Max distance in Ã… (default 3.3 for H-bonds).
            structure_ids: Optional list of structures.
            output_file: Save the contacts bar chart to file.
            ligand_2d_file: Save the 2D ligand depiction to file.
                            If not set, auto-generates filename from output_file
                            or displays inline.
            charge: Total formal charge of the ligand (e.g. -3 for citrate).
                    Helps RDKit infer correct protonation / bond orders in
                    the 2D depiction.
            infer_bond_orders: If True, RDKit will try to determine double
                    and aromatic bonds from 3D geometry. Set to False if the
                    2D depiction shows incorrect double bonds — it will display
                    connectivity only (all single bonds), which is still useful.

        Example:
            contacts = db.analyze_ligand_contacts("GLC", distance_cutoff=3.3)
            # If the 2D shows wrong double bonds, pass the charge or disable inference:
            contacts = db.analyze_ligand_contacts("CIT", charge=-3)
            contacts = db.analyze_ligand_contacts("LIG", infer_bond_orders=False)
        """
        if structure_ids is None:
            structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                        .select("structure_id").unique().collect().to_series().to_list()

        print(f"Analyzing {ligand_name} atom contacts (<{distance_cutoff}Ã…) in {len(structure_ids)} structures...")

        all_contacts = []
        representative_ligand_atoms = None
        unique_smiles = set()

        # Pre-built mol/labels for the 2D depiction (from the first structure)
        _mol = None
        _canonical_labels = None

        for sid in structure_ids:
            try:
                all_atoms = self.get_all_atoms(sid)
            except ValueError:
                print("  No all_atom data available. Re-ingest structures to enable contact analysis.")
                return pl.DataFrame()
            ligands = self.get_ligands(sid)

            contacts_df = self.ligand_analyzer.find_ligand_atom_contacts(
                all_atoms, ligands, ligand_name, distance_cutoff=distance_cutoff
            )

            if contacts_df.height > 0:
                # ── Per-structure Canonical Mapping ──────────────────────────
                # Compute canonical labels for THIS specific structure instance
                # so that even if PDB atom names differ (e.g. O1 vs O-1),
                # chemically equivalent atoms get the same label (e.g. O4).
                lig_for_struct = ligands.filter(pl.col("residue_name") == ligand_name)

                # Build mol for this structure
                mol_i, pdb_names_i, canonical_labels_i, smi_i = \
                    self.ligand_analyzer.build_ligand_mol(
                        lig_for_struct,
                        charge=charge,
                        infer_bond_orders=infer_bond_orders,
                    )

                if smi_i:
                    unique_smiles.add(smi_i)

                # Save the first valid mol for the 2D plot
                if _mol is None and mol_i is not None:
                    _mol = mol_i
                    _canonical_labels = canonical_labels_i
                    representative_ligand_atoms = lig_for_struct

                # Map PDB names -> Canonical labels for this structure
                if pdb_names_i and canonical_labels_i:
                    pdb_to_canonical = dict(zip(pdb_names_i, canonical_labels_i))

                    # Apply mapping immediately to this structure's contacts
                    mapping_df = pl.DataFrame({
                        "ligand_atom": list(pdb_to_canonical.keys()),
                        "canonical_atom": list(pdb_to_canonical.values()),
                    })
                    contacts_df = contacts_df.join(mapping_df, on="ligand_atom", how="left")
                    contacts_df = contacts_df.with_columns(
                        pl.col("canonical_atom").fill_null(pl.col("ligand_atom"))
                    )
                else:
                    # Fallback if RDKit failed or no atoms
                    contacts_df = contacts_df.with_columns(
                        pl.col("ligand_atom").alias("canonical_atom")
                    )

                contacts_df = contacts_df.with_columns(pl.lit(sid).alias("structure_id"))
                all_contacts.append(contacts_df)

        if len(unique_smiles) > 1:
            print(f"Warning: Found {len(unique_smiles)} different molecular graphs (SMILES) for {ligand_name}.")
            print("  Canonical labeling might be inconsistent if structures have different connectivity/protonation.")

        if not all_contacts:
            print("No contacts found at this cutoff.")
            return pl.DataFrame()

        result = pl.concat(all_contacts)

        # Summary
        n_hbond_candidates = result.filter(
            pl.col("ligand_element").is_in(["N", "O", "S"]) & 
            pl.col("protein_element").is_in(["N", "O", "S"])
        ).height
        print(f"  Found {result.height} contacts ({n_hbond_candidates} potential H-bonds: N/O/S pairs)")

        # Plot 1: stacked bar chart — uses canonical_atom column
        self.ligand_analyzer.plot_ligand_contacts(
            result,
            title=f"{ligand_name} — Atom Contacts <{distance_cutoff}Å ({len(structure_ids)} structures)",
            output_file=output_file,
        )

        # Plot 2: 2D ligand depiction — reuses the representative mol + canonical
        # labels so atom numbering is identical to the bar chart.
        if representative_ligand_atoms is not None:
            lig2d_out = ligand_2d_file
            if lig2d_out is None and output_file is not None:
                from pathlib import Path as _P
                stem = _P(output_file).stem
                lig2d_out = str(_P(output_file).parent / f"{stem}_ligand2d.png")

            self.ligand_analyzer.plot_ligand_2d(
                representative_ligand_atoms,
                contacts_df=result,
                title=f"{ligand_name} — 2D Structure (colored by contacts)",
                output_file=lig2d_out,
                charge=charge,
                infer_bond_orders=infer_bond_orders,
                prebuilt_mol=_mol,
                prebuilt_canonical_labels=_canonical_labels,
            )
        return result

    def get_binding_pockets(self, ligand_name: str, distance_cutoff: float = 8.0,
                            structure_ids: Optional[List[str]] = None) -> pl.DataFrame:
        """
        Returns a DataFrame where each row is a structure and columns are residue counts 
        in the binding pocket (e.g. ALA, TRP, etc.).

        This is useful for filtering structures based on pocket composition.

        Args:
            ligand_name: Three-letter ligand code.
            distance_cutoff: Radius in Angstroms (default 8.0).
            structure_ids: Optional list of structures.

        Returns:
            DataFrame with structure_id and counts for each residue type found.
            Missing residues are filled with 0.
        """
        if structure_ids is None:
            structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                        .select("structure_id").unique().collect().to_series().to_list()

        print(f"Extracting {ligand_name} binding pockets (<{distance_cutoff}Ã…) from {len(structure_ids)} structures...")

        pocket_data = []
        all_residue_types = set()

        for sid in structure_ids:
            try:
                all_atoms = self.get_all_atoms(sid)
            except ValueError:
                # Skip structures without all-atom data
                continue

            ligands = self.get_ligands(sid)

            # Get list of residues in pocket (e.g. ["ALA", "HIS", "ALA"])
            pocket_residues = self.ligand_analyzer.get_pocket_residues(
                all_atoms, ligands, ligand_name, distance_cutoff=distance_cutoff
            )

            if pocket_residues:
                from collections import Counter
                counts = Counter(pocket_residues)
                row = {"structure_id": sid}
                row.update(counts)
                pocket_data.append(row)
                all_residue_types.update(counts.keys())

        if not pocket_data:
            return pl.DataFrame()

        # Create DataFrame and fill missing columns with 0
        df = pl.from_dicts(pocket_data)

        # Ensure all columns are numeric (except structure_id) and fill nulls with 0
        fill_cols = [c for c in df.columns if c != "structure_id"]
        df = df.with_columns([
            pl.col(c).fill_null(0).cast(pl.Int32) for c in fill_cols
        ])

        return df

    def analyze_binding_pocket(self, ligand_name: str, distance_cutoff: float = 8.0,
                               structure_ids: Optional[List[str]] = None,
                               output_file: Optional[str] = None) -> Dict[str, int]:
        """
        Analyzes the amino acid composition of the binding pocket (residues within 
        a defined radius of the ligand) across all structures.

        Produces a histogram of residue counts (X-axis = 20 amino acids).

        Args:
            ligand_name: Three-letter ligand code.
            distance_cutoff: Radius in Angstroms (default 8.0).
            structure_ids: Optional list of structures.
            output_file: Save the histogram to file.

        Returns:
            Dictionary of residue counts (e.g. {'ALA': 10, 'HIS': 5}).
        """
        # Get per-structure pocket data
        df = self.get_binding_pockets(ligand_name, distance_cutoff, structure_ids)

        if df.height == 0:
            print("No pocket residues found.")
            return {}

        # Aggregate counts across all structures
        # Sum each residue column
        residue_cols = [c for c in df.columns if c != "structure_id"]
        sums = df.select([pl.col(c).sum() for c in residue_cols]).row(0, named=True)

        self.ligand_analyzer.plot_binding_pocket_composition(
            sums,
            title=f"{ligand_name} — Binding Pocket Composition (<{distance_cutoff}Å, {df.height} structures)",
            output_file=output_file
        )
        return sums

    def calculate_relative_energy(self, df: pl.DataFrame, group_by: Optional[str] = None) -> pl.DataFrame:
        """
        Converts total energies (Hartree) to relative energies (kcal/mol).
        Useful for comparing conformers of the *same* system (same atoms).

        Args:
            df: DataFrame with an "energy" column (Hartree).
            group_by: Column to group by before finding minimum (optional).
                      e.g. "ligand_name" to find best conformer per ligand.

        Returns:
            DataFrame with new "relative_energy_kcal" column.
        """
        return self.toolkit.calculate_relative_energy(df, group_by)

    def score_ligand_energy(self, ligand_name: str, structure_id: str, 
                            distance_cutoff: float = 6.0, charge: int = 0,
                            uhf: int = 0, solvent: str = "water",
                            debug: bool = False) -> Dict[str, float]:
        """
        Calculates the semi-empirical QM energy (GFN2-xTB) of the ligand in its binding pocket.

        This extracts the ligand and surrounding protein residues, freezes the protein atoms,
        and optimizes the ligand geometry to find its local energy minimum.

        It also calculates the **Interaction Energy**:
            E_int = E_complex - (E_protein + E_ligand)

        This allows for comparison across different binding pockets (different protein sequences),
        as it normalizes for the size/composition of the pocket.

        Args:
            ligand_name: Three-letter ligand code.
            structure_id: ID of the structure to analyze.
            distance_cutoff: Radius around ligand to include in the calculation (default 6.0 Ã…).
            charge: Total system charge (default 0).
            uhf: Number of unpaired electrons (default 0).
            solvent: Implicit solvent model (default "water").
            debug: If True, saves input/output structures to debug_structures/ folder.

        Returns:
            Dictionary with:
              - "energy": Total energy of complex (Hartree)
              - "gap": HOMO-LUMO gap (eV)
              - "interaction_energy": Interaction energy (kcal/mol)
              - "e_complex", "e_protein", "e_ligand": Component energies (Hartree)
            Returns empty dict if calculation fails or xtb is missing.
        """
        try:
            # Load heavy atoms (default)
            all_atoms = self.get_all_atoms(structure_id)

            # Load hydrogens if available and merge
            if self.hydrogens is not None:
                hydrogens = self.hydrogens.filter(pl.col("structure_id") == structure_id).collect()
                if hydrogens.height > 0:
                    all_atoms = pl.concat([all_atoms, hydrogens])

            ligands = self.get_ligands(structure_id)
        except ValueError:
            print(f"Error: Could not retrieve atoms for {structure_id}. Ensure all_atom data exists.")
            return {}

        # Get pocket atoms for charge estimation
        # We need to re-extract them or pass them through.
        # Since extract_pocket_xyz doesn't return the DF, we quickly re-filter here
        # or update extract_pocket_xyz to return it.
        # For simplicity, let's re-filter (cheap).
        from scipy.spatial.distance import cdist
        target_ligand = ligands.filter(pl.col("residue_name") == ligand_name)
        if target_ligand.height > 0:
            lig_coords = target_ligand.select(["x", "y", "z"]).to_numpy()
            prot_coords = all_atoms.select(["x", "y", "z"]).to_numpy()
            dists = cdist(prot_coords, lig_coords)
            min_dists = np.min(dists, axis=1)
            close_mask = min_dists < distance_cutoff
            close_atoms = all_atoms.filter(close_mask)
            # Expand to complete residues for accurate charge estimation
            residue_id_cols = ["chain", "residue_number", "residue_name"]
            available_cols = [c for c in residue_id_cols if c in all_atoms.columns]
            if available_cols and close_atoms.height > 0:
                touched = close_atoms.select(available_cols).unique()
                pocket_atoms = all_atoms.join(touched, on=available_cols, how="semi")
            else:
                pocket_atoms = close_atoms
        else:
            pocket_atoms = None

        xyz_complex, xyz_protein, xyz_ligand, n_prot = self.xtb_scorer.extract_pocket_xyz(
            all_atoms, ligands, ligand_name, distance_cutoff
        )

        if xyz_complex is None:
            return {}

        print(f"Running xTB scoring for {structure_id} (pocket size: {distance_cutoff}Ã…)...")
        results = self.xtb_scorer.run_scoring(
            xyz_complex, xyz_protein, xyz_ligand,
            n_protein_atoms=n_prot, charge=charge, uhf=uhf, solvent=solvent,
            pocket_atoms=pocket_atoms,
            save_structures=debug,
            structure_id=structure_id
        )

        if "interaction_energy" in results:
            print(f"  Interaction Energy: {results['interaction_energy']:.2f} kcal/mol")
        elif "energy" in results:
            print(f"  Total Energy: {results['energy']:.6f} Eh")

        return results

    # ── Mutation & Stability (industry-standard) ─────────────────────────────────

    def _structure_to_pdb(self, structure_id: str) -> str:
        """Reconstruct a PDB string for a structure from the Parquet database."""
        from .mutate import _df_to_pdb_string
        all_atoms = self.get_all_atoms(structure_id)
        if self.hydrogens is not None:
            h = self.hydrogens.filter(pl.col("structure_id") == structure_id).collect()
            if h.height > 0:
                all_atoms = pl.concat([all_atoms, h])
        return _df_to_pdb_string(all_atoms)

    def repair_structure(self, structure_id: str, **kwargs) -> RepairResult:
        """Repair a structure: fix missing atoms, add hydrogens, minimise.

        Repairs protein structure by fixing clashes and adding missing atoms.  Requires ``pdbfixer`` and ``openmm``
        (install with ``pip install sicifus[energy]``).

        Args:
            structure_id: ID of the structure in the database.
            **kwargs: Forwarded to :meth:`MutationEngine.repair`.

        Returns:
            RepairResult with repaired PDB and energy change.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.repair(pdb_text, **kwargs)

    def calculate_stability(self, structure_id: str, **kwargs) -> StabilityResult:
        """Calculate total potential energy with per-term decomposition.

        Calculates protein stability using energy minimization.

        Args:
            structure_id: ID of the structure in the database.
            **kwargs: Forwarded to :meth:`MutationEngine.calculate_stability`.

        Returns:
            StabilityResult with total energy (kcal/mol) and per-force-term breakdown.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.calculate_stability(pdb_text, **kwargs)

    def mutate_structure(
        self, structure_id: str, mutations: List[Union[Mutation, str]], **kwargs
    ) -> MutationResult:
        """Apply point mutations, minimise, and compute ddG.

        Args:
            structure_id: ID of the structure in the database.
            mutations: List of Mutation objects or strings
                       (e.g. ``'G13L'`` for Gly at position 13 to Leu).
            **kwargs: Forwarded to :meth:`MutationEngine.mutate`.

        Returns:
            MutationResult with wild-type energy, mutant energy, ddG, and
            mutant PDB strings.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.mutate(pdb_text, mutations, **kwargs)

    def load_mutations(self, csv_path: str) -> pl.DataFrame:
        """Load a mutation list from a CSV file.

        The CSV must have a ``mutation`` column (e.g. ``G13L``).  An optional
        ``chain`` column provides chain IDs; if absent, defaults to ``'A'``.
        Extra columns are preserved as metadata.

        Args:
            csv_path: Path to a CSV file.

        Returns:
            Polars DataFrame ready for :meth:`mutate_batch`.
        """
        return MutationEngine.load_mutations(csv_path)

    def mutate_batch(
        self, structure_id: str, mutations_df: pl.DataFrame, **kwargs
    ) -> pl.DataFrame:
        """Run every mutation in a DataFrame against a structure.

        Each row is an independent single-point mutation.  Extra columns
        from the input are carried through to the result.

        Args:
            structure_id: ID of the structure in the database.
            mutations_df: DataFrame with ``mutation`` and ``chain`` columns
                          (as returned by :meth:`load_mutations`).
            **kwargs: Forwarded to :meth:`MutationEngine.mutate_batch`.

        Returns:
            DataFrame with input columns plus
            ``[wt_energy, mutant_energy, ddg_kcal_mol]``.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.mutate_batch(pdb_text, mutations_df, **kwargs)

    def calculate_binding_energy(
        self, structure_id: str, chains_a: List[str], chains_b: List[str], **kwargs
    ) -> BindingResult:
        """Calculate binding energy between two groups of chains.

        Calculates binding energy for protein-protein complexes.

        Args:
            structure_id: ID of the structure in the database.
            chains_a: Chain IDs for the first group (e.g. ``['A']``).
            chains_b: Chain IDs for the second group (e.g. ``['B']``).
            **kwargs: Forwarded to :meth:`MutationEngine.calculate_binding_energy`.

        Returns:
            BindingResult with binding energy and interface residues.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.calculate_binding_energy(
            pdb_text, chains_a, chains_b, **kwargs
        )

    def alanine_scan(
        self, structure_id: str, chain: str, positions: Optional[List[int]] = None,
        **kwargs
    ) -> pl.DataFrame:
        """Alanine scan: mutate each position to Ala and report ddG.

        Performs systematic alanine scanning mutagenesis.

        Args:
            structure_id: ID of the structure in the database.
            chain: Chain ID to scan.
            positions: Specific residue numbers. If None, scans all eligible residues.
            **kwargs: Forwarded to :meth:`MutationEngine.alanine_scan`.

        Returns:
            DataFrame with columns [chain, position, wt_residue, ddg_kcal_mol].
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.alanine_scan(pdb_text, chain, positions, **kwargs)

    def position_scan(
        self, structure_id: str, chain: str, positions: List[int], **kwargs
    ) -> pl.DataFrame:
        """Scan all 20 amino acids at specified positions.

        Generates position-specific scoring matrix.

        Args:
            structure_id: ID of the structure in the database.
            chain: Chain ID.
            positions: List of residue numbers to scan.
            **kwargs: Forwarded to :meth:`MutationEngine.position_scan`.

        Returns:
            DataFrame with columns
            [chain, position, wt_residue, mut_residue, ddg_kcal_mol].
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.position_scan(pdb_text, chain, positions, **kwargs)

    def per_residue_energy(self, structure_id: str, **kwargs) -> pl.DataFrame:
        """Approximate per-residue energy contribution via Ala-subtraction.

        Computes per-residue energy decomposition.

        Args:
            structure_id: ID of the structure in the database.
            **kwargs: Forwarded to :meth:`MutationEngine.per_residue_energy`.

        Returns:
            DataFrame with columns
            [chain, residue_number, residue_name, energy_contribution_kcal_mol].
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.per_residue_energy(pdb_text, **kwargs)

    # ------------------------------------------------------------------
    # Mutation Visualization
    # ------------------------------------------------------------------

    def plot_mutation_results(
        self,
        results_df: pl.DataFrame,
        output_file: Optional[str] = None,
        plot_type: str = "ddg",
        **kwargs
    ) -> pl.DataFrame:
        """Visualize mutation analysis results.

        Args:
            results_df: DataFrame from mutate_batch(), alanine_scan(), or position_scan()
            output_file: Path to save figure (if None, shows interactively)
            plot_type: One of "ddg", "distribution"
            **kwargs: Passed to visualization function

        Returns:
            Processed DataFrame used for plotting
        """
        if plot_type == "ddg":
            return plot_ddg(results_df, output_file, **kwargs)
        elif plot_type == "distribution":
            return plot_ddg_distribution(results_df, output_file, **kwargs)
        else:
            raise ValueError(f"Unknown plot_type: {plot_type}. Choose 'ddg' or 'distribution'")

    def plot_position_scan(
        self,
        scan_df: pl.DataFrame,
        output_file: Optional[str] = None,
        **kwargs
    ) -> pl.DataFrame:
        """Visualize position scan as heatmap.

        Args:
            scan_df: DataFrame from position_scan()
            output_file: Path to save figure (if None, shows interactively)
            **kwargs: Passed to plot_position_scan_heatmap()

        Returns:
            Pivoted DataFrame (rows=amino acids, cols=positions)
        """
        return plot_position_scan_heatmap(scan_df, output_file, **kwargs)

    def plot_alanine_scan_results(
        self,
        scan_df: pl.DataFrame,
        output_file: Optional[str] = None,
        **kwargs
    ) -> pl.DataFrame:
        """Visualize alanine scan results.

        Args:
            scan_df: DataFrame from alanine_scan()
            output_file: Path to save figure (if None, shows interactively)
            **kwargs: Passed to plot_alanine_scan()

        Returns:
            Sorted DataFrame used for plotting
        """
        return plot_alanine_scan(scan_df, output_file, **kwargs)

    def plot_energy_breakdown(
        self,
        energy_terms_df: pl.DataFrame,
        output_file: Optional[str] = None,
        **kwargs
    ) -> pl.DataFrame:
        """Visualize energy term breakdown for WT vs mutant.

        Args:
            energy_terms_df: DataFrame from MutationResult.energy_terms
            output_file: Path to save figure (if None, shows interactively)
            **kwargs: Passed to plot_energy_terms()

        Returns:
            Processed DataFrame with term contributions
        """
        return plot_energy_terms(energy_terms_df, output_file, **kwargs)

    # ------------------------------------------------------------------
    # Interface Mutagenesis (NEW)
    # ------------------------------------------------------------------

    def mutate_interface(
        self,
        structure_id: str,
        mutations: Dict[str, List[Union[Mutation, str]]],
        chains_a: List[str],
        chains_b: List[str],
        **kwargs
    ) -> InterfaceMutationResult:
        """Apply mutations to protein-protein interface and compute ΔΔG_binding.

        Args:
            structure_id: ID of the structure (complex) in the database.
            mutations: Dict mapping chain ID to list of mutations.
                      E.g., {"A": ["F13A", "W14L"], "B": ["Y25F"]}
            chains_a: Chain IDs for the first binding partner (e.g. ['A']).
            chains_b: Chain IDs for the second binding partner (e.g. ['B']).
            **kwargs: Forwarded to MutationEngine.mutate_interface().

        Returns:
            InterfaceMutationResult with ΔΔG_binding and component energies.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.mutate_interface(
            pdb_text, mutations, chains_a, chains_b, **kwargs
        )

    # ------------------------------------------------------------------
    # Disulfide Bond Analysis (NEW)
    # ------------------------------------------------------------------

    def detect_disulfides(self, structure_id: str, **kwargs) -> pl.DataFrame:
        """Detect disulfide bonds in a structure.

        Args:
            structure_id: ID of the structure in the database.
            **kwargs: Forwarded to MutationEngine.detect_disulfides().

        Returns:
            DataFrame with columns:
            [chain1, residue1, resname1, chain2, residue2, resname2, distance].
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.detect_disulfides(pdb_text, **kwargs)

    def analyze_mutation_disulfide_impact(
        self,
        structure_id: str,
        mutations: List[Union[Mutation, str]],
        **kwargs
    ) -> Dict[str, any]:
        """Analyze how mutations affect disulfide bonds.

        Args:
            structure_id: ID of the structure in the database.
            mutations: List of Mutation objects or strings (e.g. ``'C42A'``).
            **kwargs: Forwarded to MutationEngine.analyze_mutation_disulfide_impact().

        Returns:
            Dict with wt_disulfides, mutant_disulfides, broken_bonds, new_bonds.
        """
        pdb_text = self._structure_to_pdb(structure_id)
        return self.mutation_engine.analyze_mutation_disulfide_impact(
            pdb_text, mutations, **kwargs
        )

    # ------------------------------------------------------------------
    # Residue Interaction Networks (NEW)
    # ------------------------------------------------------------------

    def compute_interaction_network(
        self,
        structure_id: str,
        distance_cutoff: float = 5.0,
        interaction_types: Optional[List[str]] = None,
    ):
        """Compute residue interaction network for a structure.

        Args:
            structure_id: ID of the structure in the database.
            distance_cutoff: Maximum distance (Ã…) for residue contact (default 5.0).
            interaction_types: Optional filter for specific residues.

        Returns:
            NetworkX graph with residue nodes and interaction edges.
        """
        structure_df = self.get_structure(structure_id).collect()
        return self.toolkit.compute_residue_interaction_network(
            structure_df, distance_cutoff, interaction_types
        )

    def analyze_network_centrality(self, G, top_n: int = 10) -> pl.DataFrame:
        """Analyze network centrality metrics to identify key residues.

        Args:
            G: NetworkX graph from compute_interaction_network()
            top_n: Number of top residues to return (default 10)

        Returns:
            DataFrame with centrality metrics.
        """
        return self.toolkit.analyze_network_centrality(G, top_n)

    def plot_interaction_network(self, G, output_file: Optional[str] = None, **kwargs):
        """Visualize residue interaction network.

        Args:
            G: NetworkX graph from compute_interaction_network()
            output_file: Path to save figure (if None, shows interactively)
            **kwargs: Passed to toolkit.plot_interaction_network()
        """
        return self.toolkit.plot_interaction_network(G, output_file, **kwargs)

backbone property

all_atom property

Returns protein heavy atoms (sidechains included). Hydrogens are excluded by default for performance, unless using legacy data.

ligands property

meta property

Returns all loaded metadata joined into a single LazyFrame on structure_id. If multiple metadata sources are loaded, they are joined together.

clusters property

Returns the cluster assignments DataFrame (structure_id, cluster). Available after calling annotate_clusters().

ingest(input_folder, batch_size=100, file_extension='cif', protonate=False)

Ingests structure files from a folder into the database.

Parameters:

Name Type Description Default
input_folder str

Folder containing structure files.

required
batch_size int

Number of files per parquet partition.

100
file_extension str

File extension to look for (e.g., "cif", "pdb").

'cif'
protonate bool

If True, uses PDBFixer (OpenMM) to add hydrogens to the structure before parsing. This ensures consistent protonation for energy calculations.

False
Source code in src/sicifus/api.py
def ingest(self, input_folder: str, batch_size: int = 100, file_extension: str = "cif", 
           protonate: bool = False):
    """
    Ingests structure files from a folder into the database.

    Args:
        input_folder: Folder containing structure files.
        batch_size: Number of files per parquet partition.
        file_extension: File extension to look for (e.g., "cif", "pdb").
        protonate: If True, uses PDBFixer (OpenMM) to add hydrogens to the structure 
                   before parsing. This ensures consistent protonation for energy calculations.
    """
    print(f"Ingesting {file_extension} files from {input_folder} to {self.db_path}...")
    if protonate:
        print("  Protonation enabled (PDBFixer). This may take longer.")

    self.loader.ingest_folder(input_folder, str(self.db_path), batch_size, file_extension, protonate=protonate)
    self.load()

load()

Loads the database (lazy).

Source code in src/sicifus/api.py
def load(self):
    """Loads the database (lazy)."""
    if self.backbone_path.exists():
        self._backbone_lf = pl.scan_parquet(str(self.backbone_path / "*.parquet"))

    # Load heavy atoms (preferred) or legacy all_atom
    if self.heavy_atoms_path.exists():
        self._heavy_atoms_lf = pl.scan_parquet(str(self.heavy_atoms_path / "*.parquet"))
    elif self.legacy_all_atom_path.exists():
        self._heavy_atoms_lf = pl.scan_parquet(str(self.legacy_all_atom_path / "*.parquet"))

    if self.hydrogens_path.exists():
        self._hydrogens_lf = pl.scan_parquet(str(self.hydrogens_path / "*.parquet"))

    if self.ligands_path.exists():
        self._ligands_lf = pl.scan_parquet(str(self.ligands_path / "*.parquet"))
    if self.metadata_path.exists():
        for pq in self.metadata_path.glob("*.parquet"):
            name = pq.stem
            self._metadata_lfs[name] = pl.scan_parquet(str(pq))

get_structure(structure_id)

Retrieves a specific structure as a DataFrame.

Source code in src/sicifus/api.py
def get_structure(self, structure_id: str) -> pl.DataFrame:
    """Retrieves a specific structure as a DataFrame."""
    return self.backbone.filter(pl.col("structure_id") == structure_id).collect()

get_all_atoms(structure_id)

Retrieves ALL protein atoms (including sidechains) for a structure.

Source code in src/sicifus/api.py
def get_all_atoms(self, structure_id: str) -> pl.DataFrame:
    """Retrieves ALL protein atoms (including sidechains) for a structure."""
    return self.all_atom.filter(pl.col("structure_id") == structure_id).collect()

get_ligands(structure_id)

Retrieves ligands for a specific structure.

Source code in src/sicifus/api.py
def get_ligands(self, structure_id: str) -> pl.DataFrame:
    """Retrieves ligands for a specific structure."""
    return self.ligands.filter(pl.col("structure_id") == structure_id).collect()

load_metadata(path, name=None, id_column='id')

Loads external metadata (CSV) and stores it in the database as parquet. The metadata is joined to structures via structure_id.

Supports
  • A single CSV file with an id column matching structure IDs.
  • A directory of CSVs — all are concatenated.

Parameters:

Name Type Description Default
path str

Path to a CSV file or a directory of CSVs.

required
name Optional[str]

Name for this metadata source (used for storage and lookup). Defaults to the filename stem (e.g. "3ca3.summarize" → "3ca3_summarize").

None
id_column str

Name of the column in the CSV that contains structure IDs. Defaults to "id".

'id'

Returns:

Type Description
DataFrame

The loaded metadata as a Polars DataFrame.

Source code in src/sicifus/api.py
def load_metadata(self, path: str, name: Optional[str] = None, 
                  id_column: str = "id") -> pl.DataFrame:
    """
    Loads external metadata (CSV) and stores it in the database as parquet.
    The metadata is joined to structures via structure_id.

    Supports:
      - A single CSV file with an id column matching structure IDs.
      - A directory of CSVs — all are concatenated.

    Args:
        path: Path to a CSV file or a directory of CSVs.
        name: Name for this metadata source (used for storage and lookup).
              Defaults to the filename stem (e.g. "3ca3.summarize" → "3ca3_summarize").
        id_column: Name of the column in the CSV that contains structure IDs.
                   Defaults to "id".

    Returns:
        The loaded metadata as a Polars DataFrame.
    """
    p = Path(path).expanduser()

    if p.is_file():
        df = pl.read_csv(str(p))
        if name is None:
            name = p.stem.replace(".", "_").replace("-", "_")
    elif p.is_dir():
        csvs = list(p.rglob("*.csv"))
        if not csvs:
            raise FileNotFoundError(f"No CSV files found in {p}")
        dfs = [pl.read_csv(str(f)) for f in csvs]
        df = pl.concat(dfs, how="diagonal")
        if name is None:
            name = p.name.replace(".", "_").replace("-", "_")
    else:
        raise FileNotFoundError(f"Path not found: {p}")

    # Rename the id column to structure_id for consistency
    if id_column in df.columns and id_column != "structure_id":
        df = df.rename({id_column: "structure_id"})
    elif "structure_id" not in df.columns:
        raise ValueError(
            f"Column '{id_column}' not found in CSV. "
            f"Available columns: {df.columns}. "
            f"Set id_column= to the column containing structure IDs."
        )

    # Store as parquet
    self.metadata_path.mkdir(parents=True, exist_ok=True)
    out_path = self.metadata_path / f"{name}.parquet"
    df.write_parquet(str(out_path))

    # Cache the lazy frame
    self._metadata_lfs[name] = pl.scan_parquet(str(out_path))

    n_rows = df.height
    n_cols = len(df.columns) - 1  # minus structure_id
    matched = 0
    if self._backbone_lf is not None:
        all_ids = self.backbone.select("structure_id").unique().collect().to_series()
        matched = df.filter(pl.col("structure_id").is_in(all_ids)).height

    print(f"Loaded metadata '{name}': {n_rows} rows, {n_cols} columns")
    if matched > 0:
        print(f"  {matched}/{n_rows} rows match structures in the database")

    return df

meta_columns()

Lists all available metadata columns (across all loaded sources).

Source code in src/sicifus/api.py
def meta_columns(self) -> List[str]:
    """Lists all available metadata columns (across all loaded sources)."""
    cols = set()
    for lf in self._metadata_lfs.values():
        cols.update(c for c in lf.columns if c != "structure_id")
    return sorted(cols)

hist(column, bins=30, title=None, output_file=None, **kwargs)

Plots a histogram of any metadata column.

If cluster annotations exist, you can pass color_by="cluster" to color the histogram by cluster assignment.

Parameters:

Name Type Description Default
column str

Column name from the metadata (e.g. "radius_of_gyration").

required
bins int

Number of histogram bins.

30
title Optional[str]

Plot title. Defaults to the column name.

None
output_file Optional[str]

Save to file instead of showing.

None
**kwargs

Extra kwargs passed to matplotlib hist().

{}

Examples:

db.hist("radius_of_gyration") db.hist("protein_length", bins=50)

Source code in src/sicifus/api.py
def hist(self, column: str, bins: int = 30, title: Optional[str] = None,
         output_file: Optional[str] = None, **kwargs):
    """
    Plots a histogram of any metadata column.

    If cluster annotations exist, you can pass color_by="cluster" to 
    color the histogram by cluster assignment.

    Args:
        column: Column name from the metadata (e.g. "radius_of_gyration").
        bins: Number of histogram bins.
        title: Plot title. Defaults to the column name.
        output_file: Save to file instead of showing.
        **kwargs: Extra kwargs passed to matplotlib hist().

    Examples:
        db.hist("radius_of_gyration")
        db.hist("protein_length", bins=50)
    """
    # Collect the column from metadata
    df = self.meta.select(["structure_id", column]).collect().drop_nulls(column)

    if df.height == 0:
        print(f"No data found for column '{column}'. Available columns:")
        print(f"  {self.meta_columns()}")
        return

    values = df.get_column(column).to_numpy()

    fig, ax = plt.subplots(figsize=(10, 6))

    color_by = kwargs.pop("color_by", None)

    if color_by == "cluster" and self._clusters is not None:
        # Join with cluster assignments
        joined = df.join(self._clusters, on="structure_id", how="left")
        cluster_col = joined.get_column("cluster")
        unique_clusters = sorted(cluster_col.drop_nulls().unique().to_list())

        n_clust = len(unique_clusters)
        cmap = plt.cm.get_cmap("tab20" if n_clust <= 20 else "hsv", n_clust)

        for i, cid in enumerate(unique_clusters):
            cluster_vals = joined.filter(pl.col("cluster") == cid).get_column(column).to_numpy()
            color = cmap(i / max(n_clust - 1, 1))
            ax.hist(cluster_vals, bins=bins, alpha=0.6, color=color, 
                    label=f"Cluster {cid}", **kwargs)
        ax.legend(fontsize=7, ncol=2)
    else:
        ax.hist(values, bins=bins, edgecolor='black', alpha=0.8, **kwargs)

    ax.set_xlabel(column, fontsize=11)
    ax.set_ylabel("Count", fontsize=11)
    ax.set_title(title or column, fontsize=13)
    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

scatter(x, y, title=None, output_file=None, **kwargs)

Scatter plot of two metadata columns.

Parameters:

Name Type Description Default
x str

Column name for x-axis.

required
y str

Column name for y-axis.

required
title Optional[str]

Plot title.

None
output_file Optional[str]

Save to file instead of showing.

None
**kwargs

Extra kwargs passed to matplotlib scatter().

{}

Examples:

db.scatter("protein_length", "radius_of_gyration")

Source code in src/sicifus/api.py
def scatter(self, x: str, y: str, title: Optional[str] = None,
            output_file: Optional[str] = None, **kwargs):
    """
    Scatter plot of two metadata columns.

    Args:
        x: Column name for x-axis.
        y: Column name for y-axis.
        title: Plot title.
        output_file: Save to file instead of showing.
        **kwargs: Extra kwargs passed to matplotlib scatter().

    Examples:
        db.scatter("protein_length", "radius_of_gyration")
    """
    df = self.meta.select(["structure_id", x, y]).collect().drop_nulls([x, y])

    if df.height == 0:
        print(f"No data found. Available columns: {self.meta_columns()}")
        return

    fig, ax = plt.subplots(figsize=(10, 6))

    color_by = kwargs.pop("color_by", None)

    if color_by == "cluster" and self._clusters is not None:
        joined = df.join(self._clusters, on="structure_id", how="left")
        cluster_col = joined.get_column("cluster").to_numpy().astype(float)
        sc = ax.scatter(joined.get_column(x).to_numpy(), 
                       joined.get_column(y).to_numpy(),
                       c=cluster_col, cmap="tab20", s=10, alpha=0.7, **kwargs)
        plt.colorbar(sc, label="Cluster")
    else:
        ax.scatter(df.get_column(x).to_numpy(), 
                  df.get_column(y).to_numpy(), s=10, alpha=0.7, **kwargs)

    ax.set_xlabel(x, fontsize=11)
    ax.set_ylabel(y, fontsize=11)
    ax.set_title(title or f"{y} vs {x}", fontsize=13)
    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

align_all(reference_id, target_ids=None)

Aligns all (or specified) structures to a reference structure. Returns a DataFrame with RMSD and alignment stats.

Source code in src/sicifus/api.py
def align_all(self, reference_id: str, target_ids: Optional[List[str]] = None) -> pl.DataFrame:
    """
    Aligns all (or specified) structures to a reference structure.
    Returns a DataFrame with RMSD and alignment stats.
    """
    ref_df = self.get_structure(reference_id)
    if ref_df.height == 0:
        raise ValueError(f"Reference structure {reference_id} not found.")

    if target_ids is None:
        # Get all unique IDs (this might be expensive for massive DB, better to use metadata)
        target_ids = self.backbone.select("structure_id").unique().collect().to_series().to_list()

    # Remove reference from targets
    if reference_id in target_ids:
        target_ids.remove(reference_id)

    results = []
    print(f"Aligning {len(target_ids)} structures to {reference_id}...")

    for tid in target_ids:
        target_df = self.get_structure(tid)
        if target_df.height > 0:
            try:
                rmsd, n_aligned = self.aligner.align_and_superimpose(target_df, ref_df)
                results.append({
                    "structure_id": tid,
                    "reference_id": reference_id,
                    "rmsd": rmsd,
                    "aligned_residues": n_aligned
                })
            except Exception as e:
                print(f"Failed to align {tid}: {e}")

    return pl.DataFrame(results)

get_aligned_structure(structure_id, reference_id)

Returns the structure transformed to align with the reference.

Source code in src/sicifus/api.py
def get_aligned_structure(self, structure_id: str, reference_id: str) -> pl.DataFrame:
    """
    Returns the structure transformed to align with the reference.
    """
    mobile_df = self.get_structure(structure_id)
    ref_df = self.get_structure(reference_id)

    if mobile_df.height == 0 or ref_df.height == 0:
        raise ValueError("Structure not found.")

    transformed_df, rmsd = self.aligner.align_and_transform(mobile_df, ref_df)
    print(f"Aligned {structure_id} to {reference_id} with RMSD: {rmsd:.2f}")
    return transformed_df

generate_tree(structure_ids=None, output_file=None, root_id=None, newick_file=None, pruning_threshold=None, layout='circular')

Generates a structural phylogenetic tree. Unrooted by default. Branch lengths are RMSD values.

This is the expensive step (O(N^2) alignments). After this, use tree_stats() to inspect branch lengths, then annotate_clusters() to assign clusters cheaply.

Parameters:

Name Type Description Default
structure_ids Optional[List[str]]

List of structure IDs. If None, uses all structures (warning: O(N^2)).

None
output_file Optional[str]

Save the tree plot to this file (e.g. "tree.png").

None
root_id Optional[str]

Root the tree at this structure ID. If None, tree is unrooted.

None
newick_file Optional[str]

Export to Newick format for iTOL or similar tools.

None
pruning_threshold Optional[float]

Skip alignment for structurally dissimilar pairs (0.0-1.0).

None
layout str

Tree layout for the plot: "circular" (default, unrooted radial) or "rectangular".

'circular'

Returns:

Type Description

Biopython Tree object.

Source code in src/sicifus/api.py
def generate_tree(self, structure_ids: Optional[List[str]] = None, output_file: Optional[str] = None, 
                  root_id: Optional[str] = None, newick_file: Optional[str] = None,
                  pruning_threshold: Optional[float] = None,
                  layout: str = "circular"):
    """
    Generates a structural phylogenetic tree. Unrooted by default.
    Branch lengths are RMSD values.

    This is the expensive step (O(N^2) alignments). After this, use tree_stats() 
    to inspect branch lengths, then annotate_clusters() to assign clusters cheaply.

    Args:
        structure_ids: List of structure IDs. If None, uses all structures (warning: O(N^2)).
        output_file: Save the tree plot to this file (e.g. "tree.png").
        root_id: Root the tree at this structure ID. If None, tree is unrooted.
        newick_file: Export to Newick format for iTOL or similar tools.
        pruning_threshold: Skip alignment for structurally dissimilar pairs (0.0-1.0).
        layout: Tree layout for the plot: "circular" (default, unrooted radial) or "rectangular".

    Returns:
        Biopython Tree object.
    """
    t_total = time.perf_counter()

    if structure_ids is None:
        t0 = time.perf_counter()
        structure_ids = self.backbone.select("structure_id").unique().collect().to_series().to_list()

    if len(structure_ids) > 100:
        print(f"Warning: Generating tree for {len(structure_ids)} structures. This involves O(N^2) alignments.")

    # Load ALL backbone data in ONE scan, then group by structure_id
    all_data = self.backbone.filter(
        pl.col("structure_id").is_in(structure_ids)
    ).collect()

    structures = {}
    for sid, group in all_data.group_by("structure_id"):
        structures[sid[0] if isinstance(sid, tuple) else sid] = group

    matrix, labels = self.toolkit.compute_rmsd_matrix(structures, pruning_threshold=pruning_threshold)

    # Build the linkage matrix (fast, C-based) — needed for circular plot
    Z = self.toolkit.build_tree(matrix, labels)

    # Cache for later use
    self._linkage = Z
    self._tree_labels = labels
    self._rmsd_matrix = matrix

    # Build the Biopython tree (fast — needed for Newick and clustering)
    tree_obj = self.toolkit.build_phylo_tree(matrix, labels, root_id)
    self._tree = tree_obj

    # Write Newick
    if newick_file:
        from Bio import Phylo
        Phylo.write(tree_obj, newick_file, "newick")

    # Plot the tree
    if output_file:
        if layout == "circular":
            self.toolkit.plot_circular_tree(
                Z, labels, 
                cluster_df=self._clusters,
                output_file=output_file
            )
        else:
            self.toolkit.plot_tree(tree_obj, output_file=output_file)

    return tree_obj

cluster(structure_ids=None, distance_threshold=2.0, coverage_threshold=0.8, output_file=None)

Fast greedy structural clustering (no full tree required).

Uses a 3Di k-mer prefilter to rapidly identify candidate centroids, then only computes RMSD for those candidates. Much faster than building a full phylogenetic tree for large datasets.

Parameters:

Name Type Description Default
structure_ids Optional[List[str]]

Structures to cluster. If None, uses all.

None
distance_threshold float

Max RMSD (Ã…) for assigning to a centroid.

2.0
coverage_threshold float

Min length-ratio for comparing two structures.

0.8
output_file Optional[str]

Save a summary bar-chart of cluster sizes.

None

Returns:

Type Description
DataFrame

Polars DataFrame with columns

DataFrame

[structure_id, cluster, centroid_id, rmsd_to_centroid].

Source code in src/sicifus/api.py
def cluster(self, structure_ids: Optional[List[str]] = None,
            distance_threshold: float = 2.0,
            coverage_threshold: float = 0.8,
            output_file: Optional[str] = None) -> pl.DataFrame:
    """Fast greedy structural clustering (no full tree required).

    Uses a 3Di k-mer prefilter to rapidly identify candidate centroids,
    then only computes RMSD for those candidates.  Much faster than
    building a full phylogenetic tree for large datasets.

    Args:
        structure_ids: Structures to cluster. If None, uses all.
        distance_threshold: Max RMSD (Ã…) for assigning to a centroid.
        coverage_threshold: Min length-ratio for comparing two structures.
        output_file: Save a summary bar-chart of cluster sizes.

    Returns:
        Polars DataFrame with columns
        ``[structure_id, cluster, centroid_id, rmsd_to_centroid]``.
    """
    t0 = time.perf_counter()

    if structure_ids is None:
        structure_ids = (
            self.backbone.select("structure_id")
            .unique().collect().to_series().to_list()
        )

    all_data = self.backbone.filter(
        pl.col("structure_id").is_in(structure_ids)
    ).collect()

    structures = {}
    for sid, group in all_data.group_by("structure_id"):
        structures[sid[0] if isinstance(sid, tuple) else sid] = group

    df = self.toolkit.cluster_fast(
        structures,
        distance_threshold=distance_threshold,
        coverage_threshold=coverage_threshold,
    )

    self._clusters = df.select(["structure_id", "cluster"])

    elapsed = time.perf_counter() - t0
    print(f"Clustering completed in {elapsed:.1f}s")

    if output_file:
        import matplotlib.pyplot as plt
        sizes = (
            df.group_by("cluster")
            .agg(pl.col("structure_id").count().alias("size"))
            .sort("size", descending=True)
        )
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.bar(range(sizes.height), sizes["size"].to_list(), edgecolor="black", alpha=0.8)
        ax.set_xlabel("Cluster (sorted by size)")
        ax.set_ylabel("Members")
        ax.set_title(f"Fast Clustering — {df['cluster'].n_unique()} clusters "
                     f"(threshold={distance_threshold} Ã…)")
        plt.tight_layout()
        plt.savefig(output_file, dpi=150, bbox_inches="tight")
        plt.close()

    return df

annotate_clusters(distance_threshold, output_file=None, layout='circular')

Annotates the tree with cluster labels by cutting branches whose RMSD exceeds distance_threshold. Each resulting subtree becomes a cluster.

This is cheap and instant — run it multiple times with different thresholds after generate_tree() to explore coarse vs fine clustering.

Use tree_stats() first to see the branch length distribution and pick a meaningful threshold.

Parameters:

Name Type Description Default
distance_threshold float

Cut branches longer than this RMSD value. e.g. 1.0 = subtrees separated by > 1 Ã… RMSD become different clusters.

required
output_file Optional[str]

Optionally re-plot the tree with cluster colors.

None
layout str

"circular" (default) or "rectangular".

'circular'

Returns:

Type Description
DataFrame

Polars DataFrame with columns: structure_id, cluster

Source code in src/sicifus/api.py
def annotate_clusters(self, distance_threshold: float, output_file: Optional[str] = None,
                      layout: str = "circular") -> pl.DataFrame:
    """
    Annotates the tree with cluster labels by cutting branches whose RMSD 
    exceeds distance_threshold. Each resulting subtree becomes a cluster.

    This is cheap and instant — run it multiple times with different thresholds 
    after generate_tree() to explore coarse vs fine clustering.

    Use tree_stats() first to see the branch length distribution and pick 
    a meaningful threshold.

    Args:
        distance_threshold: Cut branches longer than this RMSD value.
                           e.g. 1.0 = subtrees separated by > 1 Ã… RMSD become different clusters.
        output_file: Optionally re-plot the tree with cluster colors.
        layout: "circular" (default) or "rectangular".

    Returns:
        Polars DataFrame with columns: structure_id, cluster
    """
    if not hasattr(self, '_tree') or self._tree is None:
        raise ValueError("No tree available. Run generate_tree() first.")

    self._clusters = self.toolkit.cluster_from_tree(self._tree, distance_threshold)
    n_clust = self._clusters["cluster"].n_unique()
    n_structs = self._clusters.height

    # Report singleton vs multi-member breakdown
    cluster_sizes = self._clusters.group_by("cluster").agg(pl.col("structure_id").count().alias("size"))
    n_singletons = cluster_sizes.filter(pl.col("size") == 1).height
    n_multi = n_clust - n_singletons
    n_in_multi = n_structs - n_singletons

    print(f"Annotated {n_structs} structures into {n_clust} clusters (threshold={distance_threshold})")
    print(f"  {n_multi} clusters with 2+ members ({n_in_multi} structures)")
    print(f"  {n_singletons} singletons (distinct outliers)")

    # Re-plot with cluster colors if requested
    if output_file:
        if layout == "circular":
            self.toolkit.plot_circular_tree(
                self._linkage, self._tree_labels,
                cluster_df=self._clusters,
                output_file=output_file
            )
        else:
            self.toolkit.plot_tree(self._tree, output_file=output_file)

    return self._clusters

tree_branch_lengths()

Returns all branch lengths from the tree. Use this to understand the distribution and pick a meaningful distance_threshold for clustering.

Example

bls = db.tree_branch_lengths() print(f"min={bls.min():.2f}, median={np.median(bls):.2f}, max={bls.max():.2f}")

Then pick a threshold that makes biological sense

Source code in src/sicifus/api.py
def tree_branch_lengths(self) -> np.ndarray:
    """
    Returns all branch lengths from the tree. Use this to understand the
    distribution and pick a meaningful distance_threshold for clustering.

    Example:
        bls = db.tree_branch_lengths()
        print(f"min={bls.min():.2f}, median={np.median(bls):.2f}, max={bls.max():.2f}")
        # Then pick a threshold that makes biological sense
    """
    if not hasattr(self, '_tree') or self._tree is None:
        raise ValueError("No tree available. Run generate_tree() first.")

    branch_lengths = []
    for clade in self._tree.find_clades():
        if clade.branch_length is not None and clade.branch_length > 0:
            branch_lengths.append(clade.branch_length)
    return np.array(branch_lengths)

tree_stats()

Prints summary statistics of the tree's branch lengths (RMSD). Helps you pick a good distance_threshold for clustering.

Source code in src/sicifus/api.py
def tree_stats(self):
    """
    Prints summary statistics of the tree's branch lengths (RMSD).
    Helps you pick a good distance_threshold for clustering.
    """
    bls = self.tree_branch_lengths()
    percentiles = [10, 25, 50, 75, 90, 95, 99]
    pvals = np.percentile(bls, percentiles)

    print(f"Tree branch length statistics ({len(bls)} branches):")
    print(f"  min:    {bls.min():.4f}")
    for p, v in zip(percentiles, pvals):
        print(f"  {p:3d}th:  {v:.4f}")
    print(f"  max:    {bls.max():.4f}")
    print(f"  mean:   {bls.mean():.4f}")
    print()
    print("Tip: set distance_threshold to a percentile value.")
    print("  Higher threshold → fewer, larger clusters (coarser)")
    print("  Lower threshold  → more, smaller clusters (finer)")

get_cluster(cluster_id)

Returns all structure IDs belonging to a specific cluster.

Parameters:

Name Type Description Default
cluster_id int

The cluster number to query.

required

Returns:

Type Description
List[str]

List of structure IDs in that cluster.

Source code in src/sicifus/api.py
def get_cluster(self, cluster_id: int) -> List[str]:
    """
    Returns all structure IDs belonging to a specific cluster.

    Args:
        cluster_id: The cluster number to query.

    Returns:
        List of structure IDs in that cluster.
    """
    return (
        self.clusters
        .filter(pl.col("cluster") == cluster_id)
        .get_column("structure_id")
        .to_list()
    )

get_cluster_for(structure_id)

Returns the cluster ID that a specific structure belongs to.

Parameters:

Name Type Description Default
structure_id str

The structure to look up.

required

Returns:

Type Description
int

The cluster number.

Source code in src/sicifus/api.py
def get_cluster_for(self, structure_id: str) -> int:
    """
    Returns the cluster ID that a specific structure belongs to.

    Args:
        structure_id: The structure to look up.

    Returns:
        The cluster number.
    """
    result = (
        self.clusters
        .filter(pl.col("structure_id") == structure_id)
        .get_column("cluster")
        .to_list()
    )
    if not result:
        raise ValueError(f"Structure {structure_id} not found in cluster assignments.")
    return result[0]

get_cluster_siblings(structure_id)

Returns all structure IDs in the same cluster as the given structure.

Parameters:

Name Type Description Default
structure_id str

The reference structure.

required

Returns:

Type Description
List[str]

List of structure IDs in the same cluster (including the reference).

Source code in src/sicifus/api.py
def get_cluster_siblings(self, structure_id: str) -> List[str]:
    """
    Returns all structure IDs in the same cluster as the given structure.

    Args:
        structure_id: The reference structure.

    Returns:
        List of structure IDs in the same cluster (including the reference).
    """
    cluster_id = self.get_cluster_for(structure_id)
    return self.get_cluster(cluster_id)

cluster_summary()

Returns a summary of all clusters: cluster ID, count, and member IDs.

Source code in src/sicifus/api.py
def cluster_summary(self) -> pl.DataFrame:
    """
    Returns a summary of all clusters: cluster ID, count, and member IDs.
    """
    return (
        self.clusters
        .group_by("cluster")
        .agg([
            pl.col("structure_id").count().alias("count"),
            pl.col("structure_id").alias("members")
        ])
        .sort("cluster")
    )

get_clustered_ids(min_size=2)

Returns structure IDs that belong to clusters with at least min_size members. Filters out singletons (or small clusters) — the outliers on long branches.

Parameters:

Name Type Description Default
min_size int

Minimum cluster size to include (default 2 = drop singletons).

2

Returns:

Type Description
List[str]

List of structure IDs.

Source code in src/sicifus/api.py
def get_clustered_ids(self, min_size: int = 2) -> List[str]:
    """
    Returns structure IDs that belong to clusters with at least min_size members.
    Filters out singletons (or small clusters) — the outliers on long branches.

    Use this to rebuild a pruned tree with only related structures:
        related = db.get_clustered_ids(min_size=2)
        db.generate_tree(structure_ids=related, ...)

    Args:
        min_size: Minimum cluster size to include (default 2 = drop singletons).

    Returns:
        List of structure IDs.
    """
    cluster_sizes = (
        self.clusters
        .group_by("cluster")
        .agg(pl.col("structure_id").count().alias("size"))
    )
    big_clusters = cluster_sizes.filter(pl.col("size") >= min_size).get_column("cluster").to_list()

    return (
        self.clusters
        .filter(pl.col("cluster").is_in(big_clusters))
        .get_column("structure_id")
        .to_list()
    )

analyze_ligand_binding(ligand_name, structure_ids=None, output_file=None)

Analyzes binding residues for a specific ligand across structures. Plots a histogram of binding residue types.

Source code in src/sicifus/api.py
def analyze_ligand_binding(self, ligand_name: str, structure_ids: Optional[List[str]] = None,
                           output_file: Optional[str] = None):
    """
    Analyzes binding residues for a specific ligand across structures.
    Plots a histogram of binding residue types.
    """
    if structure_ids is None:
        structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                    .select("structure_id").unique().collect().to_series().to_list()

    print(f"Analyzing {ligand_name} binding in {len(structure_ids)} structures...")

    all_binding_residues = []
    for sid in structure_ids:
        backbone = self.get_structure(sid)
        ligands = self.get_ligands(sid)
        binding_df = self.ligand_analyzer.find_binding_residues(backbone, ligands, ligand_name)
        if binding_df.height > 0:
            all_binding_residues.extend(binding_df["residue_name"].to_list())

    if all_binding_residues:
        self.ligand_analyzer.plot_binding_histogram(
            all_binding_residues, 
            title=f"{ligand_name} — Binding Residue Distribution ({len(structure_ids)} structures)",
            output_file=output_file,
        )
    else:
        print("No binding residues found.")

analyze_pi_stacking(ligand_name, structure_ids=None, output_file=None, charge=None, infer_bond_orders=True)

Detects pi-stacking interactions (sandwich, parallel displaced, T-shaped) between protein aromatic residues and aromatic rings in the specified ligand.

Requires all_atom data (re-ingest if you only have backbone/CA).

Produces a grouped bar chart of interaction types and residue breakdown. Returns a DataFrame of all detected interactions.

The ligand_ring_atoms column uses the same canonical atom labels as analyze_ligand_contacts and the 2D depiction, so atom names are consistent across all analyses.

Parameters:

Name Type Description Default
ligand_name str

Three-letter ligand code (e.g. "GLC", "ATP").

required
structure_ids Optional[List[str]]

Optional list of structures to analyze. Defaults to all structures containing the ligand.

None
output_file Optional[str]

Save the plot instead of displaying.

None
charge Optional[int]

Total formal charge of the ligand (passed to build_ligand_mol).

None
infer_bond_orders bool

Whether to infer double/aromatic bonds (passed to build_ligand_mol).

True
Example

pi_df = db.analyze_pi_stacking("ATP")

Source code in src/sicifus/api.py
def analyze_pi_stacking(self, ligand_name: str, structure_ids: Optional[List[str]] = None,
                        output_file: Optional[str] = None,
                        charge: Optional[int] = None,
                        infer_bond_orders: bool = True) -> pl.DataFrame:
    """
    Detects pi-stacking interactions (sandwich, parallel displaced, T-shaped)
    between protein aromatic residues and aromatic rings in the specified ligand.

    Requires all_atom data (re-ingest if you only have backbone/CA).

    Produces a grouped bar chart of interaction types and residue breakdown.
    Returns a DataFrame of all detected interactions.

    The ``ligand_ring_atoms`` column uses the same canonical atom labels
    as ``analyze_ligand_contacts`` and the 2D depiction, so atom names
    are consistent across all analyses.

    Args:
        ligand_name: Three-letter ligand code (e.g. "GLC", "ATP").
        structure_ids: Optional list of structures to analyze. Defaults to all
                       structures containing the ligand.
        output_file: Save the plot instead of displaying.
        charge: Total formal charge of the ligand (passed to build_ligand_mol).
        infer_bond_orders: Whether to infer double/aromatic bonds (passed to
                build_ligand_mol).

    Example:
        pi_df = db.analyze_pi_stacking("ATP")
    """
    if structure_ids is None:
        structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                    .select("structure_id").unique().collect().to_series().to_list()

    print(f"Detecting pi-stacking with {ligand_name} in {len(structure_ids)} structures...")

    all_interactions = []
    unique_smiles = set()

    for sid in structure_ids:
        try:
            all_atoms = self.get_all_atoms(sid)
        except ValueError:
            print("  No all_atom data available. Re-ingest structures to enable pi-stacking analysis.")
            return pl.DataFrame()
        ligands = self.get_ligands(sid)

        pi_df = self.ligand_analyzer.detect_pi_stacking(all_atoms, ligands, ligand_name)

        if pi_df.height > 0:
            # ── Per-structure Canonical Mapping ──────────────────────────
            # Remap ring atoms to canonical labels for this specific structure
            lig_for_struct = ligands.filter(pl.col("residue_name") == ligand_name)

            _, pdb_names_i, canonical_labels_i, smi_i = \
                self.ligand_analyzer.build_ligand_mol(
                    lig_for_struct,
                    charge=charge,
                    infer_bond_orders=infer_bond_orders,
                )

            if smi_i:
                unique_smiles.add(smi_i)

            if pdb_names_i and canonical_labels_i:
                pdb_to_canonical = dict(zip(pdb_names_i, canonical_labels_i))

                def _remap_ring_atoms(pdb_csv: str) -> str:
                    return ",".join(
                        pdb_to_canonical.get(a.strip(), a.strip())
                        for a in pdb_csv.split(",")
                    )

                remapped = [
                    _remap_ring_atoms(v)
                    for v in pi_df.get_column("ligand_ring_atoms").to_list()
                ]
                pi_df = pi_df.with_columns(
                    pl.Series("ligand_ring_atoms", remapped)
                )

            pi_df = pi_df.with_columns(pl.lit(sid).alias("structure_id"))
            all_interactions.append(pi_df)

    if len(unique_smiles) > 1:
        print(f"Warning: Found {len(unique_smiles)} different molecular graphs (SMILES) for {ligand_name}.")
        print("  Canonical labeling might be inconsistent if structures have different connectivity/protonation.")

    if not all_interactions:
        print("No pi-stacking interactions detected.")
        return pl.DataFrame()

    result = pl.concat(all_interactions)

    n_sandwich = result.filter(pl.col("interaction_type") == "sandwich").height
    n_parallel = result.filter(pl.col("interaction_type") == "parallel_displaced").height
    n_tshaped = result.filter(pl.col("interaction_type") == "t_shaped").height
    print(f"  Found {result.height} interactions: "
          f"{n_sandwich} sandwich, {n_parallel} parallel displaced, {n_tshaped} T-shaped")

    # Plot
    self.ligand_analyzer.plot_pi_stacking(
        result.to_dicts(),
        title=f"{ligand_name} — Pi-Stacking Interactions ({len(structure_ids)} structures)",
        output_file=output_file,
    )
    return result

analyze_ligand_contacts(ligand_name, distance_cutoff=3.3, structure_ids=None, output_file=None, ligand_2d_file=None, charge=None, infer_bond_orders=True)

Identifies atom-level protein-ligand contacts within a distance cutoff. Default 3.3 Ã… is appropriate for hydrogen bonding.

Requires all_atom data (re-ingest if you only have backbone/CA).

Produces TWO visualizations

1) Stacked bar chart: which ligand atoms form the most contacts. 2) 2D ligand depiction (RDKit): atoms color-coded by contact count (red = many, blue = few) so you can cross-reference the chart.

Returns a DataFrame of all contacts.

Parameters:

Name Type Description Default
ligand_name str

Three-letter ligand code.

required
distance_cutoff float

Max distance in Ã… (default 3.3 for H-bonds).

3.3
structure_ids Optional[List[str]]

Optional list of structures.

None
output_file Optional[str]

Save the contacts bar chart to file.

None
ligand_2d_file Optional[str]

Save the 2D ligand depiction to file. If not set, auto-generates filename from output_file or displays inline.

None
charge Optional[int]

Total formal charge of the ligand (e.g. -3 for citrate). Helps RDKit infer correct protonation / bond orders in the 2D depiction.

None
infer_bond_orders bool

If True, RDKit will try to determine double and aromatic bonds from 3D geometry. Set to False if the 2D depiction shows incorrect double bonds — it will display connectivity only (all single bonds), which is still useful.

True
Example

contacts = db.analyze_ligand_contacts("GLC", distance_cutoff=3.3)

If the 2D shows wrong double bonds, pass the charge or disable inference:

contacts = db.analyze_ligand_contacts("CIT", charge=-3) contacts = db.analyze_ligand_contacts("LIG", infer_bond_orders=False)

Source code in src/sicifus/api.py
def analyze_ligand_contacts(self, ligand_name: str, distance_cutoff: float = 3.3,
                            structure_ids: Optional[List[str]] = None,
                            output_file: Optional[str] = None,
                            ligand_2d_file: Optional[str] = None,
                            charge: Optional[int] = None,
                            infer_bond_orders: bool = True) -> pl.DataFrame:
    """
    Identifies atom-level protein-ligand contacts within a distance cutoff.
    Default 3.3 Ã… is appropriate for hydrogen bonding.

    Requires all_atom data (re-ingest if you only have backbone/CA).

    Produces TWO visualizations:
      1) Stacked bar chart: which ligand atoms form the most contacts.
      2) 2D ligand depiction (RDKit): atoms color-coded by contact count
         (red = many, blue = few) so you can cross-reference the chart.

    Returns a DataFrame of all contacts.

    Args:
        ligand_name: Three-letter ligand code.
        distance_cutoff: Max distance in Ã… (default 3.3 for H-bonds).
        structure_ids: Optional list of structures.
        output_file: Save the contacts bar chart to file.
        ligand_2d_file: Save the 2D ligand depiction to file.
                        If not set, auto-generates filename from output_file
                        or displays inline.
        charge: Total formal charge of the ligand (e.g. -3 for citrate).
                Helps RDKit infer correct protonation / bond orders in
                the 2D depiction.
        infer_bond_orders: If True, RDKit will try to determine double
                and aromatic bonds from 3D geometry. Set to False if the
                2D depiction shows incorrect double bonds — it will display
                connectivity only (all single bonds), which is still useful.

    Example:
        contacts = db.analyze_ligand_contacts("GLC", distance_cutoff=3.3)
        # If the 2D shows wrong double bonds, pass the charge or disable inference:
        contacts = db.analyze_ligand_contacts("CIT", charge=-3)
        contacts = db.analyze_ligand_contacts("LIG", infer_bond_orders=False)
    """
    if structure_ids is None:
        structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                    .select("structure_id").unique().collect().to_series().to_list()

    print(f"Analyzing {ligand_name} atom contacts (<{distance_cutoff}Ã…) in {len(structure_ids)} structures...")

    all_contacts = []
    representative_ligand_atoms = None
    unique_smiles = set()

    # Pre-built mol/labels for the 2D depiction (from the first structure)
    _mol = None
    _canonical_labels = None

    for sid in structure_ids:
        try:
            all_atoms = self.get_all_atoms(sid)
        except ValueError:
            print("  No all_atom data available. Re-ingest structures to enable contact analysis.")
            return pl.DataFrame()
        ligands = self.get_ligands(sid)

        contacts_df = self.ligand_analyzer.find_ligand_atom_contacts(
            all_atoms, ligands, ligand_name, distance_cutoff=distance_cutoff
        )

        if contacts_df.height > 0:
            # ── Per-structure Canonical Mapping ──────────────────────────
            # Compute canonical labels for THIS specific structure instance
            # so that even if PDB atom names differ (e.g. O1 vs O-1),
            # chemically equivalent atoms get the same label (e.g. O4).
            lig_for_struct = ligands.filter(pl.col("residue_name") == ligand_name)

            # Build mol for this structure
            mol_i, pdb_names_i, canonical_labels_i, smi_i = \
                self.ligand_analyzer.build_ligand_mol(
                    lig_for_struct,
                    charge=charge,
                    infer_bond_orders=infer_bond_orders,
                )

            if smi_i:
                unique_smiles.add(smi_i)

            # Save the first valid mol for the 2D plot
            if _mol is None and mol_i is not None:
                _mol = mol_i
                _canonical_labels = canonical_labels_i
                representative_ligand_atoms = lig_for_struct

            # Map PDB names -> Canonical labels for this structure
            if pdb_names_i and canonical_labels_i:
                pdb_to_canonical = dict(zip(pdb_names_i, canonical_labels_i))

                # Apply mapping immediately to this structure's contacts
                mapping_df = pl.DataFrame({
                    "ligand_atom": list(pdb_to_canonical.keys()),
                    "canonical_atom": list(pdb_to_canonical.values()),
                })
                contacts_df = contacts_df.join(mapping_df, on="ligand_atom", how="left")
                contacts_df = contacts_df.with_columns(
                    pl.col("canonical_atom").fill_null(pl.col("ligand_atom"))
                )
            else:
                # Fallback if RDKit failed or no atoms
                contacts_df = contacts_df.with_columns(
                    pl.col("ligand_atom").alias("canonical_atom")
                )

            contacts_df = contacts_df.with_columns(pl.lit(sid).alias("structure_id"))
            all_contacts.append(contacts_df)

    if len(unique_smiles) > 1:
        print(f"Warning: Found {len(unique_smiles)} different molecular graphs (SMILES) for {ligand_name}.")
        print("  Canonical labeling might be inconsistent if structures have different connectivity/protonation.")

    if not all_contacts:
        print("No contacts found at this cutoff.")
        return pl.DataFrame()

    result = pl.concat(all_contacts)

    # Summary
    n_hbond_candidates = result.filter(
        pl.col("ligand_element").is_in(["N", "O", "S"]) & 
        pl.col("protein_element").is_in(["N", "O", "S"])
    ).height
    print(f"  Found {result.height} contacts ({n_hbond_candidates} potential H-bonds: N/O/S pairs)")

    # Plot 1: stacked bar chart — uses canonical_atom column
    self.ligand_analyzer.plot_ligand_contacts(
        result,
        title=f"{ligand_name} — Atom Contacts <{distance_cutoff}Å ({len(structure_ids)} structures)",
        output_file=output_file,
    )

    # Plot 2: 2D ligand depiction — reuses the representative mol + canonical
    # labels so atom numbering is identical to the bar chart.
    if representative_ligand_atoms is not None:
        lig2d_out = ligand_2d_file
        if lig2d_out is None and output_file is not None:
            from pathlib import Path as _P
            stem = _P(output_file).stem
            lig2d_out = str(_P(output_file).parent / f"{stem}_ligand2d.png")

        self.ligand_analyzer.plot_ligand_2d(
            representative_ligand_atoms,
            contacts_df=result,
            title=f"{ligand_name} — 2D Structure (colored by contacts)",
            output_file=lig2d_out,
            charge=charge,
            infer_bond_orders=infer_bond_orders,
            prebuilt_mol=_mol,
            prebuilt_canonical_labels=_canonical_labels,
        )
    return result

get_binding_pockets(ligand_name, distance_cutoff=8.0, structure_ids=None)

Returns a DataFrame where each row is a structure and columns are residue counts in the binding pocket (e.g. ALA, TRP, etc.).

This is useful for filtering structures based on pocket composition.

Parameters:

Name Type Description Default
ligand_name str

Three-letter ligand code.

required
distance_cutoff float

Radius in Angstroms (default 8.0).

8.0
structure_ids Optional[List[str]]

Optional list of structures.

None

Returns:

Type Description
DataFrame

DataFrame with structure_id and counts for each residue type found.

DataFrame

Missing residues are filled with 0.

Source code in src/sicifus/api.py
def get_binding_pockets(self, ligand_name: str, distance_cutoff: float = 8.0,
                        structure_ids: Optional[List[str]] = None) -> pl.DataFrame:
    """
    Returns a DataFrame where each row is a structure and columns are residue counts 
    in the binding pocket (e.g. ALA, TRP, etc.).

    This is useful for filtering structures based on pocket composition.

    Args:
        ligand_name: Three-letter ligand code.
        distance_cutoff: Radius in Angstroms (default 8.0).
        structure_ids: Optional list of structures.

    Returns:
        DataFrame with structure_id and counts for each residue type found.
        Missing residues are filled with 0.
    """
    if structure_ids is None:
        structure_ids = self.ligands.filter(pl.col("residue_name") == ligand_name)\
                                    .select("structure_id").unique().collect().to_series().to_list()

    print(f"Extracting {ligand_name} binding pockets (<{distance_cutoff}Ã…) from {len(structure_ids)} structures...")

    pocket_data = []
    all_residue_types = set()

    for sid in structure_ids:
        try:
            all_atoms = self.get_all_atoms(sid)
        except ValueError:
            # Skip structures without all-atom data
            continue

        ligands = self.get_ligands(sid)

        # Get list of residues in pocket (e.g. ["ALA", "HIS", "ALA"])
        pocket_residues = self.ligand_analyzer.get_pocket_residues(
            all_atoms, ligands, ligand_name, distance_cutoff=distance_cutoff
        )

        if pocket_residues:
            from collections import Counter
            counts = Counter(pocket_residues)
            row = {"structure_id": sid}
            row.update(counts)
            pocket_data.append(row)
            all_residue_types.update(counts.keys())

    if not pocket_data:
        return pl.DataFrame()

    # Create DataFrame and fill missing columns with 0
    df = pl.from_dicts(pocket_data)

    # Ensure all columns are numeric (except structure_id) and fill nulls with 0
    fill_cols = [c for c in df.columns if c != "structure_id"]
    df = df.with_columns([
        pl.col(c).fill_null(0).cast(pl.Int32) for c in fill_cols
    ])

    return df

analyze_binding_pocket(ligand_name, distance_cutoff=8.0, structure_ids=None, output_file=None)

Analyzes the amino acid composition of the binding pocket (residues within a defined radius of the ligand) across all structures.

Produces a histogram of residue counts (X-axis = 20 amino acids).

Parameters:

Name Type Description Default
ligand_name str

Three-letter ligand code.

required
distance_cutoff float

Radius in Angstroms (default 8.0).

8.0
structure_ids Optional[List[str]]

Optional list of structures.

None
output_file Optional[str]

Save the histogram to file.

None

Returns:

Type Description
Dict[str, int]

Dictionary of residue counts (e.g. {'ALA': 10, 'HIS': 5}).

Source code in src/sicifus/api.py
def analyze_binding_pocket(self, ligand_name: str, distance_cutoff: float = 8.0,
                           structure_ids: Optional[List[str]] = None,
                           output_file: Optional[str] = None) -> Dict[str, int]:
    """
    Analyzes the amino acid composition of the binding pocket (residues within 
    a defined radius of the ligand) across all structures.

    Produces a histogram of residue counts (X-axis = 20 amino acids).

    Args:
        ligand_name: Three-letter ligand code.
        distance_cutoff: Radius in Angstroms (default 8.0).
        structure_ids: Optional list of structures.
        output_file: Save the histogram to file.

    Returns:
        Dictionary of residue counts (e.g. {'ALA': 10, 'HIS': 5}).
    """
    # Get per-structure pocket data
    df = self.get_binding_pockets(ligand_name, distance_cutoff, structure_ids)

    if df.height == 0:
        print("No pocket residues found.")
        return {}

    # Aggregate counts across all structures
    # Sum each residue column
    residue_cols = [c for c in df.columns if c != "structure_id"]
    sums = df.select([pl.col(c).sum() for c in residue_cols]).row(0, named=True)

    self.ligand_analyzer.plot_binding_pocket_composition(
        sums,
        title=f"{ligand_name} — Binding Pocket Composition (<{distance_cutoff}Å, {df.height} structures)",
        output_file=output_file
    )
    return sums

repair_structure(structure_id, **kwargs)

Repair a structure: fix missing atoms, add hydrogens, minimise.

Repairs protein structure by fixing clashes and adding missing atoms. Requires pdbfixer and openmm (install with pip install sicifus[energy]).

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
**kwargs

Forwarded to :meth:MutationEngine.repair.

{}

Returns:

Type Description
RepairResult

RepairResult with repaired PDB and energy change.

Source code in src/sicifus/api.py
def repair_structure(self, structure_id: str, **kwargs) -> RepairResult:
    """Repair a structure: fix missing atoms, add hydrogens, minimise.

    Repairs protein structure by fixing clashes and adding missing atoms.  Requires ``pdbfixer`` and ``openmm``
    (install with ``pip install sicifus[energy]``).

    Args:
        structure_id: ID of the structure in the database.
        **kwargs: Forwarded to :meth:`MutationEngine.repair`.

    Returns:
        RepairResult with repaired PDB and energy change.
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.repair(pdb_text, **kwargs)

calculate_stability(structure_id, **kwargs)

Calculate total potential energy with per-term decomposition.

Calculates protein stability using energy minimization.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
**kwargs

Forwarded to :meth:MutationEngine.calculate_stability.

{}

Returns:

Type Description
StabilityResult

StabilityResult with total energy (kcal/mol) and per-force-term breakdown.

Source code in src/sicifus/api.py
def calculate_stability(self, structure_id: str, **kwargs) -> StabilityResult:
    """Calculate total potential energy with per-term decomposition.

    Calculates protein stability using energy minimization.

    Args:
        structure_id: ID of the structure in the database.
        **kwargs: Forwarded to :meth:`MutationEngine.calculate_stability`.

    Returns:
        StabilityResult with total energy (kcal/mol) and per-force-term breakdown.
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.calculate_stability(pdb_text, **kwargs)

mutate_structure(structure_id, mutations, **kwargs)

Apply point mutations, minimise, and compute ddG.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
mutations List[Union[Mutation, str]]

List of Mutation objects or strings (e.g. 'G13L' for Gly at position 13 to Leu).

required
**kwargs

Forwarded to :meth:MutationEngine.mutate.

{}

Returns:

Type Description
MutationResult

MutationResult with wild-type energy, mutant energy, ddG, and

MutationResult

mutant PDB strings.

Source code in src/sicifus/api.py
def mutate_structure(
    self, structure_id: str, mutations: List[Union[Mutation, str]], **kwargs
) -> MutationResult:
    """Apply point mutations, minimise, and compute ddG.

    Args:
        structure_id: ID of the structure in the database.
        mutations: List of Mutation objects or strings
                   (e.g. ``'G13L'`` for Gly at position 13 to Leu).
        **kwargs: Forwarded to :meth:`MutationEngine.mutate`.

    Returns:
        MutationResult with wild-type energy, mutant energy, ddG, and
        mutant PDB strings.
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.mutate(pdb_text, mutations, **kwargs)

load_mutations(csv_path)

Load a mutation list from a CSV file.

The CSV must have a mutation column (e.g. G13L). An optional chain column provides chain IDs; if absent, defaults to 'A'. Extra columns are preserved as metadata.

Parameters:

Name Type Description Default
csv_path str

Path to a CSV file.

required

Returns:

Type Description
DataFrame

Polars DataFrame ready for :meth:mutate_batch.

Source code in src/sicifus/api.py
def load_mutations(self, csv_path: str) -> pl.DataFrame:
    """Load a mutation list from a CSV file.

    The CSV must have a ``mutation`` column (e.g. ``G13L``).  An optional
    ``chain`` column provides chain IDs; if absent, defaults to ``'A'``.
    Extra columns are preserved as metadata.

    Args:
        csv_path: Path to a CSV file.

    Returns:
        Polars DataFrame ready for :meth:`mutate_batch`.
    """
    return MutationEngine.load_mutations(csv_path)

mutate_batch(structure_id, mutations_df, **kwargs)

Run every mutation in a DataFrame against a structure.

Each row is an independent single-point mutation. Extra columns from the input are carried through to the result.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
mutations_df DataFrame

DataFrame with mutation and chain columns (as returned by :meth:load_mutations).

required
**kwargs

Forwarded to :meth:MutationEngine.mutate_batch.

{}

Returns:

Type Description
DataFrame

DataFrame with input columns plus

DataFrame

[wt_energy, mutant_energy, ddg_kcal_mol].

Source code in src/sicifus/api.py
def mutate_batch(
    self, structure_id: str, mutations_df: pl.DataFrame, **kwargs
) -> pl.DataFrame:
    """Run every mutation in a DataFrame against a structure.

    Each row is an independent single-point mutation.  Extra columns
    from the input are carried through to the result.

    Args:
        structure_id: ID of the structure in the database.
        mutations_df: DataFrame with ``mutation`` and ``chain`` columns
                      (as returned by :meth:`load_mutations`).
        **kwargs: Forwarded to :meth:`MutationEngine.mutate_batch`.

    Returns:
        DataFrame with input columns plus
        ``[wt_energy, mutant_energy, ddg_kcal_mol]``.
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.mutate_batch(pdb_text, mutations_df, **kwargs)

calculate_binding_energy(structure_id, chains_a, chains_b, **kwargs)

Calculate binding energy between two groups of chains.

Calculates binding energy for protein-protein complexes.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
chains_a List[str]

Chain IDs for the first group (e.g. ['A']).

required
chains_b List[str]

Chain IDs for the second group (e.g. ['B']).

required
**kwargs

Forwarded to :meth:MutationEngine.calculate_binding_energy.

{}

Returns:

Type Description
BindingResult

BindingResult with binding energy and interface residues.

Source code in src/sicifus/api.py
def calculate_binding_energy(
    self, structure_id: str, chains_a: List[str], chains_b: List[str], **kwargs
) -> BindingResult:
    """Calculate binding energy between two groups of chains.

    Calculates binding energy for protein-protein complexes.

    Args:
        structure_id: ID of the structure in the database.
        chains_a: Chain IDs for the first group (e.g. ``['A']``).
        chains_b: Chain IDs for the second group (e.g. ``['B']``).
        **kwargs: Forwarded to :meth:`MutationEngine.calculate_binding_energy`.

    Returns:
        BindingResult with binding energy and interface residues.
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.calculate_binding_energy(
        pdb_text, chains_a, chains_b, **kwargs
    )

alanine_scan(structure_id, chain, positions=None, **kwargs)

Alanine scan: mutate each position to Ala and report ddG.

Performs systematic alanine scanning mutagenesis.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
chain str

Chain ID to scan.

required
positions Optional[List[int]]

Specific residue numbers. If None, scans all eligible residues.

None
**kwargs

Forwarded to :meth:MutationEngine.alanine_scan.

{}

Returns:

Type Description
DataFrame

DataFrame with columns [chain, position, wt_residue, ddg_kcal_mol].

Source code in src/sicifus/api.py
def alanine_scan(
    self, structure_id: str, chain: str, positions: Optional[List[int]] = None,
    **kwargs
) -> pl.DataFrame:
    """Alanine scan: mutate each position to Ala and report ddG.

    Performs systematic alanine scanning mutagenesis.

    Args:
        structure_id: ID of the structure in the database.
        chain: Chain ID to scan.
        positions: Specific residue numbers. If None, scans all eligible residues.
        **kwargs: Forwarded to :meth:`MutationEngine.alanine_scan`.

    Returns:
        DataFrame with columns [chain, position, wt_residue, ddg_kcal_mol].
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.alanine_scan(pdb_text, chain, positions, **kwargs)

position_scan(structure_id, chain, positions, **kwargs)

Scan all 20 amino acids at specified positions.

Generates position-specific scoring matrix.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
chain str

Chain ID.

required
positions List[int]

List of residue numbers to scan.

required
**kwargs

Forwarded to :meth:MutationEngine.position_scan.

{}

Returns:

Type Description
DataFrame

DataFrame with columns

DataFrame

[chain, position, wt_residue, mut_residue, ddg_kcal_mol].

Source code in src/sicifus/api.py
def position_scan(
    self, structure_id: str, chain: str, positions: List[int], **kwargs
) -> pl.DataFrame:
    """Scan all 20 amino acids at specified positions.

    Generates position-specific scoring matrix.

    Args:
        structure_id: ID of the structure in the database.
        chain: Chain ID.
        positions: List of residue numbers to scan.
        **kwargs: Forwarded to :meth:`MutationEngine.position_scan`.

    Returns:
        DataFrame with columns
        [chain, position, wt_residue, mut_residue, ddg_kcal_mol].
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.position_scan(pdb_text, chain, positions, **kwargs)

per_residue_energy(structure_id, **kwargs)

Approximate per-residue energy contribution via Ala-subtraction.

Computes per-residue energy decomposition.

Parameters:

Name Type Description Default
structure_id str

ID of the structure in the database.

required
**kwargs

Forwarded to :meth:MutationEngine.per_residue_energy.

{}

Returns:

Type Description
DataFrame

DataFrame with columns

DataFrame

[chain, residue_number, residue_name, energy_contribution_kcal_mol].

Source code in src/sicifus/api.py
def per_residue_energy(self, structure_id: str, **kwargs) -> pl.DataFrame:
    """Approximate per-residue energy contribution via Ala-subtraction.

    Computes per-residue energy decomposition.

    Args:
        structure_id: ID of the structure in the database.
        **kwargs: Forwarded to :meth:`MutationEngine.per_residue_energy`.

    Returns:
        DataFrame with columns
        [chain, residue_number, residue_name, energy_contribution_kcal_mol].
    """
    pdb_text = self._structure_to_pdb(structure_id)
    return self.mutation_engine.per_residue_energy(pdb_text, **kwargs)

Mutation Engine

sicifus.MutationEngine

Industry-standard protein mutation and stability engine using OpenMM + PDBFixer.

Provides structure repair, in silico mutagenesis, stability scoring, binding energy calculation, alanine scanning, and positional scanning without requiring the commercial protein design tools.

Source code in src/sicifus/mutate.py
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
class MutationEngine:
    """
    Industry-standard protein mutation and stability engine using OpenMM + PDBFixer.

    Provides structure repair, in silico mutagenesis, stability scoring,
    binding energy calculation, alanine scanning, and positional scanning
    without requiring the commercial protein design tools.
    """

    HARTREE_TO_KCAL = 627.509
    KJ_TO_KCAL = 1.0 / 4.184

    def __init__(
        self,
        forcefield: str = "amber14-all.xml",
        water_model: str = "implicit",
        platform: str = "CPU",
        work_dir: str = "./mutate_work",
    ):
        """
        Args:
            forcefield: OpenMM force field XML (default AMBER14).
            water_model: 'implicit' for GBn2 implicit solvent (fast, industry-standard)
                         or an explicit water XML like 'amber14/tip3pfb.xml'.
            platform: OpenMM platform ('CPU', 'CUDA', 'OpenCL').
            work_dir: Directory for temporary files.
        """
        self.forcefield_name = forcefield
        self.water_model = water_model
        self.platform_name = platform
        self.work_dir = Path(work_dir)
        self.work_dir.mkdir(exist_ok=True, parents=True)

        self._ff = None
        self._platform = None

    def _get_forcefield(self):
        if self._ff is None:
            from openmm.app import ForceField
            if self.water_model == "implicit":
                self._ff = ForceField(self.forcefield_name, "implicit/gbn2.xml")
            else:
                self._ff = ForceField(self.forcefield_name, self.water_model)
        return self._ff

    def _get_platform(self):
        if self._platform is None:
            from openmm import Platform
            self._platform = Platform.getPlatformByName(self.platform_name)
        return self._platform

    # ------------------------------------------------------------------
    # Core: create an OpenMM system and minimise
    # ------------------------------------------------------------------

    def _build_system(self, topology, positions, constrain_backbone: bool = False,
                       skip_hydrogens: bool = False):
        """Create an OpenMM System from topology/positions.

        Returns (system, topology, positions) — topology/positions may have
        been modified (hydrogens added via Modeller).

        Args:
            skip_hydrogens: If True, skip ``addHydrogens`` — use when the
                            structure is already fully protonated (e.g. from a
                            repair cache).
        """
        from openmm.app import Modeller, ForceField
        from openmm import app as app_mod
        import openmm
        import openmm.unit as unit

        ff = self._get_forcefield()

        modeller = Modeller(topology, positions)
        if not skip_hydrogens:
            modeller.addHydrogens(ff, pH=7.0)
        topology = modeller.getTopology()
        positions = modeller.getPositions()

        if self.water_model == "implicit":
            system = ff.createSystem(
                topology,
                nonbondedMethod=app_mod.NoCutoff,
                constraints=app_mod.HBonds,
            )
        else:
            system = ff.createSystem(
                topology,
                nonbondedMethod=app_mod.PME,
                nonbondedCutoff=1.0 * unit.nanometers,
                constraints=app_mod.HBonds,
            )

        if constrain_backbone:
            from openmm import CustomExternalForce
            import openmm.unit as u
            restraint = CustomExternalForce(
                "0.5*k*((x-x0)^2+(y-y0)^2+(z-z0)^2)"
            )
            restraint.addGlobalParameter("k", 1000.0 * u.kilojoules_per_mole / u.nanometers**2)
            restraint.addPerParticleParameter("x0")
            restraint.addPerParticleParameter("y0")
            restraint.addPerParticleParameter("z0")

            bb_names = {"CA", "C", "N", "O"}
            for atom in topology.atoms():
                if atom.name in bb_names:
                    pos_i = positions[atom.index]
                    x0 = pos_i[0].value_in_unit(u.nanometers) if hasattr(pos_i[0], 'value_in_unit') else float(pos_i[0])
                    y0 = pos_i[1].value_in_unit(u.nanometers) if hasattr(pos_i[1], 'value_in_unit') else float(pos_i[1])
                    z0 = pos_i[2].value_in_unit(u.nanometers) if hasattr(pos_i[2], 'value_in_unit') else float(pos_i[2])
                    restraint.addParticle(atom.index, [x0, y0, z0])
            system.addForce(restraint)

        return system, topology, positions

    def _minimise(self, system, topology, positions, max_iterations: int = 500,
                   tolerance: float = 1.0):
        """Energy-minimise and return (positions, energy_kj).

        Energy is returned in kJ/mol as reported by OpenMM.

        Args:
            tolerance: Convergence tolerance in kJ/mol/nm (default 1.0).
        """
        from openmm import LangevinMiddleIntegrator, Context
        import openmm.unit as unit

        integrator = LangevinMiddleIntegrator(
            300 * unit.kelvin, 1.0 / unit.picosecond, 0.002 * unit.picoseconds
        )
        platform = self._get_platform()
        context = Context(system, integrator, platform)
        context.setPositions(positions)

        from openmm import LocalEnergyMinimizer
        LocalEnergyMinimizer.minimize(
            context,
            tolerance=tolerance * unit.kilojoules_per_mole / unit.nanometer,
            maxIterations=max_iterations,
        )

        state = context.getState(getPositions=True, getEnergy=True)
        energy_kj = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
        return state.getPositions(), energy_kj

    def _decompose_energy(self, system, topology, positions) -> Dict[str, float]:
        """Evaluate each force-group independently and return per-term energies (kcal/mol)."""
        from openmm import LangevinMiddleIntegrator, Context
        import openmm.unit as unit

        for i, force in enumerate(system.getForces()):
            force.setForceGroup(i)

        integrator = LangevinMiddleIntegrator(
            300 * unit.kelvin, 1.0 / unit.picosecond, 0.002 * unit.picoseconds
        )
        platform = self._get_platform()
        context = Context(system, integrator, platform)
        context.setPositions(positions)

        terms: Dict[str, float] = {}
        for i, force in enumerate(system.getForces()):
            state = context.getState(getEnergy=True, groups={i})
            e_kj = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
            name = type(force).__name__
            terms[name] = round(e_kj * self.KJ_TO_KCAL, 4)

        state_all = context.getState(getEnergy=True)
        total_kj = state_all.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
        terms["total"] = round(total_kj * self.KJ_TO_KCAL, 4)
        return terms

    # ------------------------------------------------------------------
    # Prepare (repair-once cache)
    # ------------------------------------------------------------------

    def _prepare_structure(self, source, max_iterations: int = 2000,
                           tolerance: float = 1.0) -> _RepairCache:
        """Protonate and thoroughly minimise a structure once.

        This is the "repair once, mutate from repaired" pattern:
        PDBFixer fills missing atoms/hydrogens, then the structure is
        minimised to convergence.  The resulting cache can be reused
        for many mutations without repeated WT hydrogen placement.
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)
        fixer.findMissingResidues()
        fixer.findMissingAtoms()
        fixer.addMissingAtoms()
        fixer.addMissingHydrogens(7.0)

        system, top, pos = self._build_system(fixer.topology, fixer.positions)
        pos_min, e_kj = self._minimise(system, top, pos,
                                        max_iterations=max_iterations,
                                        tolerance=tolerance)

        pdb_out = _topology_to_pdb(top, pos_min)

        return _RepairCache(
            pdb_string=pdb_out,
            topology=top,
            positions=pos_min,
            system=system,
            energy_kj=e_kj,
            energy_kcal=round(e_kj * self.KJ_TO_KCAL, 4),
        )

    def prepare(self, source, max_iterations: int = 2000,
                tolerance: float = 1.0) -> _RepairCache:
        """Prepare a wild-type structure for efficient batch mutations.

        Protonates, fills missing atoms, and minimises the structure
        thoroughly.  Pass the returned cache to :meth:`mutate` or
        :meth:`mutate_batch` to avoid redundant WT processing.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            max_iterations: Minimisation steps (default 2000).
            tolerance: Convergence tolerance in kJ/mol/nm (default 1.0).

        Returns:
            A ``_RepairCache`` that can be passed to ``mutate()`` and
            ``mutate_batch()`` for deterministic WT energy.
        """
        return self._prepare_structure(source, max_iterations, tolerance)

    # ------------------------------------------------------------------
    # Repair (RepairPDB equivalent)
    # ------------------------------------------------------------------

    def repair(
        self,
        source,
        pH: float = 7.0,
        max_iterations: int = 2000,
    ) -> RepairResult:
        """Repair a structure: fix missing atoms/residues, add hydrogens, minimise.

        Repairs protein structure by fixing clashes and adding missing atoms.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame of atoms.
            pH: pH for protonation (default 7.0).
            max_iterations: Maximum minimisation steps.

        Returns:
            RepairResult with repaired PDB and energy change.
        """
        from pdbfixer import PDBFixer
        import openmm.unit as unit

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)

        fixer.findMissingResidues()
        fixer.findMissingAtoms()
        fixer.addMissingAtoms()
        fixer.addMissingHydrogens(pH)

        topology = fixer.topology
        positions = fixer.positions

        system_pre, top_pre, pos_pre = self._build_system(topology, positions)
        _, e_before = self._minimise(system_pre, top_pre, pos_pre, max_iterations=0)

        system, top, pos = self._build_system(topology, positions)
        pos_min, e_after = self._minimise(system, top, pos, max_iterations)

        pdb_out = _topology_to_pdb(top, pos_min)

        return RepairResult(
            topology=top,
            positions=pos_min,
            energy_before=round(e_before * self.KJ_TO_KCAL, 4),
            energy_after=round(e_after * self.KJ_TO_KCAL, 4),
            pdb_string=pdb_out,
        )

    # ------------------------------------------------------------------
    # Stability
    # ------------------------------------------------------------------

    def calculate_stability(
        self,
        source,
        max_iterations: int = 2000,
    ) -> StabilityResult:
        """Calculate total potential energy with per-term decomposition.

        Calculates protein stability using energy minimization.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            max_iterations: Minimisation steps before scoring.

        Returns:
            StabilityResult with total energy and per-force-term breakdown (kcal/mol).
        """
        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)
        fixer.findMissingResidues()
        fixer.findMissingAtoms()
        fixer.addMissingAtoms()
        fixer.addMissingHydrogens(7.0)

        system, top, pos = self._build_system(fixer.topology, fixer.positions)
        pos_min, _ = self._minimise(system, top, pos, max_iterations)

        terms = self._decompose_energy(system, top, pos_min)
        pdb_out = _topology_to_pdb(top, pos_min)

        return StabilityResult(
            total_energy=terms.get("total", 0.0),
            energy_terms=terms,
            pdb_string=pdb_out,
        )

    # ------------------------------------------------------------------
    # Mutate (BuildModel equivalent)
    # ------------------------------------------------------------------

    def mutate(
        self,
        source,
        mutations: List[Union[Mutation, str]],
        chain: str = "A",
        n_runs: int = 3,
        max_iterations: int = 2000,
        constrain_backbone: bool = True,
        keep_statistics: bool = True,
        use_mean: bool = False,
        _repair_cache: Optional[_RepairCache] = None,
    ) -> MutationResult:
        """Apply one or more point mutations, minimise, and compute ddG.

        Mutations can be ``Mutation`` objects or short strings like ``'G13L'``.
        Multiple mutations in the same call are applied simultaneously.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
                    Ignored when ``_repair_cache`` is provided.
            mutations: List of Mutation objects or strings (e.g. ``'G13L'``).
            chain: Default chain ID applied when parsing mutation strings
                   (default ``'A'``).
            n_runs: Number of independent minimisation runs for the mutant
                    (default 3).
            max_iterations: Minimisation steps per run (default 2000).
            constrain_backbone: If True, restrain backbone atoms during mutant
                                minimisation, allowing only sidechain flexibility.
            keep_statistics: If True, collect and return statistical summary
                            (mean, SD, CI) from all runs (default True).
            use_mean: If True, use mean energy for ddG calculation (industry-standard).
                     If False, use best (minimum) energy (default False).
            _repair_cache: Pre-processed WT from :meth:`prepare`.  When
                           provided the WT energy comes from the cache,
                           eliminating redundant hydrogen placement and
                           minimisation.

        Returns:
            MutationResult with wild-type energy, mutant energy, ddG,
            mutant PDB strings, and a full energy-term DataFrame. If
            keep_statistics=True, also includes mean, SD, CI, and convergence
            metrics.
        """
        from pdbfixer import PDBFixer
        from openmm.app import PDBFile

        parsed: List[Mutation] = []
        for m in mutations:
            if isinstance(m, str):
                parsed.append(Mutation.from_str(m, chain=chain))
            else:
                parsed.append(m)

        # --- Wild-type: use cache or build from scratch ---
        if _repair_cache is not None:
            pdb_text = _repair_cache.pdb_string
            e_wt = _repair_cache.energy_kcal
            pos_wt_min = _repair_cache.positions
            top_wt = _repair_cache.topology
            sys_wt = _repair_cache.system
        else:
            pdb_text = _load_pdb(source)
            fixer_wt = _pdb_string_to_fixer(pdb_text)
            fixer_wt.findMissingResidues()
            fixer_wt.findMissingAtoms()
            fixer_wt.addMissingAtoms()
            fixer_wt.addMissingHydrogens(7.0)

            sys_wt, top_wt, pos_wt = self._build_system(
                fixer_wt.topology, fixer_wt.positions)
            pos_wt_min, e_wt_kj = self._minimise(
                sys_wt, top_wt, pos_wt, max_iterations)
            e_wt = round(e_wt_kj * self.KJ_TO_KCAL, 4)
            pdb_text = _topology_to_pdb(top_wt, pos_wt_min)

        # --- Build mutant from the (repaired) WT PDB ---
        fixer_mut = _pdb_string_to_fixer(pdb_text)
        fixer_mut.findMissingResidues()

        pdbfixer_mutations = []
        for m in parsed:
            pdbfixer_mutations.append(f"{m.wt_residue}-{m.position}-{m.mut_residue}")

        chains_used = {m.chain for m in parsed}
        for chain_id in chains_used:
            fixer_mut.applyMutations(pdbfixer_mutations, chain_id)

        fixer_mut.findMissingAtoms()
        fixer_mut.addMissingAtoms()
        fixer_mut.addMissingHydrogens(7.0)

        mut_label = "+".join(m.label for m in parsed)

        best_e_mut = None
        best_pos_mut = None
        best_top_mut = None
        best_sys_mut = None
        all_run_energies_list = []

        for run in range(n_runs):
            sys_m, top_m, pos_m = self._build_system(
                fixer_mut.topology, fixer_mut.positions,
                constrain_backbone=constrain_backbone,
                skip_hydrogens=(_repair_cache is not None),
            )
            pos_m_min, e_m_kj = self._minimise(
                sys_m, top_m, pos_m, max_iterations)
            e_m = e_m_kj * self.KJ_TO_KCAL

            # Store all run energies
            all_run_energies_list.append(e_m)

            if best_e_mut is None or e_m < best_e_mut:
                best_e_mut = e_m
                best_pos_mut = pos_m_min
                best_top_mut = top_m
                best_sys_mut = sys_m

        # Compute primary ddG (based on use_mean flag)
        if use_mean and keep_statistics:
            primary_energy = float(np.mean(all_run_energies_list))
            ddg = primary_energy - e_wt
        else:
            primary_energy = best_e_mut
            ddg = best_e_mut - e_wt

        # Decompose using the exact system/topology that produced the
        # minimised positions (no redundant _build_system call).
        mut_terms = self._decompose_energy(best_sys_mut, best_top_mut, best_pos_mut)
        wt_terms = self._decompose_energy(sys_wt, top_wt, pos_wt_min)

        term_rows = []
        for key in sorted(set(wt_terms.keys()) | set(mut_terms.keys())):
            wt_val = wt_terms.get(key, 0.0)
            mt_val = mut_terms.get(key, 0.0)
            term_rows.append({
                "term": key,
                "wt_energy": wt_val,
                "mutant_energy": mt_val,
                "delta": round(mt_val - wt_val, 4),
            })
        terms_df = pl.DataFrame(term_rows)

        pdb_mut = _topology_to_pdb(best_top_mut, best_pos_mut)

        # Compute statistics if requested
        stats_dict = None
        if keep_statistics and n_runs > 1:
            stats_dict = _compute_energy_statistics(all_run_energies_list, e_wt, n_runs)

            # Warn if convergence is poor
            if stats_dict["cv"] > 0.1:
                print(f"Warning: {mut_label} has high energy variability (CV={stats_dict['cv']:.2f}). "
                      f"Consider increasing n_runs for more reliable results.")

        return MutationResult(
            wt_energy=e_wt,
            mutant_energies={mut_label: round(primary_energy, 4)},
            ddg={mut_label: round(ddg, 4)},
            mutant_pdbs={mut_label: pdb_mut},
            energy_terms=terms_df,
            # Statistical fields
            all_run_energies={mut_label: all_run_energies_list} if keep_statistics else None,
            mean_energy={mut_label: stats_dict["mean"]} if stats_dict else None,
            sd_energy={mut_label: stats_dict["sd"]} if stats_dict else None,
            min_energy={mut_label: stats_dict["min"]} if stats_dict else None,
            max_energy={mut_label: stats_dict["max"]} if stats_dict else None,
            ddg_mean={mut_label: stats_dict["ddg_mean"]} if stats_dict else None,
            ddg_sd={mut_label: stats_dict["ddg_sd"]} if stats_dict else None,
            ddg_ci_95={mut_label: stats_dict["ci_95"]} if stats_dict else None,
            convergence_metric={mut_label: stats_dict["cv"]} if stats_dict else None,
        )

    # ------------------------------------------------------------------
    # Binding energy (AnalyseComplex equivalent)
    # ------------------------------------------------------------------

    def calculate_binding_energy(
        self,
        source,
        chains_a: List[str],
        chains_b: List[str],
        max_iterations: int = 2000,
    ) -> BindingResult:
        """Calculate binding energy between two groups of chains.

        Calculates binding energy for protein-protein complexes.

        E_binding = E_complex - (E_chains_a + E_chains_b)

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            chains_a: Chain IDs for the first group (e.g. ['A']).
            chains_b: Chain IDs for the second group (e.g. ['B']).
            max_iterations: Minimisation steps.

        Returns:
            BindingResult with binding energy, component energies,
            and interface residues.
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)
        fixer.findMissingResidues()
        fixer.findMissingAtoms()
        fixer.addMissingAtoms()
        fixer.addMissingHydrogens(7.0)

        # --- Complex energy ---
        sys_c, top_c, pos_c = self._build_system(fixer.topology, fixer.positions)
        pos_c_min, e_complex_kj = self._minimise(sys_c, top_c, pos_c, max_iterations)
        e_complex = e_complex_kj * self.KJ_TO_KCAL

        # Use minimised complex positions to extract chain subsets
        top_a, pos_a = _extract_chains(top_c, pos_c_min, chains_a)
        top_b, pos_b = _extract_chains(top_c, pos_c_min, chains_b)

        sys_a, top_a2, pos_a2 = self._build_system(top_a, pos_a)
        _, e_a_kj = self._minimise(sys_a, top_a2, pos_a2, max_iterations)
        e_a = e_a_kj * self.KJ_TO_KCAL

        sys_b, top_b2, pos_b2 = self._build_system(top_b, pos_b)
        _, e_b_kj = self._minimise(sys_b, top_b2, pos_b2, max_iterations)
        e_b = e_b_kj * self.KJ_TO_KCAL

        e_binding = e_complex - (e_a + e_b)

        interface_df = _find_interface_residues(
            top_c, pos_c_min, chains_a, chains_b, cutoff_nm=0.5
        )

        return BindingResult(
            binding_energy=round(e_binding, 4),
            complex_energy=round(e_complex, 4),
            chain_a_energy=round(e_a, 4),
            chain_b_energy=round(e_b, 4),
            interface_residues=interface_df,
        )

    # ------------------------------------------------------------------
    # Interface mutagenesis (mutation-to-binding pipeline)
    # ------------------------------------------------------------------

    def mutate_interface(
        self,
        source,
        mutations: Dict[str, List[Union[Mutation, str]]],
        chains_a: List[str],
        chains_b: List[str],
        max_iterations: int = 2000,
        n_runs: int = 3,
        constrain_backbone: bool = True,
    ) -> InterfaceMutationResult:
        """Apply mutations to protein-protein interface and compute ΔΔG_binding.

        This is a pipeline that combines mutate() and calculate_binding_energy()
        to automatically compute how mutations affect binding affinity.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame (complex).
            mutations: Dict mapping chain ID to list of mutations.
                      E.g., {"A": ["F13A", "W14L"], "B": ["Y25F"]}
            chains_a: Chain IDs for the first binding partner (e.g. ['A']).
            chains_b: Chain IDs for the second binding partner (e.g. ['B']).
            max_iterations: Minimisation steps.
            n_runs: Number of independent minimisation runs.
            constrain_backbone: Restrain backbone during mutant minimisation.

        Returns:
            InterfaceMutationResult with ΔΔG_binding, ΔΔG_stability per chain,
            and component energies.
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)

        # --- Step 1: Calculate WT binding energy ---
        print("Calculating wild-type binding energy...")
        wt_binding = self.calculate_binding_energy(
            pdb_text, chains_a, chains_b, max_iterations=max_iterations
        )

        # --- Step 2: Apply all mutations to create mutant complex ---
        print("Applying mutations to complex...")
        fixer_mut = _pdb_string_to_fixer(pdb_text)
        fixer_mut.findMissingResidues()

        # Parse mutations
        mutations_by_chain = {}
        for chain_id, mut_list in mutations.items():
            parsed = []
            for m in mut_list:
                if isinstance(m, str):
                    parsed.append(Mutation.from_str(m, chain=chain_id))
                else:
                    parsed.append(m)
            mutations_by_chain[chain_id] = parsed

        # Apply mutations per chain using PDBFixer
        for chain_id, mut_objs in mutations_by_chain.items():
            pdbfixer_mutations = []
            for m in mut_objs:
                pdbfixer_mutations.append(f"{m.wt_residue}-{m.position}-{m.mut_residue}")
            fixer_mut.applyMutations(pdbfixer_mutations, chain_id)

        fixer_mut.findMissingAtoms()
        fixer_mut.addMissingAtoms()
        fixer_mut.addMissingHydrogens(7.0)

        # Minimize mutant complex (best of n_runs)
        best_e_complex = None
        best_pos_complex = None
        best_top_complex = None

        for run in range(n_runs):
            sys_c, top_c, pos_c = self._build_system(
                fixer_mut.topology, fixer_mut.positions,
                constrain_backbone=constrain_backbone
            )
            pos_c_min, e_c_kj = self._minimise(sys_c, top_c, pos_c, max_iterations)
            e_c = e_c_kj * self.KJ_TO_KCAL

            if best_e_complex is None or e_c < best_e_complex:
                best_e_complex = e_c
                best_pos_complex = pos_c_min
                best_top_complex = top_c

        # --- Step 3: Extract and minimize mutant chains separately ---
        print("Calculating mutant component energies...")
        top_a_mut, pos_a_mut = _extract_chains(best_top_complex, best_pos_complex, chains_a)
        top_b_mut, pos_b_mut = _extract_chains(best_top_complex, best_pos_complex, chains_b)

        sys_a_mut, top_a_mut2, pos_a_mut2 = self._build_system(top_a_mut, pos_a_mut)
        _, e_a_mut_kj = self._minimise(sys_a_mut, top_a_mut2, pos_a_mut2, max_iterations)
        e_a_mut = e_a_mut_kj * self.KJ_TO_KCAL

        sys_b_mut, top_b_mut2, pos_b_mut2 = self._build_system(top_b_mut, pos_b_mut)
        _, e_b_mut_kj = self._minimise(sys_b_mut, top_b_mut2, pos_b_mut2, max_iterations)
        e_b_mut = e_b_mut_kj * self.KJ_TO_KCAL

        # --- Step 4: Calculate mutant binding energy ---
        e_binding_mut = best_e_complex - (e_a_mut + e_b_mut)

        # --- Step 5: Compute ΔΔG values ---
        ddg_binding = e_binding_mut - wt_binding.binding_energy
        ddg_stability_a = e_a_mut - wt_binding.chain_a_energy
        ddg_stability_b = e_b_mut - wt_binding.chain_b_energy

        # --- Step 6: Get mutant PDB ---
        mutant_pdb = _topology_to_pdb(best_top_complex, best_pos_complex)

        print(f"ΔΔG_binding: {ddg_binding:+.2f} kcal/mol")
        print(f"ΔΔG_stability (chain A): {ddg_stability_a:+.2f} kcal/mol")
        print(f"ΔΔG_stability (chain B): {ddg_stability_b:+.2f} kcal/mol")

        return InterfaceMutationResult(
            wt_binding_energy=wt_binding.binding_energy,
            mutant_binding_energy=round(e_binding_mut, 4),
            ddg_binding=round(ddg_binding, 4),
            wt_complex_energy=wt_binding.complex_energy,
            mutant_complex_energy=round(best_e_complex, 4),
            wt_chain_a_energy=wt_binding.chain_a_energy,
            wt_chain_b_energy=wt_binding.chain_b_energy,
            mutant_chain_a_energy=round(e_a_mut, 4),
            mutant_chain_b_energy=round(e_b_mut, 4),
            ddg_stability_a=round(ddg_stability_a, 4),
            ddg_stability_b=round(ddg_stability_b, 4),
            interface_residues=wt_binding.interface_residues,
            mutations_by_chain=mutations_by_chain,
            mutant_pdb=mutant_pdb,
        )

    # ------------------------------------------------------------------
    # Alanine scan (AlaScan equivalent)
    # ------------------------------------------------------------------

    def alanine_scan(
        self,
        source,
        chain: str,
        positions: Optional[List[int]] = None,
        max_iterations: int = 2000,
        constrain_backbone: bool = True,
    ) -> pl.DataFrame:
        """Perform alanine scanning on a chain.

        Performs systematic alanine scanning mutagenesis.  Each non-Ala/Gly position is mutated
        to alanine and the ddG is reported.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            chain: Chain ID to scan (e.g. 'A').
            positions: Specific residue numbers to scan.  If None, scans all
                       non-Ala/Gly standard residues.
            max_iterations: Minimisation steps per mutant.
            constrain_backbone: Freeze backbone atoms during minimisation.

        Returns:
            Polars DataFrame with columns:
            [chain, position, wt_residue, ddg_kcal_mol].
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)

        scan_positions = self._get_scannable_positions(
            fixer.topology, chain, positions, skip_residues={"ALA", "GLY"}
        )

        if not scan_positions:
            print(f"No scannable positions found on chain {chain}.")
            return pl.DataFrame(schema={
                "chain": pl.Utf8, "position": pl.Int64,
                "wt_residue": pl.Utf8, "ddg_kcal_mol": pl.Float64,
            })

        print(f"Alanine scan: {len(scan_positions)} positions on chain {chain}")

        rows = []
        for pos_num, wt_res in scan_positions:
            mut = Mutation(chain=chain, position=pos_num,
                          wt_residue=wt_res, mut_residue="ALA")
            try:
                result = self.mutate(
                    pdb_text, [mut],
                    max_iterations=max_iterations,
                    constrain_backbone=constrain_backbone,
                )
                ddg_val = list(result.ddg.values())[0]
            except Exception as e:
                print(f"  {mut.label}: FAILED ({e})")
                ddg_val = float("nan")

            rows.append({
                "chain": chain,
                "position": pos_num,
                "wt_residue": wt_res,
                "ddg_kcal_mol": round(ddg_val, 4),
            })
            print(f"  {mut.label}: ddG = {ddg_val:+.2f} kcal/mol")

        return pl.DataFrame(rows)

    # ------------------------------------------------------------------
    # Position scan / PSSM
    # ------------------------------------------------------------------

    def position_scan(
        self,
        source,
        chain: str,
        positions: List[int],
        max_iterations: int = 2000,
        constrain_backbone: bool = True,
    ) -> pl.DataFrame:
        """Scan all 20 amino acids at specified positions.

        Generates position-specific scoring matrix by scanning all amino acids.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            chain: Chain ID.
            positions: List of residue numbers to scan.
            max_iterations: Minimisation steps per mutant.
            constrain_backbone: Freeze backbone atoms during minimisation.

        Returns:
            Polars DataFrame with columns:
            [chain, position, wt_residue, mut_residue, ddg_kcal_mol].
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)

        resmap = {}
        for chain_obj in fixer.topology.chains():
            if chain_obj.id == chain:
                for res in chain_obj.residues():
                    resmap[int(res.id)] = res.name

        rows = []
        total = len(positions) * 20
        done = 0
        for pos_num in positions:
            wt_res = resmap.get(pos_num)
            if wt_res is None or wt_res not in STANDARD_RESIDUES:
                continue

            for mut_res in ALL_AMINO_ACIDS:
                done += 1
                if mut_res == wt_res:
                    rows.append({
                        "chain": chain, "position": pos_num,
                        "wt_residue": wt_res, "mut_residue": mut_res,
                        "ddg_kcal_mol": 0.0,
                    })
                    continue

                mut = Mutation(chain=chain, position=pos_num,
                              wt_residue=wt_res, mut_residue=mut_res)
                try:
                    result = self.mutate(
                        pdb_text, [mut],
                        max_iterations=max_iterations,
                        constrain_backbone=constrain_backbone,
                    )
                    ddg_val = list(result.ddg.values())[0]
                except Exception as e:
                    ddg_val = float("nan")

                rows.append({
                    "chain": chain, "position": pos_num,
                    "wt_residue": wt_res, "mut_residue": mut_res,
                    "ddg_kcal_mol": round(ddg_val, 4),
                })

            print(f"  Position {pos_num} ({wt_res}) complete [{done}/{total}]")

        return pl.DataFrame(rows)

    # ------------------------------------------------------------------
    # Per-residue energy (SequenceDetail equivalent)
    # ------------------------------------------------------------------

    def per_residue_energy(
        self,
        source,
        max_iterations: int = 2000,
    ) -> pl.DataFrame:
        """Approximate per-residue energy contribution.

        Computes per-residue energy decomposition.

        Uses an alanine-subtraction approach: for each residue, the energy
        difference between the full structure and the Ala-mutant estimates
        that residue's energetic contribution (positive = stabilising,
        negative = destabilising).

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            max_iterations: Minimisation steps.

        Returns:
            Polars DataFrame with columns:
            [chain, residue_number, residue_name, energy_contribution_kcal_mol].
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)

        all_positions = []
        for chain_obj in fixer.topology.chains():
            for res in chain_obj.residues():
                if res.name in STANDARD_RESIDUES and res.name not in ("ALA", "GLY"):
                    all_positions.append((chain_obj.id, int(res.id), res.name))

        if not all_positions:
            return pl.DataFrame(schema={
                "chain": pl.Utf8, "residue_number": pl.Int64,
                "residue_name": pl.Utf8, "energy_contribution_kcal_mol": pl.Float64,
            })

        print(f"Per-residue energy: {len(all_positions)} residues via Ala-subtraction")

        # Baseline WT energy
        stab = self.calculate_stability(pdb_text, max_iterations=max_iterations)
        e_wt = stab.total_energy

        rows = []
        for chain_id, pos_num, res_name in all_positions:
            mut = Mutation(chain=chain_id, position=pos_num,
                          wt_residue=res_name, mut_residue="ALA")
            try:
                result = self.mutate(
                    pdb_text, [mut],
                    max_iterations=max_iterations,
                    constrain_backbone=True,
                )
                ddg_val = list(result.ddg.values())[0]
                contribution = -ddg_val
            except Exception:
                contribution = float("nan")

            rows.append({
                "chain": chain_id,
                "residue_number": pos_num,
                "residue_name": res_name,
                "energy_contribution_kcal_mol": round(contribution, 4),
            })

        # Ala and Gly get 0.0 (self-reference)
        for chain_obj in fixer.topology.chains():
            for res in chain_obj.residues():
                if res.name in ("ALA", "GLY"):
                    rows.append({
                        "chain": chain_obj.id,
                        "residue_number": int(res.id),
                        "residue_name": res.name,
                        "energy_contribution_kcal_mol": 0.0,
                    })

        df = pl.DataFrame(rows).sort(["chain", "residue_number"])
        return df

    # ------------------------------------------------------------------
    # CSV-based batch mutations
    # ------------------------------------------------------------------

    @staticmethod
    def load_mutations(csv_path: str) -> pl.DataFrame:
        """Load a mutation list from a CSV file.

        The CSV must contain a ``mutation`` column with strings like ``G13L``.
        Optional columns:

        - ``chain`` — chain identifier (defaults to ``'A'`` if absent).
        - Any other columns (e.g. ``score``, ``source``, ``notes``) are
          preserved as metadata and carried through to the results.

        Args:
            csv_path: Path to a CSV file.

        Returns:
            Polars DataFrame with at least ``[mutation, chain]`` plus any
            extra columns from the CSV.
        """
        df = pl.read_csv(csv_path)

        if "mutation" not in df.columns:
            raise ValueError(
                f"CSV must contain a 'mutation' column. "
                f"Found columns: {df.columns}"
            )

        if "chain" not in df.columns:
            df = df.with_columns(pl.lit("A").alias("chain"))

        return df

    def mutate_batch(
        self,
        source,
        mutations_df: pl.DataFrame,
        max_iterations: int = 2000,
        n_runs: int = 3,
        constrain_backbone: bool = True,
        _repair_cache: Optional[_RepairCache] = None,
    ) -> pl.DataFrame:
        """Run every mutation in a DataFrame and return results.

        Each row is treated as an independent single-point mutation.
        Any extra columns in the input DataFrame are preserved in the output.

        The wild-type structure is prepared *once* and reused for every
        mutation, giving deterministic WT energies and eliminating
        hydrogen-placement noise.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            mutations_df: DataFrame with ``mutation`` and ``chain`` columns
                          (as returned by :meth:`load_mutations`).
            max_iterations: Minimisation steps per mutation (default 2000).
            n_runs: Independent minimisation runs per mutation (default 3).
            constrain_backbone: Restrain backbone during minimisation.
            _repair_cache: Optional pre-processed WT from :meth:`prepare`.
                           If not provided, one is created automatically.

        Returns:
            Polars DataFrame with the input columns plus
            ``[wt_energy, mutant_energy, ddg_kcal_mol]``.
        """
        if "mutation" not in mutations_df.columns:
            raise ValueError("DataFrame must contain a 'mutation' column.")
        if "chain" not in mutations_df.columns:
            mutations_df = mutations_df.with_columns(pl.lit("A").alias("chain"))

        if _repair_cache is None:
            print("Preparing wild-type structure (repair-once)...")
            _repair_cache = self._prepare_structure(
                source, max_iterations=max_iterations)

        result_rows = []
        total = mutations_df.height
        for i, row in enumerate(mutations_df.iter_rows(named=True)):
            mut_str = row["mutation"]
            chain_id = row["chain"]

            try:
                mut = Mutation.from_str(mut_str, chain=chain_id)
                result = self.mutate(
                    source, [mut],
                    chain=chain_id,
                    n_runs=n_runs,
                    max_iterations=max_iterations,
                    constrain_backbone=constrain_backbone,
                    _repair_cache=_repair_cache,
                )
                ddg_val = list(result.ddg.values())[0]
                wt_e = result.wt_energy
                mut_e = list(result.mutant_energies.values())[0]
            except Exception as e:
                print(f"  [{i+1}/{total}] {mut_str} chain {chain_id}: FAILED ({e})")
                ddg_val = float("nan")
                wt_e = float("nan")
                mut_e = float("nan")

            out = {**row, "wt_energy": wt_e, "mutant_energy": mut_e, "ddg_kcal_mol": ddg_val}
            result_rows.append(out)
            print(f"  [{i+1}/{total}] {mut_str} chain {chain_id}: ddG = {ddg_val:+.2f} kcal/mol")

        return pl.DataFrame(result_rows)

    # ------------------------------------------------------------------
    # Disulfide Bond Analysis
    # ------------------------------------------------------------------

    def detect_disulfides(
        self,
        source,
        distance_cutoff: float = 2.5,
    ) -> pl.DataFrame:
        """Detect disulfide bonds in a structure.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            distance_cutoff: Maximum S-S distance for disulfide bond (Ã…, default 2.5).

        Returns:
            DataFrame with columns:
            [chain1, residue1, resname1, chain2, residue2, resname2, distance].
        """
        from pdbfixer import PDBFixer

        pdb_text = _load_pdb(source)
        fixer = _pdb_string_to_fixer(pdb_text)

        # Get topology and positions
        topology = fixer.topology
        positions = fixer.positions

        return _detect_disulfide_bonds(topology, positions, distance_cutoff)

    def analyze_mutation_disulfide_impact(
        self,
        source,
        mutations: List[Union[Mutation, str]],
        chain: str = "A",
        distance_cutoff: float = 2.5,
    ) -> Dict[str, any]:
        """Analyze how mutations affect disulfide bonds.

        Args:
            source: PDB file path, PDB string, or Polars DataFrame.
            mutations: List of Mutation objects or strings (e.g. ``'C42A'``).
            chain: Default chain ID (default ``'A'``).
            distance_cutoff: Maximum S-S distance for disulfide bond (Ã…).

        Returns:
            Dict with:
            - wt_disulfides: DataFrame of WT disulfide bonds
            - mutant_disulfides: DataFrame of mutant disulfide bonds
            - broken_bonds: List of broken disulfide bonds
            - new_bonds: List of new disulfide bonds formed
            - affected_cysteines: List of mutated cysteine positions
        """
        from pdbfixer import PDBFixer

        # Parse mutations
        parsed = []
        for m in mutations:
            if isinstance(m, str):
                parsed.append(Mutation.from_str(m, chain=chain))
            else:
                parsed.append(m)

        # Detect WT disulfides
        wt_disulfides = self.detect_disulfides(source, distance_cutoff)

        # Check if any cysteines are being mutated
        affected_cysteines = []
        for m in parsed:
            if m.wt_residue == "CYS":
                affected_cysteines.append((m.chain, m.position))

        # Build mutant structure (simplified - just apply mutations)
        pdb_text = _load_pdb(source)
        fixer_mut = _pdb_string_to_fixer(pdb_text)
        fixer_mut.findMissingResidues()

        pdbfixer_mutations = []
        for m in parsed:
            pdbfixer_mutations.append(f"{m.wt_residue}-{m.position}-{m.mut_residue}")

        chains_used = {m.chain for m in parsed}
        for chain_id in chains_used:
            fixer_mut.applyMutations(pdbfixer_mutations, chain_id)

        fixer_mut.findMissingAtoms()
        fixer_mut.addMissingAtoms()

        # Detect mutant disulfides
        mutant_disulfides = _detect_disulfide_bonds(
            fixer_mut.topology, fixer_mut.positions, distance_cutoff
        )

        # Identify broken and new bonds
        wt_bonds = set()
        if wt_disulfides.height > 0:
            for row in wt_disulfides.iter_rows(named=True):
                bond = tuple(sorted([
                    (row["chain1"], row["residue1"]),
                    (row["chain2"], row["residue2"])
                ]))
                wt_bonds.add(bond)

        mutant_bonds = set()
        if mutant_disulfides.height > 0:
            for row in mutant_disulfides.iter_rows(named=True):
                bond = tuple(sorted([
                    (row["chain1"], row["residue1"]),
                    (row["chain2"], row["residue2"])
                ]))
                mutant_bonds.add(bond)

        broken_bonds = list(wt_bonds - mutant_bonds)
        new_bonds = list(mutant_bonds - wt_bonds)

        return {
            "wt_disulfides": wt_disulfides,
            "mutant_disulfides": mutant_disulfides,
            "broken_bonds": broken_bonds,
            "new_bonds": new_bonds,
            "affected_cysteines": affected_cysteines,
        }

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _get_scannable_positions(
        self, topology, chain_id: str,
        positions: Optional[List[int]],
        skip_residues: Optional[set] = None,
    ) -> List[Tuple[int, str]]:
        """Return (residue_number, residue_name) for positions eligible for scanning."""
        skip = skip_residues or set()
        result = []
        for chain_obj in topology.chains():
            if chain_obj.id != chain_id:
                continue
            for res in chain_obj.residues():
                pos_num = int(res.id)
                if res.name not in STANDARD_RESIDUES:
                    continue
                if res.name in skip:
                    continue
                if positions is not None and pos_num not in positions:
                    continue
                result.append((pos_num, res.name))
        return result

repair(source, pH=7.0, max_iterations=2000)

Repair a structure: fix missing atoms/residues, add hydrogens, minimise.

Repairs protein structure by fixing clashes and adding missing atoms.

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame of atoms.

required
pH float

pH for protonation (default 7.0).

7.0
max_iterations int

Maximum minimisation steps.

2000

Returns:

Type Description
RepairResult

RepairResult with repaired PDB and energy change.

Source code in src/sicifus/mutate.py
def repair(
    self,
    source,
    pH: float = 7.0,
    max_iterations: int = 2000,
) -> RepairResult:
    """Repair a structure: fix missing atoms/residues, add hydrogens, minimise.

    Repairs protein structure by fixing clashes and adding missing atoms.

    Args:
        source: PDB file path, PDB string, or Polars DataFrame of atoms.
        pH: pH for protonation (default 7.0).
        max_iterations: Maximum minimisation steps.

    Returns:
        RepairResult with repaired PDB and energy change.
    """
    from pdbfixer import PDBFixer
    import openmm.unit as unit

    pdb_text = _load_pdb(source)
    fixer = _pdb_string_to_fixer(pdb_text)

    fixer.findMissingResidues()
    fixer.findMissingAtoms()
    fixer.addMissingAtoms()
    fixer.addMissingHydrogens(pH)

    topology = fixer.topology
    positions = fixer.positions

    system_pre, top_pre, pos_pre = self._build_system(topology, positions)
    _, e_before = self._minimise(system_pre, top_pre, pos_pre, max_iterations=0)

    system, top, pos = self._build_system(topology, positions)
    pos_min, e_after = self._minimise(system, top, pos, max_iterations)

    pdb_out = _topology_to_pdb(top, pos_min)

    return RepairResult(
        topology=top,
        positions=pos_min,
        energy_before=round(e_before * self.KJ_TO_KCAL, 4),
        energy_after=round(e_after * self.KJ_TO_KCAL, 4),
        pdb_string=pdb_out,
    )

calculate_stability(source, max_iterations=2000)

Calculate total potential energy with per-term decomposition.

Calculates protein stability using energy minimization.

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame.

required
max_iterations int

Minimisation steps before scoring.

2000

Returns:

Type Description
StabilityResult

StabilityResult with total energy and per-force-term breakdown (kcal/mol).

Source code in src/sicifus/mutate.py
def calculate_stability(
    self,
    source,
    max_iterations: int = 2000,
) -> StabilityResult:
    """Calculate total potential energy with per-term decomposition.

    Calculates protein stability using energy minimization.

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
        max_iterations: Minimisation steps before scoring.

    Returns:
        StabilityResult with total energy and per-force-term breakdown (kcal/mol).
    """
    pdb_text = _load_pdb(source)
    fixer = _pdb_string_to_fixer(pdb_text)
    fixer.findMissingResidues()
    fixer.findMissingAtoms()
    fixer.addMissingAtoms()
    fixer.addMissingHydrogens(7.0)

    system, top, pos = self._build_system(fixer.topology, fixer.positions)
    pos_min, _ = self._minimise(system, top, pos, max_iterations)

    terms = self._decompose_energy(system, top, pos_min)
    pdb_out = _topology_to_pdb(top, pos_min)

    return StabilityResult(
        total_energy=terms.get("total", 0.0),
        energy_terms=terms,
        pdb_string=pdb_out,
    )

mutate(source, mutations, chain='A', n_runs=3, max_iterations=2000, constrain_backbone=True, keep_statistics=True, use_mean=False, _repair_cache=None)

Apply one or more point mutations, minimise, and compute ddG.

Mutations can be Mutation objects or short strings like 'G13L'. Multiple mutations in the same call are applied simultaneously.

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame. Ignored when _repair_cache is provided.

required
mutations List[Union[Mutation, str]]

List of Mutation objects or strings (e.g. 'G13L').

required
chain str

Default chain ID applied when parsing mutation strings (default 'A').

'A'
n_runs int

Number of independent minimisation runs for the mutant (default 3).

3
max_iterations int

Minimisation steps per run (default 2000).

2000
constrain_backbone bool

If True, restrain backbone atoms during mutant minimisation, allowing only sidechain flexibility.

True
keep_statistics bool

If True, collect and return statistical summary (mean, SD, CI) from all runs (default True).

True
use_mean bool

If True, use mean energy for ddG calculation (industry-standard). If False, use best (minimum) energy (default False).

False
_repair_cache Optional[_RepairCache]

Pre-processed WT from :meth:prepare. When provided the WT energy comes from the cache, eliminating redundant hydrogen placement and minimisation.

None

Returns:

Type Description
MutationResult

MutationResult with wild-type energy, mutant energy, ddG,

MutationResult

mutant PDB strings, and a full energy-term DataFrame. If

MutationResult

keep_statistics=True, also includes mean, SD, CI, and convergence

MutationResult

metrics.

Source code in src/sicifus/mutate.py
def mutate(
    self,
    source,
    mutations: List[Union[Mutation, str]],
    chain: str = "A",
    n_runs: int = 3,
    max_iterations: int = 2000,
    constrain_backbone: bool = True,
    keep_statistics: bool = True,
    use_mean: bool = False,
    _repair_cache: Optional[_RepairCache] = None,
) -> MutationResult:
    """Apply one or more point mutations, minimise, and compute ddG.

    Mutations can be ``Mutation`` objects or short strings like ``'G13L'``.
    Multiple mutations in the same call are applied simultaneously.

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
                Ignored when ``_repair_cache`` is provided.
        mutations: List of Mutation objects or strings (e.g. ``'G13L'``).
        chain: Default chain ID applied when parsing mutation strings
               (default ``'A'``).
        n_runs: Number of independent minimisation runs for the mutant
                (default 3).
        max_iterations: Minimisation steps per run (default 2000).
        constrain_backbone: If True, restrain backbone atoms during mutant
                            minimisation, allowing only sidechain flexibility.
        keep_statistics: If True, collect and return statistical summary
                        (mean, SD, CI) from all runs (default True).
        use_mean: If True, use mean energy for ddG calculation (industry-standard).
                 If False, use best (minimum) energy (default False).
        _repair_cache: Pre-processed WT from :meth:`prepare`.  When
                       provided the WT energy comes from the cache,
                       eliminating redundant hydrogen placement and
                       minimisation.

    Returns:
        MutationResult with wild-type energy, mutant energy, ddG,
        mutant PDB strings, and a full energy-term DataFrame. If
        keep_statistics=True, also includes mean, SD, CI, and convergence
        metrics.
    """
    from pdbfixer import PDBFixer
    from openmm.app import PDBFile

    parsed: List[Mutation] = []
    for m in mutations:
        if isinstance(m, str):
            parsed.append(Mutation.from_str(m, chain=chain))
        else:
            parsed.append(m)

    # --- Wild-type: use cache or build from scratch ---
    if _repair_cache is not None:
        pdb_text = _repair_cache.pdb_string
        e_wt = _repair_cache.energy_kcal
        pos_wt_min = _repair_cache.positions
        top_wt = _repair_cache.topology
        sys_wt = _repair_cache.system
    else:
        pdb_text = _load_pdb(source)
        fixer_wt = _pdb_string_to_fixer(pdb_text)
        fixer_wt.findMissingResidues()
        fixer_wt.findMissingAtoms()
        fixer_wt.addMissingAtoms()
        fixer_wt.addMissingHydrogens(7.0)

        sys_wt, top_wt, pos_wt = self._build_system(
            fixer_wt.topology, fixer_wt.positions)
        pos_wt_min, e_wt_kj = self._minimise(
            sys_wt, top_wt, pos_wt, max_iterations)
        e_wt = round(e_wt_kj * self.KJ_TO_KCAL, 4)
        pdb_text = _topology_to_pdb(top_wt, pos_wt_min)

    # --- Build mutant from the (repaired) WT PDB ---
    fixer_mut = _pdb_string_to_fixer(pdb_text)
    fixer_mut.findMissingResidues()

    pdbfixer_mutations = []
    for m in parsed:
        pdbfixer_mutations.append(f"{m.wt_residue}-{m.position}-{m.mut_residue}")

    chains_used = {m.chain for m in parsed}
    for chain_id in chains_used:
        fixer_mut.applyMutations(pdbfixer_mutations, chain_id)

    fixer_mut.findMissingAtoms()
    fixer_mut.addMissingAtoms()
    fixer_mut.addMissingHydrogens(7.0)

    mut_label = "+".join(m.label for m in parsed)

    best_e_mut = None
    best_pos_mut = None
    best_top_mut = None
    best_sys_mut = None
    all_run_energies_list = []

    for run in range(n_runs):
        sys_m, top_m, pos_m = self._build_system(
            fixer_mut.topology, fixer_mut.positions,
            constrain_backbone=constrain_backbone,
            skip_hydrogens=(_repair_cache is not None),
        )
        pos_m_min, e_m_kj = self._minimise(
            sys_m, top_m, pos_m, max_iterations)
        e_m = e_m_kj * self.KJ_TO_KCAL

        # Store all run energies
        all_run_energies_list.append(e_m)

        if best_e_mut is None or e_m < best_e_mut:
            best_e_mut = e_m
            best_pos_mut = pos_m_min
            best_top_mut = top_m
            best_sys_mut = sys_m

    # Compute primary ddG (based on use_mean flag)
    if use_mean and keep_statistics:
        primary_energy = float(np.mean(all_run_energies_list))
        ddg = primary_energy - e_wt
    else:
        primary_energy = best_e_mut
        ddg = best_e_mut - e_wt

    # Decompose using the exact system/topology that produced the
    # minimised positions (no redundant _build_system call).
    mut_terms = self._decompose_energy(best_sys_mut, best_top_mut, best_pos_mut)
    wt_terms = self._decompose_energy(sys_wt, top_wt, pos_wt_min)

    term_rows = []
    for key in sorted(set(wt_terms.keys()) | set(mut_terms.keys())):
        wt_val = wt_terms.get(key, 0.0)
        mt_val = mut_terms.get(key, 0.0)
        term_rows.append({
            "term": key,
            "wt_energy": wt_val,
            "mutant_energy": mt_val,
            "delta": round(mt_val - wt_val, 4),
        })
    terms_df = pl.DataFrame(term_rows)

    pdb_mut = _topology_to_pdb(best_top_mut, best_pos_mut)

    # Compute statistics if requested
    stats_dict = None
    if keep_statistics and n_runs > 1:
        stats_dict = _compute_energy_statistics(all_run_energies_list, e_wt, n_runs)

        # Warn if convergence is poor
        if stats_dict["cv"] > 0.1:
            print(f"Warning: {mut_label} has high energy variability (CV={stats_dict['cv']:.2f}). "
                  f"Consider increasing n_runs for more reliable results.")

    return MutationResult(
        wt_energy=e_wt,
        mutant_energies={mut_label: round(primary_energy, 4)},
        ddg={mut_label: round(ddg, 4)},
        mutant_pdbs={mut_label: pdb_mut},
        energy_terms=terms_df,
        # Statistical fields
        all_run_energies={mut_label: all_run_energies_list} if keep_statistics else None,
        mean_energy={mut_label: stats_dict["mean"]} if stats_dict else None,
        sd_energy={mut_label: stats_dict["sd"]} if stats_dict else None,
        min_energy={mut_label: stats_dict["min"]} if stats_dict else None,
        max_energy={mut_label: stats_dict["max"]} if stats_dict else None,
        ddg_mean={mut_label: stats_dict["ddg_mean"]} if stats_dict else None,
        ddg_sd={mut_label: stats_dict["ddg_sd"]} if stats_dict else None,
        ddg_ci_95={mut_label: stats_dict["ci_95"]} if stats_dict else None,
        convergence_metric={mut_label: stats_dict["cv"]} if stats_dict else None,
    )

load_mutations(csv_path) staticmethod

Load a mutation list from a CSV file.

The CSV must contain a mutation column with strings like G13L. Optional columns:

  • chain — chain identifier (defaults to 'A' if absent).
  • Any other columns (e.g. score, source, notes) are preserved as metadata and carried through to the results.

Parameters:

Name Type Description Default
csv_path str

Path to a CSV file.

required

Returns:

Type Description
DataFrame

Polars DataFrame with at least [mutation, chain] plus any

DataFrame

extra columns from the CSV.

Source code in src/sicifus/mutate.py
@staticmethod
def load_mutations(csv_path: str) -> pl.DataFrame:
    """Load a mutation list from a CSV file.

    The CSV must contain a ``mutation`` column with strings like ``G13L``.
    Optional columns:

    - ``chain`` — chain identifier (defaults to ``'A'`` if absent).
    - Any other columns (e.g. ``score``, ``source``, ``notes``) are
      preserved as metadata and carried through to the results.

    Args:
        csv_path: Path to a CSV file.

    Returns:
        Polars DataFrame with at least ``[mutation, chain]`` plus any
        extra columns from the CSV.
    """
    df = pl.read_csv(csv_path)

    if "mutation" not in df.columns:
        raise ValueError(
            f"CSV must contain a 'mutation' column. "
            f"Found columns: {df.columns}"
        )

    if "chain" not in df.columns:
        df = df.with_columns(pl.lit("A").alias("chain"))

    return df

mutate_batch(source, mutations_df, max_iterations=2000, n_runs=3, constrain_backbone=True, _repair_cache=None)

Run every mutation in a DataFrame and return results.

Each row is treated as an independent single-point mutation. Any extra columns in the input DataFrame are preserved in the output.

The wild-type structure is prepared once and reused for every mutation, giving deterministic WT energies and eliminating hydrogen-placement noise.

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame.

required
mutations_df DataFrame

DataFrame with mutation and chain columns (as returned by :meth:load_mutations).

required
max_iterations int

Minimisation steps per mutation (default 2000).

2000
n_runs int

Independent minimisation runs per mutation (default 3).

3
constrain_backbone bool

Restrain backbone during minimisation.

True
_repair_cache Optional[_RepairCache]

Optional pre-processed WT from :meth:prepare. If not provided, one is created automatically.

None

Returns:

Type Description
DataFrame

Polars DataFrame with the input columns plus

DataFrame

[wt_energy, mutant_energy, ddg_kcal_mol].

Source code in src/sicifus/mutate.py
def mutate_batch(
    self,
    source,
    mutations_df: pl.DataFrame,
    max_iterations: int = 2000,
    n_runs: int = 3,
    constrain_backbone: bool = True,
    _repair_cache: Optional[_RepairCache] = None,
) -> pl.DataFrame:
    """Run every mutation in a DataFrame and return results.

    Each row is treated as an independent single-point mutation.
    Any extra columns in the input DataFrame are preserved in the output.

    The wild-type structure is prepared *once* and reused for every
    mutation, giving deterministic WT energies and eliminating
    hydrogen-placement noise.

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
        mutations_df: DataFrame with ``mutation`` and ``chain`` columns
                      (as returned by :meth:`load_mutations`).
        max_iterations: Minimisation steps per mutation (default 2000).
        n_runs: Independent minimisation runs per mutation (default 3).
        constrain_backbone: Restrain backbone during minimisation.
        _repair_cache: Optional pre-processed WT from :meth:`prepare`.
                       If not provided, one is created automatically.

    Returns:
        Polars DataFrame with the input columns plus
        ``[wt_energy, mutant_energy, ddg_kcal_mol]``.
    """
    if "mutation" not in mutations_df.columns:
        raise ValueError("DataFrame must contain a 'mutation' column.")
    if "chain" not in mutations_df.columns:
        mutations_df = mutations_df.with_columns(pl.lit("A").alias("chain"))

    if _repair_cache is None:
        print("Preparing wild-type structure (repair-once)...")
        _repair_cache = self._prepare_structure(
            source, max_iterations=max_iterations)

    result_rows = []
    total = mutations_df.height
    for i, row in enumerate(mutations_df.iter_rows(named=True)):
        mut_str = row["mutation"]
        chain_id = row["chain"]

        try:
            mut = Mutation.from_str(mut_str, chain=chain_id)
            result = self.mutate(
                source, [mut],
                chain=chain_id,
                n_runs=n_runs,
                max_iterations=max_iterations,
                constrain_backbone=constrain_backbone,
                _repair_cache=_repair_cache,
            )
            ddg_val = list(result.ddg.values())[0]
            wt_e = result.wt_energy
            mut_e = list(result.mutant_energies.values())[0]
        except Exception as e:
            print(f"  [{i+1}/{total}] {mut_str} chain {chain_id}: FAILED ({e})")
            ddg_val = float("nan")
            wt_e = float("nan")
            mut_e = float("nan")

        out = {**row, "wt_energy": wt_e, "mutant_energy": mut_e, "ddg_kcal_mol": ddg_val}
        result_rows.append(out)
        print(f"  [{i+1}/{total}] {mut_str} chain {chain_id}: ddG = {ddg_val:+.2f} kcal/mol")

    return pl.DataFrame(result_rows)

calculate_binding_energy(source, chains_a, chains_b, max_iterations=2000)

Calculate binding energy between two groups of chains.

Calculates binding energy for protein-protein complexes.

E_binding = E_complex - (E_chains_a + E_chains_b)

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame.

required
chains_a List[str]

Chain IDs for the first group (e.g. ['A']).

required
chains_b List[str]

Chain IDs for the second group (e.g. ['B']).

required
max_iterations int

Minimisation steps.

2000

Returns:

Type Description
BindingResult

BindingResult with binding energy, component energies,

BindingResult

and interface residues.

Source code in src/sicifus/mutate.py
def calculate_binding_energy(
    self,
    source,
    chains_a: List[str],
    chains_b: List[str],
    max_iterations: int = 2000,
) -> BindingResult:
    """Calculate binding energy between two groups of chains.

    Calculates binding energy for protein-protein complexes.

    E_binding = E_complex - (E_chains_a + E_chains_b)

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
        chains_a: Chain IDs for the first group (e.g. ['A']).
        chains_b: Chain IDs for the second group (e.g. ['B']).
        max_iterations: Minimisation steps.

    Returns:
        BindingResult with binding energy, component energies,
        and interface residues.
    """
    from pdbfixer import PDBFixer

    pdb_text = _load_pdb(source)
    fixer = _pdb_string_to_fixer(pdb_text)
    fixer.findMissingResidues()
    fixer.findMissingAtoms()
    fixer.addMissingAtoms()
    fixer.addMissingHydrogens(7.0)

    # --- Complex energy ---
    sys_c, top_c, pos_c = self._build_system(fixer.topology, fixer.positions)
    pos_c_min, e_complex_kj = self._minimise(sys_c, top_c, pos_c, max_iterations)
    e_complex = e_complex_kj * self.KJ_TO_KCAL

    # Use minimised complex positions to extract chain subsets
    top_a, pos_a = _extract_chains(top_c, pos_c_min, chains_a)
    top_b, pos_b = _extract_chains(top_c, pos_c_min, chains_b)

    sys_a, top_a2, pos_a2 = self._build_system(top_a, pos_a)
    _, e_a_kj = self._minimise(sys_a, top_a2, pos_a2, max_iterations)
    e_a = e_a_kj * self.KJ_TO_KCAL

    sys_b, top_b2, pos_b2 = self._build_system(top_b, pos_b)
    _, e_b_kj = self._minimise(sys_b, top_b2, pos_b2, max_iterations)
    e_b = e_b_kj * self.KJ_TO_KCAL

    e_binding = e_complex - (e_a + e_b)

    interface_df = _find_interface_residues(
        top_c, pos_c_min, chains_a, chains_b, cutoff_nm=0.5
    )

    return BindingResult(
        binding_energy=round(e_binding, 4),
        complex_energy=round(e_complex, 4),
        chain_a_energy=round(e_a, 4),
        chain_b_energy=round(e_b, 4),
        interface_residues=interface_df,
    )

alanine_scan(source, chain, positions=None, max_iterations=2000, constrain_backbone=True)

Perform alanine scanning on a chain.

Performs systematic alanine scanning mutagenesis. Each non-Ala/Gly position is mutated to alanine and the ddG is reported.

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame.

required
chain str

Chain ID to scan (e.g. 'A').

required
positions Optional[List[int]]

Specific residue numbers to scan. If None, scans all non-Ala/Gly standard residues.

None
max_iterations int

Minimisation steps per mutant.

2000
constrain_backbone bool

Freeze backbone atoms during minimisation.

True

Returns:

Type Description
DataFrame

Polars DataFrame with columns:

DataFrame

[chain, position, wt_residue, ddg_kcal_mol].

Source code in src/sicifus/mutate.py
def alanine_scan(
    self,
    source,
    chain: str,
    positions: Optional[List[int]] = None,
    max_iterations: int = 2000,
    constrain_backbone: bool = True,
) -> pl.DataFrame:
    """Perform alanine scanning on a chain.

    Performs systematic alanine scanning mutagenesis.  Each non-Ala/Gly position is mutated
    to alanine and the ddG is reported.

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
        chain: Chain ID to scan (e.g. 'A').
        positions: Specific residue numbers to scan.  If None, scans all
                   non-Ala/Gly standard residues.
        max_iterations: Minimisation steps per mutant.
        constrain_backbone: Freeze backbone atoms during minimisation.

    Returns:
        Polars DataFrame with columns:
        [chain, position, wt_residue, ddg_kcal_mol].
    """
    from pdbfixer import PDBFixer

    pdb_text = _load_pdb(source)
    fixer = _pdb_string_to_fixer(pdb_text)

    scan_positions = self._get_scannable_positions(
        fixer.topology, chain, positions, skip_residues={"ALA", "GLY"}
    )

    if not scan_positions:
        print(f"No scannable positions found on chain {chain}.")
        return pl.DataFrame(schema={
            "chain": pl.Utf8, "position": pl.Int64,
            "wt_residue": pl.Utf8, "ddg_kcal_mol": pl.Float64,
        })

    print(f"Alanine scan: {len(scan_positions)} positions on chain {chain}")

    rows = []
    for pos_num, wt_res in scan_positions:
        mut = Mutation(chain=chain, position=pos_num,
                      wt_residue=wt_res, mut_residue="ALA")
        try:
            result = self.mutate(
                pdb_text, [mut],
                max_iterations=max_iterations,
                constrain_backbone=constrain_backbone,
            )
            ddg_val = list(result.ddg.values())[0]
        except Exception as e:
            print(f"  {mut.label}: FAILED ({e})")
            ddg_val = float("nan")

        rows.append({
            "chain": chain,
            "position": pos_num,
            "wt_residue": wt_res,
            "ddg_kcal_mol": round(ddg_val, 4),
        })
        print(f"  {mut.label}: ddG = {ddg_val:+.2f} kcal/mol")

    return pl.DataFrame(rows)

position_scan(source, chain, positions, max_iterations=2000, constrain_backbone=True)

Scan all 20 amino acids at specified positions.

Generates position-specific scoring matrix by scanning all amino acids.

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame.

required
chain str

Chain ID.

required
positions List[int]

List of residue numbers to scan.

required
max_iterations int

Minimisation steps per mutant.

2000
constrain_backbone bool

Freeze backbone atoms during minimisation.

True

Returns:

Type Description
DataFrame

Polars DataFrame with columns:

DataFrame

[chain, position, wt_residue, mut_residue, ddg_kcal_mol].

Source code in src/sicifus/mutate.py
def position_scan(
    self,
    source,
    chain: str,
    positions: List[int],
    max_iterations: int = 2000,
    constrain_backbone: bool = True,
) -> pl.DataFrame:
    """Scan all 20 amino acids at specified positions.

    Generates position-specific scoring matrix by scanning all amino acids.

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
        chain: Chain ID.
        positions: List of residue numbers to scan.
        max_iterations: Minimisation steps per mutant.
        constrain_backbone: Freeze backbone atoms during minimisation.

    Returns:
        Polars DataFrame with columns:
        [chain, position, wt_residue, mut_residue, ddg_kcal_mol].
    """
    from pdbfixer import PDBFixer

    pdb_text = _load_pdb(source)
    fixer = _pdb_string_to_fixer(pdb_text)

    resmap = {}
    for chain_obj in fixer.topology.chains():
        if chain_obj.id == chain:
            for res in chain_obj.residues():
                resmap[int(res.id)] = res.name

    rows = []
    total = len(positions) * 20
    done = 0
    for pos_num in positions:
        wt_res = resmap.get(pos_num)
        if wt_res is None or wt_res not in STANDARD_RESIDUES:
            continue

        for mut_res in ALL_AMINO_ACIDS:
            done += 1
            if mut_res == wt_res:
                rows.append({
                    "chain": chain, "position": pos_num,
                    "wt_residue": wt_res, "mut_residue": mut_res,
                    "ddg_kcal_mol": 0.0,
                })
                continue

            mut = Mutation(chain=chain, position=pos_num,
                          wt_residue=wt_res, mut_residue=mut_res)
            try:
                result = self.mutate(
                    pdb_text, [mut],
                    max_iterations=max_iterations,
                    constrain_backbone=constrain_backbone,
                )
                ddg_val = list(result.ddg.values())[0]
            except Exception as e:
                ddg_val = float("nan")

            rows.append({
                "chain": chain, "position": pos_num,
                "wt_residue": wt_res, "mut_residue": mut_res,
                "ddg_kcal_mol": round(ddg_val, 4),
            })

        print(f"  Position {pos_num} ({wt_res}) complete [{done}/{total}]")

    return pl.DataFrame(rows)

per_residue_energy(source, max_iterations=2000)

Approximate per-residue energy contribution.

Computes per-residue energy decomposition.

Uses an alanine-subtraction approach: for each residue, the energy difference between the full structure and the Ala-mutant estimates that residue's energetic contribution (positive = stabilising, negative = destabilising).

Parameters:

Name Type Description Default
source

PDB file path, PDB string, or Polars DataFrame.

required
max_iterations int

Minimisation steps.

2000

Returns:

Type Description
DataFrame

Polars DataFrame with columns:

DataFrame

[chain, residue_number, residue_name, energy_contribution_kcal_mol].

Source code in src/sicifus/mutate.py
def per_residue_energy(
    self,
    source,
    max_iterations: int = 2000,
) -> pl.DataFrame:
    """Approximate per-residue energy contribution.

    Computes per-residue energy decomposition.

    Uses an alanine-subtraction approach: for each residue, the energy
    difference between the full structure and the Ala-mutant estimates
    that residue's energetic contribution (positive = stabilising,
    negative = destabilising).

    Args:
        source: PDB file path, PDB string, or Polars DataFrame.
        max_iterations: Minimisation steps.

    Returns:
        Polars DataFrame with columns:
        [chain, residue_number, residue_name, energy_contribution_kcal_mol].
    """
    from pdbfixer import PDBFixer

    pdb_text = _load_pdb(source)
    fixer = _pdb_string_to_fixer(pdb_text)

    all_positions = []
    for chain_obj in fixer.topology.chains():
        for res in chain_obj.residues():
            if res.name in STANDARD_RESIDUES and res.name not in ("ALA", "GLY"):
                all_positions.append((chain_obj.id, int(res.id), res.name))

    if not all_positions:
        return pl.DataFrame(schema={
            "chain": pl.Utf8, "residue_number": pl.Int64,
            "residue_name": pl.Utf8, "energy_contribution_kcal_mol": pl.Float64,
        })

    print(f"Per-residue energy: {len(all_positions)} residues via Ala-subtraction")

    # Baseline WT energy
    stab = self.calculate_stability(pdb_text, max_iterations=max_iterations)
    e_wt = stab.total_energy

    rows = []
    for chain_id, pos_num, res_name in all_positions:
        mut = Mutation(chain=chain_id, position=pos_num,
                      wt_residue=res_name, mut_residue="ALA")
        try:
            result = self.mutate(
                pdb_text, [mut],
                max_iterations=max_iterations,
                constrain_backbone=True,
            )
            ddg_val = list(result.ddg.values())[0]
            contribution = -ddg_val
        except Exception:
            contribution = float("nan")

        rows.append({
            "chain": chain_id,
            "residue_number": pos_num,
            "residue_name": res_name,
            "energy_contribution_kcal_mol": round(contribution, 4),
        })

    # Ala and Gly get 0.0 (self-reference)
    for chain_obj in fixer.topology.chains():
        for res in chain_obj.residues():
            if res.name in ("ALA", "GLY"):
                rows.append({
                    "chain": chain_obj.id,
                    "residue_number": int(res.id),
                    "residue_name": res.name,
                    "energy_contribution_kcal_mol": 0.0,
                })

    df = pl.DataFrame(rows).sort(["chain", "residue_number"])
    return df

Mutation

sicifus.Mutation dataclass

Describes a single point mutation.

Parameters:

Name Type Description Default
position int

Residue number in the structure.

required
wt_residue str

Wild-type residue (1-letter or 3-letter code).

required
mut_residue str

Mutant residue (1-letter or 3-letter code).

required
chain str

Chain identifier (default "A").

'A'
Source code in src/sicifus/mutate.py
@dataclass
class Mutation:
    """Describes a single point mutation.

    Args:
        position: Residue number in the structure.
        wt_residue: Wild-type residue (1-letter or 3-letter code).
        mut_residue: Mutant residue (1-letter or 3-letter code).
        chain: Chain identifier (default ``"A"``).
    """
    position: int
    wt_residue: str
    mut_residue: str
    chain: str = "A"

    def __post_init__(self):
        self.wt_residue = self._normalise(self.wt_residue)
        self.mut_residue = self._normalise(self.mut_residue)

    @staticmethod
    def _normalise(code: str) -> str:
        code = code.strip().upper()
        if len(code) == 1 and code in ONE_TO_THREE:
            return ONE_TO_THREE[code]
        if len(code) == 3 and code in STANDARD_RESIDUES:
            return code
        raise ValueError(f"Unknown residue code: {code!r}")

    @property
    def label(self) -> str:
        """Short label like ``G13L``."""
        wt1 = THREE_TO_ONE[self.wt_residue]
        mt1 = THREE_TO_ONE[self.mut_residue]
        return f"{wt1}{self.position}{mt1}"

    @classmethod
    def from_str(cls, notation: str, chain: str = "A") -> "Mutation":
        """Parse a mutation string like ``'G13L'`` (Gly at position 13 to Leu).

        Args:
            notation: String in the format ``WtPositionMut``
                      (e.g. ``'G13L'``, ``'F42W'``, ``'A100V'``).
            chain: Chain identifier (default ``'A'``).
        """
        m = re.match(r"^([A-Z])(\d+)([A-Z])$", notation.strip().upper())
        if not m:
            raise ValueError(
                f"Invalid mutation string: {notation!r}. "
                "Expected format: WtPositionMut (e.g. G13L)"
            )
        wt_one, pos, mut_one = m.groups()
        return cls(
            position=int(pos),
            wt_residue=ONE_TO_THREE[wt_one],
            mut_residue=ONE_TO_THREE[mut_one],
            chain=chain.upper(),
        )

    def __repr__(self):
        if self.chain != "A":
            return f"Mutation({self.label}, chain={self.chain})"
        return f"Mutation({self.label})"

label property

Short label like G13L.

from_str(notation, chain='A') classmethod

Parse a mutation string like 'G13L' (Gly at position 13 to Leu).

Parameters:

Name Type Description Default
notation str

String in the format WtPositionMut (e.g. 'G13L', 'F42W', 'A100V').

required
chain str

Chain identifier (default 'A').

'A'
Source code in src/sicifus/mutate.py
@classmethod
def from_str(cls, notation: str, chain: str = "A") -> "Mutation":
    """Parse a mutation string like ``'G13L'`` (Gly at position 13 to Leu).

    Args:
        notation: String in the format ``WtPositionMut``
                  (e.g. ``'G13L'``, ``'F42W'``, ``'A100V'``).
        chain: Chain identifier (default ``'A'``).
    """
    m = re.match(r"^([A-Z])(\d+)([A-Z])$", notation.strip().upper())
    if not m:
        raise ValueError(
            f"Invalid mutation string: {notation!r}. "
            "Expected format: WtPositionMut (e.g. G13L)"
        )
    wt_one, pos, mut_one = m.groups()
    return cls(
        position=int(pos),
        wt_residue=ONE_TO_THREE[wt_one],
        mut_residue=ONE_TO_THREE[mut_one],
        chain=chain.upper(),
    )

Analysis Toolkit

sicifus.analysis.AnalysisToolkit

Tools for analyzing structural dataframes.

Source code in src/sicifus/analysis.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
class AnalysisToolkit:
    """
    Tools for analyzing structural dataframes.
    """

    def __init__(self):
        self.aligner = StructuralAligner()

    def compute_rmsd_matrix(self, structures: Dict[str, pl.DataFrame], n_jobs: int = -1,
                            pruning_threshold: Optional[float] = None,
                            prefilter: bool = True) -> Tuple[np.ndarray, List[str]]:
        """
        Computes all-vs-all RMSD matrix for a dictionary of structures.
        Returns (matrix, labels).

        Args:
            structures: Dictionary of structure_id -> DataFrame
            n_jobs: Number of parallel jobs (-1 for all CPUs)
            pruning_threshold: If set (0.0-1.0), skip alignment if sequence length ratio < threshold.
                               Skipped pairs get a high RMSD value (e.g., 99.9).
            prefilter: If True (default), use 3Di k-mer prefiltering to skip
                       dissimilar pairs. Much faster for large N.
        """
        ids = list(structures.keys())
        n = len(ids)

        print(f"Pre-processing {n} structures...")

        coords_list = []
        for sid in ids:
            coords_list.append(self.aligner.get_ca_coords(structures[sid]))

        lengths = [len(c) for c in coords_list]
        num_pairs = n * (n - 1) // 2
        print(f"Computing RMSD matrix for {n} structures ({num_pairs} pairs)...")

        if prefilter:
            print(f"  Using 3Di k-mer prefilter.")
            matrix = self._rmsd_matrix_prefiltered(coords_list, ids, n, n_jobs, pruning_threshold)
        else:
            unique_lengths = sorted(set(lengths))
            all_same_length = len(unique_lengths) == 1
            if all_same_length:
                print(f"  All structures have {lengths[0]} residues — using vectorized Kabsch.")
                matrix = self._rmsd_matrix_vectorized(coords_list, n)
            else:
                print(f"  Variable lengths detected — using threaded alignment path.")
                matrix = self._rmsd_matrix_variable_length(coords_list, ids, n, n_jobs, pruning_threshold)

        return matrix, ids

    def _rmsd_matrix_vectorized(self, coords_list: List[np.ndarray], n: int) -> np.ndarray:
        """
        Fully vectorized RMSD matrix for same-length structures.
        Uses batched numpy einsum + SVD — no Python loops over pairs.
        """
        L = len(coords_list[0])
        num_pairs = n * (n - 1) // 2

        # Stack all structures: (n, L, 3)
        coords_array = np.array(coords_list)

        # Center all structures at their centroids
        centroids = coords_array.mean(axis=1, keepdims=True)  # (n, 1, 3)
        centered = coords_array - centroids  # (n, L, 3)

        # Precompute squared norms per structure: sum of all squared coords
        sq_norms = np.sum(centered ** 2, axis=(1, 2))  # (n,)

        # Compute ALL pairwise covariance matrices in one shot
        H_all = np.einsum('ila,jlb->ijab', centered, centered)  # (n, n, 3, 3)

        # Extract upper triangle pairs only
        i_idx, j_idx = np.triu_indices(n, k=1)
        H_pairs = H_all[i_idx, j_idx]  # (num_pairs, 3, 3)
        del H_all

        # Batched SVD on all pairs at once
        U, S, Vt = np.linalg.svd(H_pairs)  # S: (num_pairs, 3)

        # Handle reflections + compute RMSD
        R = np.einsum('...ji,...kj->...ik', Vt, U)  # (num_pairs, 3, 3)
        dets = np.linalg.det(R)  # (num_pairs,)
        S[dets < 0, -1] *= -1

        sum_S = np.sum(S, axis=1)  # (num_pairs,)
        rmsd_sq = (sq_norms[i_idx] + sq_norms[j_idx] - 2.0 * sum_S) / L
        rmsd = np.sqrt(np.maximum(rmsd_sq, 0.0))

        # Fill symmetric matrix
        matrix = np.zeros((n, n))
        matrix[i_idx, j_idx] = rmsd
        matrix[j_idx, i_idx] = rmsd

        return matrix

    def _rmsd_matrix_variable_length(self, coords_list: List[np.ndarray], ids: List[str], 
                                      n: int, n_jobs: int, pruning_threshold: Optional[float]) -> np.ndarray:
        """
        RMSD matrix for variable-length structures using alignment + threading.
        """
        matrix = np.zeros((n, n))

        # Pre-compute structural sequences
        seq_list = []
        lengths = []
        for coords in coords_list:
            lengths.append(len(coords))
            seq_str = self.aligner.encode_structure(coords)
            seq_list.append(np.array([ord(c) for c in seq_str], dtype=np.int32))

        from .align import _align_sequences_numba, _superimpose_numba

        # Warm up Numba JIT (first call compiles, would skew timing)
        if len(coords_list) >= 2:
            _superimpose_numba(coords_list[0][:3], coords_list[1][:3])
            _align_sequences_numba(seq_list[0][:3], seq_list[1][:3])

        def compute_row(i):
            row_results = np.zeros(n)
            c1, s1, l1 = coords_list[i], seq_list[i], lengths[i]

            for j in range(i + 1, n):
                c2, s2, l2 = coords_list[j], seq_list[j], lengths[j]

                if pruning_threshold is not None:
                    if l1 > 0 and l2 > 0:
                        if min(l1, l2) / max(l1, l2) < pruning_threshold:
                            row_results[j] = 99.9
                            continue
                    else:
                        row_results[j] = 99.9
                        continue

                if l1 == l2:
                    rmsd, _, _, _ = _superimpose_numba(c1, c2)
                else:
                    idx1, idx2 = _align_sequences_numba(s1, s2)
                    if len(idx1) < 3:
                        rmsd = 99.9
                    else:
                        rmsd, _, _, _ = _superimpose_numba(c1[idx1], c2[idx2])

                row_results[j] = rmsd
            return i, row_results

        results = Parallel(n_jobs=n_jobs, prefer="threads")(
            delayed(compute_row)(i) for i in range(n)
        )

        for i, row_data in results:
            for j in range(i + 1, n):
                val = row_data[j]
                if val != 0.0:
                    matrix[i, j] = val
                    matrix[j, i] = val

        return matrix

    def _rmsd_matrix_prefiltered(self, coords_list: List[np.ndarray], ids: List[str],
                                    n: int, n_jobs: int,
                                    pruning_threshold: Optional[float]) -> np.ndarray:
        """RMSD matrix using 3Di k-mer prefilter to skip dissimilar pairs."""
        from .align import _align_sequences_numba, _superimpose_numba, _encode_3di_numba
        from .kmer_index import build_kmer_index, prefilter_pairs

        sequences_3di = [
            _encode_3di_numba(np.ascontiguousarray(c, dtype=np.float64))
            for c in coords_list
        ]
        candidate_pairs = prefilter_pairs(sequences_3di, k=6, alphabet_size=20, min_score=0.1)

        total_pairs = n * (n - 1) // 2
        pct = 100.0 * len(candidate_pairs) / max(total_pairs, 1)
        print(f"  Prefilter kept {len(candidate_pairs)}/{total_pairs} pairs ({pct:.1f}%)")

        seq_list = []
        lengths = []
        for coords in coords_list:
            lengths.append(len(coords))
            seq_str = self.aligner.encode_structure(coords)
            seq_list.append(np.array([ord(c) for c in seq_str], dtype=np.int32))

        if n >= 2:
            _superimpose_numba(coords_list[0][:3], coords_list[1][:3])
            _align_sequences_numba(seq_list[0][:3], seq_list[1][:3])

        matrix = np.full((n, n), 99.9)
        np.fill_diagonal(matrix, 0.0)

        pairs_list = list(candidate_pairs)

        def compute_pair(pair):
            i, j = pair
            c1, s1, l1 = coords_list[i], seq_list[i], lengths[i]
            c2, s2, l2 = coords_list[j], seq_list[j], lengths[j]

            if pruning_threshold is not None:
                if l1 > 0 and l2 > 0:
                    if min(l1, l2) / max(l1, l2) < pruning_threshold:
                        return i, j, 99.9
                else:
                    return i, j, 99.9

            if l1 == l2:
                rmsd, _, _, _ = _superimpose_numba(c1, c2)
            else:
                idx1, idx2 = _align_sequences_numba(s1, s2)
                if len(idx1) < 3:
                    return i, j, 99.9
                rmsd, _, _, _ = _superimpose_numba(c1[idx1], c2[idx2])

            return i, j, rmsd

        results = Parallel(n_jobs=n_jobs, prefer="threads")(
            delayed(compute_pair)(p) for p in pairs_list
        )

        for i, j, rmsd in results:
            matrix[i, j] = rmsd
            matrix[j, i] = rmsd

        return matrix

    # ------------------------------------------------------------------
    # Fast greedy clustering (no full distance matrix)
    # ------------------------------------------------------------------

    def cluster_fast(self, structures: Dict[str, pl.DataFrame],
                     distance_threshold: float = 2.0,
                     coverage_threshold: float = 0.8) -> pl.DataFrame:
        """Greedy centroid-based structural clustering (linclust-inspired).

        Uses the 3Di k-mer index to quickly identify candidate centroids for
        each structure, then only computes RMSD to those candidates.  No full
        N×N distance matrix is needed.

        Args:
            structures: Dictionary of structure_id -> DataFrame.
            distance_threshold: Maximum RMSD to assign a structure to an
                                existing cluster centroid (Ã…).
            coverage_threshold: Minimum length-ratio between a structure and a
                                centroid (0-1) for them to be compared.

        Returns:
            Polars DataFrame with columns
            ``[structure_id, cluster, centroid_id, rmsd_to_centroid]``.
        """
        from .align import _encode_3di_numba, _superimpose_numba, _align_sequences_numba
        from .kmer_index import build_kmer_index, _extract_kmer_hashes

        ids = list(structures.keys())
        n = len(ids)

        print(f"Fast clustering {n} structures (threshold={distance_threshold} Ã…)...")

        coords_list = [self.aligner.get_ca_coords(structures[sid]) for sid in ids]
        lengths = [len(c) for c in coords_list]

        sequences_3di = [
            _encode_3di_numba(np.ascontiguousarray(c, dtype=np.float64))
            for c in coords_list
        ]
        seq_list = [
            np.array([ord(c) for c in self.aligner.encode_structure(c)], dtype=np.int32)
            for c in coords_list
        ]

        index = build_kmer_index(sequences_3di, k=6, alphabet_size=20)

        if n >= 2:
            _superimpose_numba(coords_list[0][:3], coords_list[1][:3])

        order = sorted(range(n), key=lambda i: lengths[i], reverse=True)

        centroid_indices: List[int] = []
        centroid_set: set = set()
        centroid_to_cluster: Dict[int, int] = {}
        cluster_of: Dict[int, int] = {}
        rmsd_of: Dict[int, float] = {}

        for idx in order:
            if not centroid_indices:
                centroid_indices.append(idx)
                centroid_set.add(idx)
                centroid_to_cluster[idx] = 0
                cluster_of[idx] = 0
                rmsd_of[idx] = 0.0
                continue

            hashes = _extract_kmer_hashes(sequences_3di[idx], 6, 20)
            unique_hashes = set(int(h) for h in hashes)
            n_query = len(unique_hashes)

            candidate_centroids: List[int] = []
            if n_query > 0:
                threshold = max(int(0.05 * n_query), 1)
                scores: Dict[int, int] = {}
                for h in unique_hashes:
                    if h in index:
                        for j in index[h]:
                            if j in centroid_set:
                                scores[j] = scores.get(j, 0) + 1
                candidate_centroids = [
                    c for c, s in scores.items() if s >= threshold
                ]

            best_cluster = -1
            best_rmsd = float("inf")

            for c_idx in candidate_centroids:
                l1, l2 = lengths[idx], lengths[c_idx]
                if l1 > 0 and l2 > 0:
                    if min(l1, l2) / max(l1, l2) < coverage_threshold:
                        continue
                if l1 == l2:
                    rmsd, _, _, _ = _superimpose_numba(coords_list[idx], coords_list[c_idx])
                else:
                    i1, i2 = _align_sequences_numba(seq_list[idx], seq_list[c_idx])
                    if len(i1) < 3:
                        continue
                    rmsd, _, _, _ = _superimpose_numba(
                        coords_list[idx][i1], coords_list[c_idx][i2]
                    )
                if rmsd < best_rmsd:
                    best_rmsd = rmsd
                    best_cluster = centroid_to_cluster[c_idx]

            if best_cluster >= 0 and best_rmsd <= distance_threshold:
                cluster_of[idx] = best_cluster
                rmsd_of[idx] = best_rmsd
            else:
                new_cid = len(centroid_indices)
                centroid_indices.append(idx)
                centroid_set.add(idx)
                centroid_to_cluster[idx] = new_cid
                cluster_of[idx] = new_cid
                rmsd_of[idx] = 0.0

        cluster_to_centroid = {v: ids[k] for k, v in centroid_to_cluster.items()}

        rows = []
        for idx in range(n):
            cid = cluster_of[idx]
            rows.append({
                "structure_id": ids[idx],
                "cluster": cid + 1,
                "centroid_id": cluster_to_centroid[cid],
                "rmsd_to_centroid": round(rmsd_of[idx], 4),
            })

        df = pl.DataFrame(rows)
        n_clusters = df["cluster"].n_unique()
        print(f"  {n_clusters} clusters found")
        return df

    def build_tree(self, rmsd_matrix: np.ndarray, labels: List[str], method: str = "average"):
        """
        Builds a hierarchical clustering tree from RMSD matrix using Scipy.
        Returns the linkage matrix.
        """
        # Convert to condensed distance matrix
        condensed_dist = squareform(rmsd_matrix)
        Z = linkage(condensed_dist, method=method)
        return Z

    def build_phylo_tree(self, rmsd_matrix: np.ndarray, labels: List[str], root_id: Optional[str] = None):
        """
        Builds a phylogenetic tree from RMSD matrix.
        Uses Scipy's fast C-based linkage, then converts to Biopython Tree for Newick export.
        Tree is unrooted by default. If root_id is provided, roots at that structure.
        """
        condensed_dist = squareform(rmsd_matrix)
        Z = linkage(condensed_dist, method="average")

        tree = self._linkage_to_biopython_tree(Z, labels)

        if root_id:
            if root_id not in labels:
                raise ValueError(f"Root ID {root_id} not found in labels.")
            tree.root_with_outgroup({"name": root_id})
            tree.rooted = True

        return tree

    def _linkage_to_biopython_tree(self, Z: np.ndarray, labels: List[str]):
        """
        Converts a Scipy linkage matrix to a Biopython Tree object.
        """
        from Bio.Phylo.BaseTree import Tree, Clade

        root_node = to_tree(Z)

        def _build_clade(node, parent_height=None):
            """
            Recursively build Biopython Clade from Scipy ClusterNode.
            node.dist = the height (cumulative RMSD) at which this node's cluster formed.
            Leaves have dist=0. Branch length = parent_height - this_height.
            """
            height = node.dist

            if parent_height is not None:
                branch_length = max(parent_height - height, 0.0)
            else:
                branch_length = 0.0  # Root has no branch length

            if node.is_leaf():
                return Clade(name=labels[node.id], branch_length=branch_length)
            else:
                left = _build_clade(node.get_left(), parent_height=height)
                right = _build_clade(node.get_right(), parent_height=height)
                clade = Clade(branch_length=branch_length)
                clade.clades = [left, right]
                return clade

        root_clade = _build_clade(root_node)

        return Tree(root=root_clade, rooted=False)

    def cluster_from_tree(self, tree, distance_threshold: float) -> pl.DataFrame:
        """
        Derives clusters directly from the phylogenetic tree by cutting branches
        whose length (RMSD) exceeds the threshold. Each resulting subtree's 
        leaves form a cluster.

        This uses the actual tree topology and RMSD branch lengths — the clusters
        are defined by the tree itself, not by an external algorithm.

        Args:
            tree: Biopython Tree object with RMSD branch lengths.
            distance_threshold: Cut any branch longer than this RMSD value.
                               e.g. 1.0 means groups separated by > 1 Ã… RMSD are different clusters.

        Returns:
            Polars DataFrame with columns: structure_id, cluster
        """
        assignments = {}
        cluster_counter = [0]

        def _assign(clade, current_cluster):
            """Walk the tree; cut branches that exceed the threshold."""
            if clade.is_terminal():
                assignments[clade.name] = current_cluster
            else:
                for child in clade.clades:
                    bl = child.branch_length if child.branch_length is not None else 0.0
                    if bl > distance_threshold:
                        # This branch is too long — new cluster for this subtree
                        cluster_counter[0] += 1
                        _assign(child, cluster_counter[0])
                    else:
                        # Same cluster
                        _assign(child, current_cluster)

        cluster_counter[0] = 1
        _assign(tree.root, 1)

        return pl.DataFrame({
            "structure_id": list(assignments.keys()),
            "cluster": list(assignments.values())
        })

    def plot_tree(self, tree_obj: Union[np.ndarray, Phylo.BaseTree.Tree], labels: Optional[List[str]] = None, output_file: Optional[str] = None):
        """
        Plots the tree. Handles both Scipy linkage matrix and Biopython Tree.
        """
        plt.figure(figsize=(10, 7))

        if isinstance(tree_obj, np.ndarray):
            # Scipy Linkage Matrix
            dendrogram(tree_obj, labels=labels, leaf_rotation=90)
            plt.title("Structural Phylogenetic Tree (RMSD-based)")
            plt.xlabel("Structure ID")
            plt.ylabel("RMSD")
        elif isinstance(tree_obj, Phylo.BaseTree.Tree):
            # Biopython Tree
            plt.clf()
            Phylo.draw(tree_obj, do_show=False)
            plt.title("Structural Phylogenetic Tree (Rooted)" if tree_obj.rooted else "Structural Phylogenetic Tree")

        if output_file:
            plt.savefig(output_file)
        else:
            plt.show()

    def plot_circular_tree(self, Z: np.ndarray, labels: List[str], 
                           cluster_df: Optional[pl.DataFrame] = None,
                           output_file: Optional[str] = None,
                           show_labels: Optional[bool] = None,
                           figsize: Tuple[int, int] = (14, 14),
                           linewidth: float = 0.4):
        """
        Plots an unrooted circular/radial dendrogram from a linkage matrix.

        Args:
            Z: Scipy linkage matrix.
            labels: Structure IDs (leaf labels).
            cluster_df: Optional cluster assignments (from cluster_structures) to color branches.
            output_file: If provided, saves to file instead of showing.
            show_labels: Whether to show leaf labels. Defaults to True if <=100 leaves, else False.
            figsize: Figure size.
            linewidth: Line width for branches.
        """
        n_leaves = len(labels)

        if show_labels is None:
            show_labels = n_leaves <= 100

        # Get dendrogram coordinate data (without plotting)
        dn = dendrogram(Z, no_plot=True, count_sort=True, labels=labels)
        plt.close()  # close the blank figure dendrogram might have created

        icoord = np.array(dn['icoord'])  # x-coords of each U-shape: (n_merges, 4)
        dcoord = np.array(dn['dcoord'])  # y-coords (heights): (n_merges, 4)
        leaf_label_order = dn['ivl']     # ordered leaf labels

        # Build a leaf_label -> cluster color mapping
        cluster_colors = None
        if cluster_df is not None:
            n_clusters = cluster_df["cluster"].n_unique()
            cmap = plt.cm.get_cmap("tab20" if n_clusters <= 20 else "hsv", n_clusters)

            cluster_color_map = {}
            for row in cluster_df.iter_rows(named=True):
                cluster_color_map[row["structure_id"]] = cmap((row["cluster"] - 1) / max(n_clusters - 1, 1))
            cluster_colors = cluster_color_map

        # Map leaf x-positions to angles (0 to 2Ï€)
        # Dendrogram leaf x-positions are at 5, 15, 25, ... (spacing=10)
        x_min = 5.0
        x_max = 5.0 + (n_leaves - 1) * 10.0
        x_range = x_max - x_min if x_max > x_min else 1.0

        max_height = np.max(dcoord) if np.max(dcoord) > 0 else 1.0

        def x_to_angle(x):
            return (x - x_min) / x_range * 2.0 * np.pi

        def y_to_radius(y):
            # Height 0 (leaves) → outer ring; max height (root) → center
            return 1.0 - (y / max_height) * 0.85

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, polar=True)

        # Draw each U-shape link
        for xs, ys in zip(icoord, dcoord):
            # xs: [left_x, left_x, right_x, right_x]
            # ys: [bottom_left_y, top_y, top_y, bottom_right_y]

            a = [x_to_angle(x) for x in xs]
            r = [y_to_radius(y) for y in ys]

            # Determine branch color: if cluster_colors, use the color of the left child leaf
            color = '#555555'

            # Left vertical: (a[0], r[0]) → (a[1], r[1])
            ax.plot([a[0], a[1]], [r[0], r[1]], color=color, linewidth=linewidth, solid_capstyle='round')

            # Right vertical: (a[2], r[2]) → (a[3], r[3])
            ax.plot([a[2], a[3]], [r[2], r[3]], color=color, linewidth=linewidth, solid_capstyle='round')

            # Top arc: from a[1] to a[2] at radius r[1] (= r[2])
            n_arc = max(int(abs(a[2] - a[1]) / (2 * np.pi) * 100), 2)
            arc_angles = np.linspace(a[1], a[2], n_arc)
            arc_radii = np.full_like(arc_angles, r[1])
            ax.plot(arc_angles, arc_radii, color=color, linewidth=linewidth, solid_capstyle='round')

        # Draw colored dots and labels at leaf positions
        leaf_angles = [x_to_angle(5.0 + i * 10.0) for i in range(n_leaves)]
        leaf_radius = y_to_radius(0)

        for i, (angle, label) in enumerate(zip(leaf_angles, leaf_label_order)):
            dot_color = cluster_colors.get(label, '#888888') if cluster_colors else '#333333'
            ax.plot(angle, leaf_radius, 'o', color=dot_color, markersize=2.5, zorder=5)

            if show_labels:
                # Rotate label to read outward
                angle_deg = np.degrees(angle)
                ha = 'left' if angle < np.pi else 'right'
                rotation = angle_deg if angle < np.pi else angle_deg - 180
                ax.text(angle, leaf_radius + 0.04, label, fontsize=5, rotation=rotation,
                        ha=ha, va='center', rotation_mode='anchor',
                        color=dot_color if cluster_colors else '#333333')

        # Draw cluster legend if clusters provided
        if cluster_df is not None:
            n_clusters = cluster_df["cluster"].n_unique()
            cmap = plt.cm.get_cmap("tab20" if n_clusters <= 20 else "hsv", n_clusters)
            cluster_ids = sorted(cluster_df["cluster"].unique().to_list())

            legend_handles = []
            for cid in cluster_ids:
                count = cluster_df.filter(pl.col("cluster") == cid).height
                color = cmap((cid - 1) / max(n_clusters - 1, 1))
                patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color,
                                   markersize=8, label=f"Cluster {cid} ({count})")
                legend_handles.append(patch)

            ax.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(1.3, 1.05),
                      fontsize=8, title="Clusters", title_fontsize=9)

        # Clean up polar plot
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        ax.spines['polar'].set_visible(False)
        ax.grid(False)

        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=200, bbox_inches='tight')
            plt.close()
        else:
            plt.show()

    def build_similarity_network(self, rmsd_matrix: np.ndarray, labels: List[str], threshold: float) -> nx.Graph:
        """
        Builds a network where edges exist if RMSD < threshold.
        """
        G = nx.Graph()
        n = len(labels)

        for i in range(n):
            G.add_node(labels[i])

        for i in range(n):
            for j in range(i + 1, n):
                if rmsd_matrix[i, j] < threshold:
                    G.add_edge(labels[i], labels[j], weight=rmsd_matrix[i, j])

        return G

    def calculate_relative_energy(self, df: pl.DataFrame, group_by: Optional[str] = None) -> pl.DataFrame:
        """
        Converts total energies (Hartree) to relative energies (kcal/mol).

        If group_by is provided (e.g. "ligand_name"), relative energies are calculated
        separately for each group. Otherwise, the global minimum is used.

        Args:
            df: DataFrame with an "energy" column (Hartree).
            group_by: Column to group by before finding minimum (optional).

        Returns:
            DataFrame with new "relative_energy_kcal" column.
        """
        if "energy" not in df.columns:
            raise ValueError("DataFrame must have an 'energy' column (Hartree).")

        HARTREE_TO_KCAL = 627.509

        if group_by:
            return df.with_columns(
                ((pl.col("energy") - pl.col("energy").min().over(group_by)) * HARTREE_TO_KCAL)
                .alias("relative_energy_kcal")
            )
        else:
            min_e = df["energy"].min()
            return df.with_columns(
                ((pl.col("energy") - min_e) * HARTREE_TO_KCAL).alias("relative_energy_kcal")
            )

    def compute_residue_interaction_network(
        self,
        structure_df: pl.DataFrame,
        distance_cutoff: float = 5.0,
        interaction_types: Optional[List[str]] = None,
    ) -> nx.Graph:
        """Compute residue interaction network based on spatial proximity.

        Creates a graph where nodes are residues and edges represent
        interactions (contacts within distance_cutoff).

        Args:
            structure_df: DataFrame with columns [chain, residue_number, residue_name, x, y, z]
            distance_cutoff: Maximum distance (Ã…) for residue contact (default 5.0)
            interaction_types: Optional filter for specific residues (e.g., ['PHE', 'TYR', 'TRP'])

        Returns:
            NetworkX graph with residue nodes and interaction edges.
            Node attributes: chain, residue_number, residue_name
            Edge attributes: distance, atom_contacts
        """
        # Group atoms by residue
        residues = (
            structure_df
            .group_by(["chain", "residue_number", "residue_name"])
            .agg([
                pl.col("x").mean().alias("center_x"),
                pl.col("y").mean().alias("center_y"),
                pl.col("z").mean().alias("center_z"),
            ])
        )

        # Filter by interaction types if specified
        if interaction_types:
            residues = residues.filter(pl.col("residue_name").is_in(interaction_types))

        # Build NetworkX graph
        G = nx.Graph()

        # Add nodes
        for row in residues.iter_rows(named=True):
            node_id = (row["chain"], row["residue_number"])
            G.add_node(
                node_id,
                chain=row["chain"],
                residue_number=row["residue_number"],
                residue_name=row["residue_name"],
                pos=(row["center_x"], row["center_y"], row["center_z"])
            )

        # Compute pairwise distances
        residue_list = list(residues.iter_rows(named=True))
        cutoff_nm = distance_cutoff / 10.0  # Ã… to nm

        for i in range(len(residue_list)):
            for j in range(i + 1, len(residue_list)):
                res1 = residue_list[i]
                res2 = residue_list[j]

                # Skip same residue
                if res1["chain"] == res2["chain"] and res1["residue_number"] == res2["residue_number"]:
                    continue

                # Calculate center-of-mass distance
                center1 = np.array([res1["center_x"], res1["center_y"], res1["center_z"]])
                center2 = np.array([res2["center_x"], res2["center_y"], res2["center_z"]])
                dist = np.linalg.norm(center1 - center2)

                if dist < cutoff_nm:
                    node1 = (res1["chain"], res1["residue_number"])
                    node2 = (res2["chain"], res2["residue_number"])

                    # Count atom-level contacts
                    atoms1 = structure_df.filter(
                        (pl.col("chain") == res1["chain"]) &
                        (pl.col("residue_number") == res1["residue_number"])
                    )
                    atoms2 = structure_df.filter(
                        (pl.col("chain") == res2["chain"]) &
                        (pl.col("residue_number") == res2["residue_number"])
                    )

                    coords1 = atoms1.select(["x", "y", "z"]).to_numpy()
                    coords2 = atoms2.select(["x", "y", "z"]).to_numpy()

                    from scipy.spatial.distance import cdist
                    atom_distances = cdist(coords1, coords2)
                    num_contacts = int(np.sum(atom_distances < cutoff_nm))

                    G.add_edge(
                        node1,
                        node2,
                        distance=round(dist * 10.0, 3),  # Convert back to Ã…
                        atom_contacts=num_contacts
                    )

        return G

    def analyze_network_centrality(
        self,
        G: nx.Graph,
        top_n: int = 10,
    ) -> pl.DataFrame:
        """Analyze network centrality metrics to identify key residues.

        Args:
            G: NetworkX graph from compute_residue_interaction_network()
            top_n: Number of top residues to return (default 10)

        Returns:
            DataFrame with columns: [chain, residue_number, residue_name,
            degree_centrality, betweenness_centrality, closeness_centrality]
        """
        # Compute centrality metrics
        degree_cent = nx.degree_centrality(G)
        betweenness_cent = nx.betweenness_centrality(G)
        closeness_cent = nx.closeness_centrality(G)

        # Build DataFrame
        rows = []
        for node in G.nodes():
            chain, res_num = node
            rows.append({
                "chain": chain,
                "residue_number": res_num,
                "residue_name": G.nodes[node]["residue_name"],
                "degree_centrality": round(degree_cent[node], 4),
                "betweenness_centrality": round(betweenness_cent[node], 4),
                "closeness_centrality": round(closeness_cent[node], 4),
            })

        df = pl.DataFrame(rows)

        # Sort by betweenness (often most informative for hubs)
        df = df.sort("betweenness_centrality", descending=True).head(top_n)

        return df

    def plot_interaction_network(
        self,
        G: nx.Graph,
        output_file: Optional[str] = None,
        node_color_by: str = "chain",
        figsize: Tuple[int, int] = (12, 12),
    ):
        """Visualize residue interaction network.

        Args:
            G: NetworkX graph from compute_residue_interaction_network()
            output_file: Path to save figure (if None, shows interactively)
            node_color_by: Color nodes by "chain" or "residue_name" (default "chain")
            figsize: Figure size
        """
        fig, ax = plt.subplots(figsize=figsize)

        # Layout
        pos = nx.spring_layout(G, seed=42, k=0.5, iterations=50)

        # Node colors
        if node_color_by == "chain":
            unique_chains = sorted(set(G.nodes[n]["chain"] for n in G.nodes()))
            cmap = plt.cm.get_cmap("tab10", len(unique_chains))
            chain_to_color = {ch: cmap(i) for i, ch in enumerate(unique_chains)}
            node_colors = [chain_to_color[G.nodes[n]["chain"]] for n in G.nodes()]
        else:
            unique_residues = sorted(set(G.nodes[n]["residue_name"] for n in G.nodes()))
            cmap = plt.cm.get_cmap("tab20", len(unique_residues))
            res_to_color = {res: cmap(i) for i, res in enumerate(unique_residues)}
            node_colors = [res_to_color[G.nodes[n]["residue_name"]] for n in G.nodes()]

        # Draw network
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=300,
                              alpha=0.8, ax=ax)
        nx.draw_networkx_edges(G, pos, alpha=0.3, width=1, ax=ax)

        # Labels (residue number only for readability)
        labels = {n: str(n[1]) for n in G.nodes()}  # Just residue number
        nx.draw_networkx_labels(G, pos, labels, font_size=8, ax=ax)

        ax.set_title("Residue Interaction Network", fontsize=14, fontweight="bold")
        ax.axis("off")

        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches='tight')
            plt.close()
        else:
            plt.show()

compute_rmsd_matrix(structures, n_jobs=-1, pruning_threshold=None, prefilter=True)

Computes all-vs-all RMSD matrix for a dictionary of structures. Returns (matrix, labels).

Parameters:

Name Type Description Default
structures Dict[str, DataFrame]

Dictionary of structure_id -> DataFrame

required
n_jobs int

Number of parallel jobs (-1 for all CPUs)

-1
pruning_threshold Optional[float]

If set (0.0-1.0), skip alignment if sequence length ratio < threshold. Skipped pairs get a high RMSD value (e.g., 99.9).

None
prefilter bool

If True (default), use 3Di k-mer prefiltering to skip dissimilar pairs. Much faster for large N.

True
Source code in src/sicifus/analysis.py
def compute_rmsd_matrix(self, structures: Dict[str, pl.DataFrame], n_jobs: int = -1,
                        pruning_threshold: Optional[float] = None,
                        prefilter: bool = True) -> Tuple[np.ndarray, List[str]]:
    """
    Computes all-vs-all RMSD matrix for a dictionary of structures.
    Returns (matrix, labels).

    Args:
        structures: Dictionary of structure_id -> DataFrame
        n_jobs: Number of parallel jobs (-1 for all CPUs)
        pruning_threshold: If set (0.0-1.0), skip alignment if sequence length ratio < threshold.
                           Skipped pairs get a high RMSD value (e.g., 99.9).
        prefilter: If True (default), use 3Di k-mer prefiltering to skip
                   dissimilar pairs. Much faster for large N.
    """
    ids = list(structures.keys())
    n = len(ids)

    print(f"Pre-processing {n} structures...")

    coords_list = []
    for sid in ids:
        coords_list.append(self.aligner.get_ca_coords(structures[sid]))

    lengths = [len(c) for c in coords_list]
    num_pairs = n * (n - 1) // 2
    print(f"Computing RMSD matrix for {n} structures ({num_pairs} pairs)...")

    if prefilter:
        print(f"  Using 3Di k-mer prefilter.")
        matrix = self._rmsd_matrix_prefiltered(coords_list, ids, n, n_jobs, pruning_threshold)
    else:
        unique_lengths = sorted(set(lengths))
        all_same_length = len(unique_lengths) == 1
        if all_same_length:
            print(f"  All structures have {lengths[0]} residues — using vectorized Kabsch.")
            matrix = self._rmsd_matrix_vectorized(coords_list, n)
        else:
            print(f"  Variable lengths detected — using threaded alignment path.")
            matrix = self._rmsd_matrix_variable_length(coords_list, ids, n, n_jobs, pruning_threshold)

    return matrix, ids

cluster_fast(structures, distance_threshold=2.0, coverage_threshold=0.8)

Greedy centroid-based structural clustering (linclust-inspired).

Uses the 3Di k-mer index to quickly identify candidate centroids for each structure, then only computes RMSD to those candidates. No full N×N distance matrix is needed.

Parameters:

Name Type Description Default
structures Dict[str, DataFrame]

Dictionary of structure_id -> DataFrame.

required
distance_threshold float

Maximum RMSD to assign a structure to an existing cluster centroid (Ã…).

2.0
coverage_threshold float

Minimum length-ratio between a structure and a centroid (0-1) for them to be compared.

0.8

Returns:

Type Description
DataFrame

Polars DataFrame with columns

DataFrame

[structure_id, cluster, centroid_id, rmsd_to_centroid].

Source code in src/sicifus/analysis.py
def cluster_fast(self, structures: Dict[str, pl.DataFrame],
                 distance_threshold: float = 2.0,
                 coverage_threshold: float = 0.8) -> pl.DataFrame:
    """Greedy centroid-based structural clustering (linclust-inspired).

    Uses the 3Di k-mer index to quickly identify candidate centroids for
    each structure, then only computes RMSD to those candidates.  No full
    N×N distance matrix is needed.

    Args:
        structures: Dictionary of structure_id -> DataFrame.
        distance_threshold: Maximum RMSD to assign a structure to an
                            existing cluster centroid (Ã…).
        coverage_threshold: Minimum length-ratio between a structure and a
                            centroid (0-1) for them to be compared.

    Returns:
        Polars DataFrame with columns
        ``[structure_id, cluster, centroid_id, rmsd_to_centroid]``.
    """
    from .align import _encode_3di_numba, _superimpose_numba, _align_sequences_numba
    from .kmer_index import build_kmer_index, _extract_kmer_hashes

    ids = list(structures.keys())
    n = len(ids)

    print(f"Fast clustering {n} structures (threshold={distance_threshold} Ã…)...")

    coords_list = [self.aligner.get_ca_coords(structures[sid]) for sid in ids]
    lengths = [len(c) for c in coords_list]

    sequences_3di = [
        _encode_3di_numba(np.ascontiguousarray(c, dtype=np.float64))
        for c in coords_list
    ]
    seq_list = [
        np.array([ord(c) for c in self.aligner.encode_structure(c)], dtype=np.int32)
        for c in coords_list
    ]

    index = build_kmer_index(sequences_3di, k=6, alphabet_size=20)

    if n >= 2:
        _superimpose_numba(coords_list[0][:3], coords_list[1][:3])

    order = sorted(range(n), key=lambda i: lengths[i], reverse=True)

    centroid_indices: List[int] = []
    centroid_set: set = set()
    centroid_to_cluster: Dict[int, int] = {}
    cluster_of: Dict[int, int] = {}
    rmsd_of: Dict[int, float] = {}

    for idx in order:
        if not centroid_indices:
            centroid_indices.append(idx)
            centroid_set.add(idx)
            centroid_to_cluster[idx] = 0
            cluster_of[idx] = 0
            rmsd_of[idx] = 0.0
            continue

        hashes = _extract_kmer_hashes(sequences_3di[idx], 6, 20)
        unique_hashes = set(int(h) for h in hashes)
        n_query = len(unique_hashes)

        candidate_centroids: List[int] = []
        if n_query > 0:
            threshold = max(int(0.05 * n_query), 1)
            scores: Dict[int, int] = {}
            for h in unique_hashes:
                if h in index:
                    for j in index[h]:
                        if j in centroid_set:
                            scores[j] = scores.get(j, 0) + 1
            candidate_centroids = [
                c for c, s in scores.items() if s >= threshold
            ]

        best_cluster = -1
        best_rmsd = float("inf")

        for c_idx in candidate_centroids:
            l1, l2 = lengths[idx], lengths[c_idx]
            if l1 > 0 and l2 > 0:
                if min(l1, l2) / max(l1, l2) < coverage_threshold:
                    continue
            if l1 == l2:
                rmsd, _, _, _ = _superimpose_numba(coords_list[idx], coords_list[c_idx])
            else:
                i1, i2 = _align_sequences_numba(seq_list[idx], seq_list[c_idx])
                if len(i1) < 3:
                    continue
                rmsd, _, _, _ = _superimpose_numba(
                    coords_list[idx][i1], coords_list[c_idx][i2]
                )
            if rmsd < best_rmsd:
                best_rmsd = rmsd
                best_cluster = centroid_to_cluster[c_idx]

        if best_cluster >= 0 and best_rmsd <= distance_threshold:
            cluster_of[idx] = best_cluster
            rmsd_of[idx] = best_rmsd
        else:
            new_cid = len(centroid_indices)
            centroid_indices.append(idx)
            centroid_set.add(idx)
            centroid_to_cluster[idx] = new_cid
            cluster_of[idx] = new_cid
            rmsd_of[idx] = 0.0

    cluster_to_centroid = {v: ids[k] for k, v in centroid_to_cluster.items()}

    rows = []
    for idx in range(n):
        cid = cluster_of[idx]
        rows.append({
            "structure_id": ids[idx],
            "cluster": cid + 1,
            "centroid_id": cluster_to_centroid[cid],
            "rmsd_to_centroid": round(rmsd_of[idx], 4),
        })

    df = pl.DataFrame(rows)
    n_clusters = df["cluster"].n_unique()
    print(f"  {n_clusters} clusters found")
    return df

build_tree(rmsd_matrix, labels, method='average')

Builds a hierarchical clustering tree from RMSD matrix using Scipy. Returns the linkage matrix.

Source code in src/sicifus/analysis.py
def build_tree(self, rmsd_matrix: np.ndarray, labels: List[str], method: str = "average"):
    """
    Builds a hierarchical clustering tree from RMSD matrix using Scipy.
    Returns the linkage matrix.
    """
    # Convert to condensed distance matrix
    condensed_dist = squareform(rmsd_matrix)
    Z = linkage(condensed_dist, method=method)
    return Z

build_phylo_tree(rmsd_matrix, labels, root_id=None)

Builds a phylogenetic tree from RMSD matrix. Uses Scipy's fast C-based linkage, then converts to Biopython Tree for Newick export. Tree is unrooted by default. If root_id is provided, roots at that structure.

Source code in src/sicifus/analysis.py
def build_phylo_tree(self, rmsd_matrix: np.ndarray, labels: List[str], root_id: Optional[str] = None):
    """
    Builds a phylogenetic tree from RMSD matrix.
    Uses Scipy's fast C-based linkage, then converts to Biopython Tree for Newick export.
    Tree is unrooted by default. If root_id is provided, roots at that structure.
    """
    condensed_dist = squareform(rmsd_matrix)
    Z = linkage(condensed_dist, method="average")

    tree = self._linkage_to_biopython_tree(Z, labels)

    if root_id:
        if root_id not in labels:
            raise ValueError(f"Root ID {root_id} not found in labels.")
        tree.root_with_outgroup({"name": root_id})
        tree.rooted = True

    return tree

cluster_from_tree(tree, distance_threshold)

Derives clusters directly from the phylogenetic tree by cutting branches whose length (RMSD) exceeds the threshold. Each resulting subtree's leaves form a cluster.

This uses the actual tree topology and RMSD branch lengths — the clusters are defined by the tree itself, not by an external algorithm.

Parameters:

Name Type Description Default
tree

Biopython Tree object with RMSD branch lengths.

required
distance_threshold float

Cut any branch longer than this RMSD value. e.g. 1.0 means groups separated by > 1 Ã… RMSD are different clusters.

required

Returns:

Type Description
DataFrame

Polars DataFrame with columns: structure_id, cluster

Source code in src/sicifus/analysis.py
def cluster_from_tree(self, tree, distance_threshold: float) -> pl.DataFrame:
    """
    Derives clusters directly from the phylogenetic tree by cutting branches
    whose length (RMSD) exceeds the threshold. Each resulting subtree's 
    leaves form a cluster.

    This uses the actual tree topology and RMSD branch lengths — the clusters
    are defined by the tree itself, not by an external algorithm.

    Args:
        tree: Biopython Tree object with RMSD branch lengths.
        distance_threshold: Cut any branch longer than this RMSD value.
                           e.g. 1.0 means groups separated by > 1 Ã… RMSD are different clusters.

    Returns:
        Polars DataFrame with columns: structure_id, cluster
    """
    assignments = {}
    cluster_counter = [0]

    def _assign(clade, current_cluster):
        """Walk the tree; cut branches that exceed the threshold."""
        if clade.is_terminal():
            assignments[clade.name] = current_cluster
        else:
            for child in clade.clades:
                bl = child.branch_length if child.branch_length is not None else 0.0
                if bl > distance_threshold:
                    # This branch is too long — new cluster for this subtree
                    cluster_counter[0] += 1
                    _assign(child, cluster_counter[0])
                else:
                    # Same cluster
                    _assign(child, current_cluster)

    cluster_counter[0] = 1
    _assign(tree.root, 1)

    return pl.DataFrame({
        "structure_id": list(assignments.keys()),
        "cluster": list(assignments.values())
    })

plot_tree(tree_obj, labels=None, output_file=None)

Plots the tree. Handles both Scipy linkage matrix and Biopython Tree.

Source code in src/sicifus/analysis.py
def plot_tree(self, tree_obj: Union[np.ndarray, Phylo.BaseTree.Tree], labels: Optional[List[str]] = None, output_file: Optional[str] = None):
    """
    Plots the tree. Handles both Scipy linkage matrix and Biopython Tree.
    """
    plt.figure(figsize=(10, 7))

    if isinstance(tree_obj, np.ndarray):
        # Scipy Linkage Matrix
        dendrogram(tree_obj, labels=labels, leaf_rotation=90)
        plt.title("Structural Phylogenetic Tree (RMSD-based)")
        plt.xlabel("Structure ID")
        plt.ylabel("RMSD")
    elif isinstance(tree_obj, Phylo.BaseTree.Tree):
        # Biopython Tree
        plt.clf()
        Phylo.draw(tree_obj, do_show=False)
        plt.title("Structural Phylogenetic Tree (Rooted)" if tree_obj.rooted else "Structural Phylogenetic Tree")

    if output_file:
        plt.savefig(output_file)
    else:
        plt.show()

plot_circular_tree(Z, labels, cluster_df=None, output_file=None, show_labels=None, figsize=(14, 14), linewidth=0.4)

Plots an unrooted circular/radial dendrogram from a linkage matrix.

Parameters:

Name Type Description Default
Z ndarray

Scipy linkage matrix.

required
labels List[str]

Structure IDs (leaf labels).

required
cluster_df Optional[DataFrame]

Optional cluster assignments (from cluster_structures) to color branches.

None
output_file Optional[str]

If provided, saves to file instead of showing.

None
show_labels Optional[bool]

Whether to show leaf labels. Defaults to True if <=100 leaves, else False.

None
figsize Tuple[int, int]

Figure size.

(14, 14)
linewidth float

Line width for branches.

0.4
Source code in src/sicifus/analysis.py
def plot_circular_tree(self, Z: np.ndarray, labels: List[str], 
                       cluster_df: Optional[pl.DataFrame] = None,
                       output_file: Optional[str] = None,
                       show_labels: Optional[bool] = None,
                       figsize: Tuple[int, int] = (14, 14),
                       linewidth: float = 0.4):
    """
    Plots an unrooted circular/radial dendrogram from a linkage matrix.

    Args:
        Z: Scipy linkage matrix.
        labels: Structure IDs (leaf labels).
        cluster_df: Optional cluster assignments (from cluster_structures) to color branches.
        output_file: If provided, saves to file instead of showing.
        show_labels: Whether to show leaf labels. Defaults to True if <=100 leaves, else False.
        figsize: Figure size.
        linewidth: Line width for branches.
    """
    n_leaves = len(labels)

    if show_labels is None:
        show_labels = n_leaves <= 100

    # Get dendrogram coordinate data (without plotting)
    dn = dendrogram(Z, no_plot=True, count_sort=True, labels=labels)
    plt.close()  # close the blank figure dendrogram might have created

    icoord = np.array(dn['icoord'])  # x-coords of each U-shape: (n_merges, 4)
    dcoord = np.array(dn['dcoord'])  # y-coords (heights): (n_merges, 4)
    leaf_label_order = dn['ivl']     # ordered leaf labels

    # Build a leaf_label -> cluster color mapping
    cluster_colors = None
    if cluster_df is not None:
        n_clusters = cluster_df["cluster"].n_unique()
        cmap = plt.cm.get_cmap("tab20" if n_clusters <= 20 else "hsv", n_clusters)

        cluster_color_map = {}
        for row in cluster_df.iter_rows(named=True):
            cluster_color_map[row["structure_id"]] = cmap((row["cluster"] - 1) / max(n_clusters - 1, 1))
        cluster_colors = cluster_color_map

    # Map leaf x-positions to angles (0 to 2Ï€)
    # Dendrogram leaf x-positions are at 5, 15, 25, ... (spacing=10)
    x_min = 5.0
    x_max = 5.0 + (n_leaves - 1) * 10.0
    x_range = x_max - x_min if x_max > x_min else 1.0

    max_height = np.max(dcoord) if np.max(dcoord) > 0 else 1.0

    def x_to_angle(x):
        return (x - x_min) / x_range * 2.0 * np.pi

    def y_to_radius(y):
        # Height 0 (leaves) → outer ring; max height (root) → center
        return 1.0 - (y / max_height) * 0.85

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, polar=True)

    # Draw each U-shape link
    for xs, ys in zip(icoord, dcoord):
        # xs: [left_x, left_x, right_x, right_x]
        # ys: [bottom_left_y, top_y, top_y, bottom_right_y]

        a = [x_to_angle(x) for x in xs]
        r = [y_to_radius(y) for y in ys]

        # Determine branch color: if cluster_colors, use the color of the left child leaf
        color = '#555555'

        # Left vertical: (a[0], r[0]) → (a[1], r[1])
        ax.plot([a[0], a[1]], [r[0], r[1]], color=color, linewidth=linewidth, solid_capstyle='round')

        # Right vertical: (a[2], r[2]) → (a[3], r[3])
        ax.plot([a[2], a[3]], [r[2], r[3]], color=color, linewidth=linewidth, solid_capstyle='round')

        # Top arc: from a[1] to a[2] at radius r[1] (= r[2])
        n_arc = max(int(abs(a[2] - a[1]) / (2 * np.pi) * 100), 2)
        arc_angles = np.linspace(a[1], a[2], n_arc)
        arc_radii = np.full_like(arc_angles, r[1])
        ax.plot(arc_angles, arc_radii, color=color, linewidth=linewidth, solid_capstyle='round')

    # Draw colored dots and labels at leaf positions
    leaf_angles = [x_to_angle(5.0 + i * 10.0) for i in range(n_leaves)]
    leaf_radius = y_to_radius(0)

    for i, (angle, label) in enumerate(zip(leaf_angles, leaf_label_order)):
        dot_color = cluster_colors.get(label, '#888888') if cluster_colors else '#333333'
        ax.plot(angle, leaf_radius, 'o', color=dot_color, markersize=2.5, zorder=5)

        if show_labels:
            # Rotate label to read outward
            angle_deg = np.degrees(angle)
            ha = 'left' if angle < np.pi else 'right'
            rotation = angle_deg if angle < np.pi else angle_deg - 180
            ax.text(angle, leaf_radius + 0.04, label, fontsize=5, rotation=rotation,
                    ha=ha, va='center', rotation_mode='anchor',
                    color=dot_color if cluster_colors else '#333333')

    # Draw cluster legend if clusters provided
    if cluster_df is not None:
        n_clusters = cluster_df["cluster"].n_unique()
        cmap = plt.cm.get_cmap("tab20" if n_clusters <= 20 else "hsv", n_clusters)
        cluster_ids = sorted(cluster_df["cluster"].unique().to_list())

        legend_handles = []
        for cid in cluster_ids:
            count = cluster_df.filter(pl.col("cluster") == cid).height
            color = cmap((cid - 1) / max(n_clusters - 1, 1))
            patch = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color,
                               markersize=8, label=f"Cluster {cid} ({count})")
            legend_handles.append(patch)

        ax.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(1.3, 1.05),
                  fontsize=8, title="Clusters", title_fontsize=9)

    # Clean up polar plot
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.spines['polar'].set_visible(False)
    ax.grid(False)

    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=200, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

build_similarity_network(rmsd_matrix, labels, threshold)

Builds a network where edges exist if RMSD < threshold.

Source code in src/sicifus/analysis.py
def build_similarity_network(self, rmsd_matrix: np.ndarray, labels: List[str], threshold: float) -> nx.Graph:
    """
    Builds a network where edges exist if RMSD < threshold.
    """
    G = nx.Graph()
    n = len(labels)

    for i in range(n):
        G.add_node(labels[i])

    for i in range(n):
        for j in range(i + 1, n):
            if rmsd_matrix[i, j] < threshold:
                G.add_edge(labels[i], labels[j], weight=rmsd_matrix[i, j])

    return G

Ligand Analyzer

sicifus.analysis.LigandAnalyzer

Tools for analyzing ligand binding sites, pi-stacking interactions, and atom-level protein-ligand contacts.

Source code in src/sicifus/analysis.py
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
class LigandAnalyzer:
    """
    Tools for analyzing ligand binding sites, pi-stacking interactions,
    and atom-level protein-ligand contacts.
    """

    # ── Aromatic ring definitions for standard amino acids ────────────
    # Each entry maps residue name -> list of rings, each ring is a tuple
    # of atom names that form the aromatic ring.
    PROTEIN_AROMATIC_RINGS: Dict[str, List[Tuple[str, ...]]] = {
        "PHE": [("CG", "CD1", "CE1", "CZ", "CE2", "CD2")],
        "TYR": [("CG", "CD1", "CE1", "CZ", "CE2", "CD2")],
        "TRP": [
            ("CG", "CD1", "NE1", "CE2", "CD2"),           # 5-membered indole ring
            ("CD2", "CE2", "CE3", "CZ3", "CH2", "CZ2"),   # 6-membered indole ring
        ],
        "HIS": [("CG", "ND1", "CE1", "NE2", "CD2")],
    }

    def __init__(self):
        pass

    # ── Binding residue detection (existing, uses backbone CA) ────────

    def find_binding_residues(self, backbone_df: pl.DataFrame, ligand_df: pl.DataFrame, 
                              ligand_name: str, distance_cutoff: float = 5.0) -> pl.DataFrame:
        """
        Identifies residues within a cutoff distance of a specific ligand.
        Uses CA atoms from backbone for fast residue-level proximity.
        """
        target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
        if target_ligand.height == 0:
            return pl.DataFrame()

        from scipy.spatial.distance import cdist
        backbone_coords = backbone_df.select(["x", "y", "z"]).to_numpy()
        ligand_coords = target_ligand.select(["x", "y", "z"]).to_numpy()

        dists = cdist(backbone_coords, ligand_coords)
        min_dists = np.min(dists, axis=1)
        mask = min_dists < distance_cutoff
        binding_residues = backbone_df.filter(mask)
        return binding_residues.unique(subset=["chain", "residue_number", "residue_name"])

    def get_pocket_residues(self, all_atom_df: pl.DataFrame, ligand_df: pl.DataFrame,
                            ligand_name: str, distance_cutoff: float = 8.0) -> List[str]:
        """
        Identifies all unique residues within the specified distance cutoff
        of any atom in the ligand. Uses all-atom coordinates for accuracy.

        Returns a list of residue names (e.g. ["ALA", "HIS", ...]) found in the pocket.
        """
        from scipy.spatial.distance import cdist

        target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
        if target_ligand.height == 0 or all_atom_df.height == 0:
            return []

        prot_coords = all_atom_df.select(["x", "y", "z"]).to_numpy()
        lig_coords = target_ligand.select(["x", "y", "z"]).to_numpy()

        # Calculate distances between all protein atoms and all ligand atoms
        dists = cdist(prot_coords, lig_coords)

        # Find protein atoms that are close to ANY ligand atom
        min_dists = np.min(dists, axis=1)
        mask = min_dists < distance_cutoff

        # Filter protein atoms
        pocket_atoms = all_atom_df.filter(mask)

        if pocket_atoms.height == 0:
            return []

        # Get unique residues (chain + number + name)
        unique_residues = pocket_atoms.unique(subset=["chain", "residue_number", "residue_name"])

        # Return just the residue names
        return unique_residues.get_column("residue_name").to_list()

    def plot_binding_pocket_composition(self, residue_counts: Dict[str, int],
                                        title: str = "Binding Pocket Composition",
                                        output_file: Optional[str] = None):
        """
        Plots a histogram of residue types found in the binding pocket.
        Ensures all 20 standard amino acids are represented on the X-axis.
        """
        # Standard 20 amino acids
        standard_aa = [
            "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
            "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"
        ]

        # Separate standard vs non-standard counts
        standard_counts = {aa: residue_counts.get(aa, 0) for aa in standard_aa}
        non_standard_counts = {k: v for k, v in residue_counts.items() if k not in standard_aa}

        # Combine for plotting: standard first (alphabetical), then non-standard (sorted by count)
        plot_labels = standard_aa + sorted(non_standard_counts.keys(), key=lambda k: non_standard_counts[k], reverse=True)
        plot_values = [standard_counts.get(l, non_standard_counts.get(l, 0)) for l in plot_labels]

        # Filter out non-standard with 0 counts (shouldn't happen based on logic but good safety)
        # Keep all standard even if 0
        final_labels = []
        final_values = []
        for l, v in zip(plot_labels, plot_values):
            if l in standard_aa or v > 0:
                final_labels.append(l)
                final_values.append(v)

        fig, ax = plt.subplots(figsize=(max(10, len(final_labels) * 0.4), 6))
        bars = ax.bar(final_labels, final_values, edgecolor="black", alpha=0.8, color="#4CAF50")

        ax.set_title(title, fontsize=14)
        ax.set_xlabel("Residue Type", fontsize=12)
        ax.set_ylabel("Frequency (Count)", fontsize=12)
        plt.xticks(rotation=45, ha="right")

        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height,
                        f'{int(height)}',
                        ha='center', va='bottom', fontsize=9)

        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

    def plot_binding_histogram(self, residues_list: List[str], title: str = "Ligand Binding Residue Distribution",
                               output_file: Optional[str] = None):
        """Plots a histogram of binding residue types."""
        from collections import Counter
        counts = Counter(residues_list)

        if not counts:
            print("No binding residues to plot.")
            return

        labels, values = zip(*counts.most_common())

        fig, ax = plt.subplots(figsize=(10, 6))
        ax.bar(labels, values, edgecolor="black", alpha=0.8)
        ax.set_title(title, fontsize=13)
        ax.set_xlabel("Residue Type", fontsize=11)
        ax.set_ylabel("Frequency", fontsize=11)
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

    # ── Pi-stacking detection ─────────────────────────────────────────

    @staticmethod
    def _ring_geometry(coords: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Computes the centroid and unit normal vector for a ring 
        defined by a set of 3D coordinates (N, 3).
        """
        centroid = coords.mean(axis=0)
        centered = coords - centroid
        _, _, Vt = np.linalg.svd(centered)
        normal = Vt[2]  # eigenvector corresponding to smallest singular value
        normal = normal / np.linalg.norm(normal)
        return centroid, normal

    def _get_protein_rings(self, all_atom_df: pl.DataFrame) -> List[Dict]:
        """
        Extracts aromatic ring definitions from protein all-atom data.
        Returns list of dicts with keys: centroid, normal, residue_name,
        residue_number, chain, ring_type.
        """
        rings = []
        # Group by residue
        aromatic_res = all_atom_df.filter(
            pl.col("residue_name").is_in(list(self.PROTEIN_AROMATIC_RINGS.keys()))
        )
        if aromatic_res.height == 0:
            return rings

        grouped = aromatic_res.group_by(["chain", "residue_number", "residue_name"])

        for (chain, resnum, resname), group_df in grouped:
            ring_defs = self.PROTEIN_AROMATIC_RINGS.get(resname, [])
            for ring_atoms in ring_defs:
                ring_df = group_df.filter(pl.col("atom_name").is_in(ring_atoms))
                if ring_df.height < len(ring_atoms) - 1:
                    continue  # missing atoms, skip
                coords = ring_df.select(["x", "y", "z"]).to_numpy()
                if coords.shape[0] < 3:
                    continue
                centroid, normal = self._ring_geometry(coords)
                rings.append({
                    "centroid": centroid,
                    "normal": normal,
                    "residue_name": resname,
                    "residue_number": resnum,
                    "chain": chain,
                    "ring_size": len(ring_atoms),
                    "source": "protein",
                })
        return rings

    def _detect_ligand_rings(self, ligand_atoms: pl.DataFrame) -> List[Dict]:
        """
        Detects aromatic rings in a ligand using distance-based connectivity
        and planarity analysis. No external chemistry toolkit needed.
        """
        import networkx as nx_graph
        from scipy.spatial.distance import cdist as cdist_fn

        rings = []
        if ligand_atoms.height < 5:
            return rings

        coords = ligand_atoms.select(["x", "y", "z"]).to_numpy()
        atom_names = ligand_atoms.get_column("atom_name").to_list()
        elements = ligand_atoms.get_column("element").to_list()

        # Build connectivity graph from interatomic distances (covalent bonds ~1.0–1.8 Å)
        dists = cdist_fn(coords, coords)
        G = nx_graph.Graph()
        for i in range(len(coords)):
            G.add_node(i)
        for i in range(len(coords)):
            for j in range(i + 1, len(coords)):
                if 0.8 < dists[i, j] < 1.85:
                    G.add_edge(i, j)

        # Find all cycles of size 5 or 6
        try:
            cycles = nx_graph.cycle_basis(G)
        except Exception:
            return rings

        for cycle in cycles:
            if len(cycle) not in (5, 6):
                continue
            ring_coords = coords[cycle]
            ring_elements = [elements[idx] for idx in cycle]
            # Aromatic rings are made of C, N, O, S
            if not all(e in ("C", "N", "O", "S") for e in ring_elements):
                continue
            # Check planarity: smallest singular value should be near zero
            centroid = ring_coords.mean(axis=0)
            centered = ring_coords - centroid
            _, s, Vt = np.linalg.svd(centered)
            planarity = s[2] / (s[0] + 1e-10)
            if planarity > 0.15:
                continue  # not planar enough
            normal = Vt[2]
            normal = normal / np.linalg.norm(normal)
            rings.append({
                "centroid": centroid,
                "normal": normal,
                "atom_indices": cycle,
                "atom_names": [atom_names[idx] for idx in cycle],
                "elements": ring_elements,
                "ring_size": len(cycle),
                "source": "ligand",
            })
        return rings

    @staticmethod
    def _classify_pi_interaction(centroid1: np.ndarray, normal1: np.ndarray,
                                  centroid2: np.ndarray, normal2: np.ndarray) -> Optional[str]:
        """
        Classifies the pi-stacking geometry between two aromatic rings.
        Returns one of: "sandwich", "parallel_displaced", "t_shaped", or None.

        Criteria (standard computational chemistry definitions):
          - Sandwich:            distance < 4.0 Å, angle < 30°, offset < 1.5 Å
          - Parallel displaced:  distance < 5.5 Å, angle < 30°
          - T-shaped:            distance < 5.5 Å, angle 60–90°
        """
        vec = centroid2 - centroid1
        d = np.linalg.norm(vec)

        if d > 7.0 or d < 2.0:
            return None

        cos_angle = abs(np.dot(normal1, normal2))
        cos_angle = np.clip(cos_angle, 0.0, 1.0)
        angle = np.degrees(np.arccos(cos_angle))

        # Perpendicular offset: projection of centroid-centroid vector onto ring normal
        projection = abs(np.dot(vec, normal1))
        offset = np.sqrt(max(d**2 - projection**2, 0.0))

        if angle < 30:  # roughly parallel normals
            if d < 4.0 and offset < 1.5:
                return "sandwich"
            elif d < 5.5:
                return "parallel_displaced"
        elif angle > 60:  # roughly perpendicular normals
            if d < 5.5:
                return "t_shaped"
        return None

    def detect_pi_stacking(self, all_atom_df: pl.DataFrame, ligand_df: pl.DataFrame,
                           ligand_name: str) -> pl.DataFrame:
        """
        Detects pi-stacking interactions between protein aromatic residues
        and aromatic rings in the specified ligand.

        Returns a DataFrame with columns:
          protein_chain, protein_residue, protein_resname, 
          ligand_ring_atoms, interaction_type, distance, angle
        """
        target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
        if target_ligand.height == 0:
            return pl.DataFrame()

        protein_rings = self._get_protein_rings(all_atom_df)
        ligand_rings = self._detect_ligand_rings(target_ligand)

        if not protein_rings or not ligand_rings:
            return pl.DataFrame()

        interactions = []
        for pr in protein_rings:
            for lr in ligand_rings:
                interaction = self._classify_pi_interaction(
                    pr["centroid"], pr["normal"],
                    lr["centroid"], lr["normal"],
                )
                if interaction is not None:
                    vec = lr["centroid"] - pr["centroid"]
                    d = float(np.linalg.norm(vec))
                    cos_a = abs(np.dot(pr["normal"], lr["normal"]))
                    angle = float(np.degrees(np.arccos(np.clip(cos_a, 0.0, 1.0))))
                    interactions.append({
                        "protein_chain": pr["chain"],
                        "protein_residue": pr["residue_number"],
                        "protein_resname": pr["residue_name"],
                        "ligand_ring_atoms": ",".join(lr["atom_names"]),
                        "interaction_type": interaction,
                        "distance": round(d, 2),
                        "angle": round(angle, 1),
                    })

        if not interactions:
            return pl.DataFrame()
        return pl.DataFrame(interactions)

    def plot_pi_stacking(self, interactions_list: List[Dict],
                         title: str = "Pi-Stacking Interactions",
                         output_file: Optional[str] = None):
        """
        Plots a grouped bar chart of pi-stacking interaction types
        broken down by interaction type and protein residue type.
        """
        from collections import Counter

        if not interactions_list:
            print("No pi-stacking interactions to plot.")
            return

        type_counts = Counter()
        residue_type_counts: Dict[str, Counter] = {
            "sandwich": Counter(), "parallel_displaced": Counter(), "t_shaped": Counter()
        }

        for ix in interactions_list:
            itype = ix["interaction_type"]
            resname = ix["protein_resname"]
            type_counts[itype] += 1
            if itype in residue_type_counts:
                residue_type_counts[itype][resname] += 1

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Left: overall counts by interaction type
        type_labels = ["sandwich", "parallel_displaced", "t_shaped"]
        type_colors = {"sandwich": "#2196F3", "parallel_displaced": "#FF9800", "t_shaped": "#4CAF50"}
        counts = [type_counts.get(t, 0) for t in type_labels]
        display_labels = ["Sandwich", "Parallel\nDisplaced", "T-Shaped"]
        bars = axes[0].bar(display_labels, counts, 
                           color=[type_colors[t] for t in type_labels], edgecolor="black", alpha=0.85)
        axes[0].set_ylabel("Count", fontsize=11)
        axes[0].set_title("By Interaction Type", fontsize=12)
        for bar, c in zip(bars, counts):
            if c > 0:
                axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3, 
                           str(c), ha="center", fontsize=10)

        # Right: breakdown by residue type
        all_resnames = sorted(set(r for c in residue_type_counts.values() for r in c))
        x = np.arange(len(all_resnames))
        width = 0.25
        for i, (itype, label) in enumerate(zip(type_labels, display_labels)):
            vals = [residue_type_counts[itype].get(r, 0) for r in all_resnames]
            axes[1].bar(x + i * width, vals, width, label=label.replace("\n", " "),
                       color=type_colors[itype], edgecolor="black", alpha=0.85)
        axes[1].set_xticks(x + width)
        axes[1].set_xticklabels(all_resnames)
        axes[1].set_ylabel("Count", fontsize=11)
        axes[1].set_title("By Residue Type", fontsize=12)
        axes[1].legend(fontsize=9)

        fig.suptitle(title, fontsize=14, fontweight="bold")
        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

    # ── Ligand atom-level contacts (hydrogen bonding etc.) ────────────

    def find_ligand_atom_contacts(self, all_atom_df: pl.DataFrame, ligand_df: pl.DataFrame,
                                   ligand_name: str, distance_cutoff: float = 3.3) -> pl.DataFrame:
        """
        Identifies atom-level contacts between protein atoms and individual
        ligand atoms. Default cutoff of 3.3 Ã… targets hydrogen-bond-like
        interactions (user can adjust).

        Returns a DataFrame with columns:
          ligand_atom, ligand_element, protein_chain, protein_residue,
          protein_resname, protein_atom, protein_element, distance
        """
        from scipy.spatial.distance import cdist as cdist_fn

        target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
        if target_ligand.height == 0 or all_atom_df.height == 0:
            return pl.DataFrame()

        prot_coords = all_atom_df.select(["x", "y", "z"]).to_numpy()
        lig_coords = target_ligand.select(["x", "y", "z"]).to_numpy()

        dists = cdist_fn(prot_coords, lig_coords)  # (P, L)

        contacts = []
        # For each ligand atom, find protein atoms within cutoff
        for lig_idx in range(lig_coords.shape[0]):
            close_mask = dists[:, lig_idx] < distance_cutoff
            close_indices = np.where(close_mask)[0]

            lig_row = target_ligand.row(lig_idx, named=True)

            for prot_idx in close_indices:
                prot_row = all_atom_df.row(int(prot_idx), named=True)
                contacts.append({
                    "ligand_atom": lig_row["atom_name"],
                    "ligand_element": lig_row["element"],
                    "protein_chain": prot_row["chain"],
                    "protein_residue": prot_row["residue_number"],
                    "protein_resname": prot_row["residue_name"],
                    "protein_atom": prot_row["atom_name"],
                    "protein_element": prot_row["element"],
                    "distance": round(float(dists[prot_idx, lig_idx]), 2),
                })

        if not contacts:
            return pl.DataFrame()
        return pl.DataFrame(contacts)

    def plot_ligand_contacts(self, contacts_df: pl.DataFrame,
                             title: str = "Ligand Atom Contacts",
                             output_file: Optional[str] = None):
        """
        Plots a bar chart showing which ligand atoms form the most contacts,
        colored by the protein residue type they interact with.

        Uses canonical atom labels (e.g. C1, O2) if available, otherwise
        falls back to PDB atom names.
        """
        if contacts_df.height == 0:
            print("No contacts to plot.")
            return

        # Use canonical_atom column if available (consistent across predictors)
        label_col = "canonical_atom" if "canonical_atom" in contacts_df.columns else "ligand_atom"

        # Aggregate: for each ligand atom, count contacts by protein residue type
        agg = (
            contacts_df
            .group_by([label_col, "protein_resname"])
            .agg(pl.len().alias("count"))
            .sort("count", descending=True)
        )

        # Get all unique ligand atoms, sorted by total contacts
        atom_totals = (
            agg.group_by(label_col)
            .agg(pl.col("count").sum().alias("total"))
            .sort("total", descending=True)
        )
        lig_atoms = atom_totals.get_column(label_col).to_list()
        all_resnames = sorted(agg.get_column("protein_resname").unique().to_list())

        fig, ax = plt.subplots(figsize=(max(10, len(lig_atoms) * 0.8), 6))

        cmap = plt.cm.get_cmap("tab20", len(all_resnames))
        x = np.arange(len(lig_atoms))
        bottom = np.zeros(len(lig_atoms))

        for i, resname in enumerate(all_resnames):
            vals = []
            for atom in lig_atoms:
                row = agg.filter(
                    (pl.col(label_col) == atom) & (pl.col("protein_resname") == resname)
                )
                vals.append(row.get_column("count").sum() if row.height > 0 else 0)
            vals = np.array(vals, dtype=float)
            color = cmap(i / max(len(all_resnames) - 1, 1))
            ax.bar(x, vals, bottom=bottom, label=resname, color=color, edgecolor="black", alpha=0.85)
            bottom += vals

        ax.set_xticks(x)
        ax.set_xticklabels(lig_atoms, rotation=45, ha="right", fontsize=9)
        ax.set_ylabel("Number of Contacts", fontsize=11)
        ax.set_xlabel("Ligand Atom (canonical)", fontsize=11)
        ax.set_title(title, fontsize=13)
        ax.legend(fontsize=7, ncol=3, title="Protein Residue", title_fontsize=8,
                  bbox_to_anchor=(1.02, 1), loc="upper left")
        plt.tight_layout()

        if output_file:
            plt.savefig(output_file, dpi=150, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

    # ── 2D Ligand Depiction (RDKit) ──────────────────────────────────

    @staticmethod
    def _rdkit_available() -> bool:
        try:
            from rdkit import Chem  # noqa: F401
            return True
        except ImportError:
            return False

    def build_ligand_mol(self, ligand_atoms: pl.DataFrame, charge: Optional[int] = None,
                         infer_bond_orders: bool = True):
        """
        Builds an RDKit molecule from ligand atom 3D coordinates.

        Uses canonical SMILES ordering to assign consistent atom labels
        that are independent of the input file's atom naming convention.
        This ensures the same atom always gets the same label regardless
        of which structure predictor generated the file.

        Args:
            ligand_atoms: DataFrame with x, y, z, element, atom_name columns.
            charge: Total formal charge of the ligand. Helps RDKit infer
                    correct bond orders / protonation state.  For example,
                    citrate at pH 7 is typically -3. If None, RDKit guesses.
            infer_bond_orders: If True, attempt to determine double/aromatic
                    bonds from 3D geometry.  If False, only connectivity
                    (single bonds) is determined — safer when the protonation
                    state is unknown.

        Returns:
            (mol, pdb_atom_names, canonical_labels, canonical_smiles)
            or (None, None, None, None) if RDKit is not installed.

            canonical_labels: list of str like ["C1", "C2", "O1", "O2", ...]
                numbered per-element in canonical SMILES traversal order.
        """
        if not self._rdkit_available():
            return None, None, None, None

        from rdkit import Chem
        from rdkit.Chem import AllChem
        from rdkit.Geometry import Point3D

        coords = ligand_atoms.select(["x", "y", "z"]).to_numpy()
        elements = ligand_atoms.get_column("element").to_list()
        atom_names = ligand_atoms.get_column("atom_name").to_list()

        if len(coords) == 0:
            return None, None, None, None

        # Build editable molecule with 3D conformer
        mol = Chem.RWMol()
        conf = Chem.Conformer(len(coords))

        for i, (elem, coord) in enumerate(zip(elements, coords)):
            atom = Chem.Atom(elem)
            idx = mol.AddAtom(atom)
            conf.SetAtomPosition(idx, Point3D(float(coord[0]), float(coord[1]), float(coord[2])))

        mol.AddConformer(conf, assignId=True)

        # Step 1: Determine connectivity (which atoms are bonded)
        # Step 2 (optional): Determine bond orders (single / double / aromatic)
        connectivity_ok = False
        try:
            from rdkit.Chem import rdDetermineBonds
            rdDetermineBonds.DetermineConnectivity(mol)
            connectivity_ok = True

            if infer_bond_orders:
                try:
                    if charge is not None:
                        rdDetermineBonds.DetermineBondOrders(mol, charge=charge)
                    else:
                        rdDetermineBonds.DetermineBondOrders(mol)
                except Exception:
                    pass
        except (ImportError, Exception):
            pass

        if not connectivity_ok:
            from scipy.spatial.distance import cdist as cdist_fn
            dists = cdist_fn(coords, coords)
            for i in range(len(coords)):
                for j in range(i + 1, len(coords)):
                    if 0.8 < dists[i, j] < 1.85:
                        mol.AddBond(i, j, Chem.BondType.SINGLE)

        try:
            Chem.SanitizeMol(mol)
        except Exception:
            pass

        mol_final = mol.GetMol()

        # ── Canonical atom labelling ─────────────────────────────────
        # CanonicalRankAtoms gives each atom a unique rank based on the
        # canonical SMILES traversal — identical molecular graphs always
        # produce the same ranking, regardless of input atom order/names.
        canonical_smiles = None
        canonical_labels = list(atom_names)  # fallback to PDB names
        try:
            canonical_smiles = Chem.MolToSmiles(mol_final)
            ranks = Chem.CanonicalRankAtoms(mol_final)
            # Build per-element numbering in canonical order:
            #   rank 0 gets assigned first, then rank 1, etc.
            #   e.g. C1, C2, C3, O1, O2, N1, ...
            n = mol_final.GetNumAtoms()
            # Sort atom indices by their canonical rank
            sorted_indices = sorted(range(n), key=lambda i: ranks[i])
            element_counters: Dict[str, int] = {}
            label_by_idx: Dict[int, str] = {}
            for idx in sorted_indices:
                elem = mol_final.GetAtomWithIdx(idx).GetSymbol()
                element_counters[elem] = element_counters.get(elem, 0) + 1
                label_by_idx[idx] = f"{elem}{element_counters[elem]}"
            canonical_labels = [label_by_idx.get(i, atom_names[i]) for i in range(n)]
        except Exception:
            pass

        return mol_final, atom_names, canonical_labels, canonical_smiles

    def plot_ligand_2d(self, ligand_atoms: pl.DataFrame,
                       contacts_df: Optional[pl.DataFrame] = None,
                       title: str = "Ligand 2D Structure",
                       output_file: Optional[str] = None,
                       size: Tuple[int, int] = (700, 500),
                       charge: Optional[int] = None,
                       infer_bond_orders: bool = True,
                       prebuilt_mol=None,
                       prebuilt_canonical_labels: Optional[List[str]] = None):
        """
        Generates a 2D depiction of the ligand using RDKit.  Atoms are
        labelled with **canonical SMILES-derived names** (e.g. C1, O2, N1)
        so they match the contacts bar chart and are consistent across
        different structure predictors.

        If contacts_df is provided, atoms are color-coded by the number of
        protein contacts (red = many, blue = few, gray = none).

        Args:
            ligand_atoms: DataFrame of ligand atoms (from one structure).
            contacts_df: Optional contacts DataFrame to color-code atoms.
                         Must contain a "canonical_atom" column.
            title: Plot title.
            output_file: Save image to file. If None, displays inline.
            size: Image dimensions (width, height) in pixels.
            charge: Total formal charge of the ligand (e.g. -3 for citrate).
            infer_bond_orders: If True, attempt to determine double/aromatic
                    bonds. If False, show only connectivity (all single bonds).
            prebuilt_mol: Optional pre-built RDKit Mol from build_ligand_mol().
                    Avoids a redundant rebuild and guarantees the same
                    canonical labels used by the bar chart.
            prebuilt_canonical_labels: Optional canonical labels matching
                    prebuilt_mol atom order.  Must be provided together
                    with prebuilt_mol for consistency.

        Returns:
            PNG data as bytes (or None if RDKit unavailable).
        """
        if not self._rdkit_available():
            print("RDKit is not installed. Install with: pip install rdkit")
            print("  2D ligand depiction requires RDKit.")
            return None

        from rdkit import Chem
        from rdkit.Chem import AllChem, Draw
        from rdkit.Chem.Draw import rdMolDraw2D

        # Reuse pre-built molecule + labels when available (single source
        # of truth shared with the bar chart) to guarantee consistency.
        if prebuilt_mol is not None and prebuilt_canonical_labels is not None:
            mol = prebuilt_mol
            canonical_labels = prebuilt_canonical_labels
        else:
            mol, _pdb_names, canonical_labels, canonical_smiles = self.build_ligand_mol(
                ligand_atoms, charge=charge, infer_bond_orders=infer_bond_orders
            )
            if mol is None:
                print("Could not build ligand molecule.")
                return None
            if canonical_smiles:
                print(f"  Canonical SMILES: {canonical_smiles}")

        # Compute 2D coordinates for clean layout
        mol_2d = Chem.RWMol(mol)
        AllChem.Compute2DCoords(mol_2d)
        mol_2d = mol_2d.GetMol()

        # Label each atom with its canonical label (e.g. "C1", "O2")
        for idx, label in enumerate(canonical_labels):
            if idx < mol_2d.GetNumAtoms():
                mol_2d.GetAtomWithIdx(idx).SetProp("atomNote", label)

        # Build highlight colors if contact data is available
        highlight_atoms = {}  # idx -> color tuple (r, g, b, a)
        highlight_radii = {}

        if contacts_df is not None and contacts_df.height > 0:
            # Use canonical_atom column if present, otherwise fall back to ligand_atom
            label_col = "canonical_atom" if "canonical_atom" in contacts_df.columns else "ligand_atom"
            contact_counts = (
                contacts_df.group_by(label_col)
                .agg(pl.len().alias("count"))
            )
            count_map = dict(zip(
                contact_counts.get_column(label_col).to_list(),
                contact_counts.get_column("count").to_list(),
            ))

            max_count = max(count_map.values()) if count_map else 1

            for idx, label in enumerate(canonical_labels):
                if idx >= mol_2d.GetNumAtoms():
                    break
                c = count_map.get(label, 0)
                if c > 0:
                    # Gradient: blue (few) → red (many)
                    frac = c / max_count
                    r = frac
                    b = 1.0 - frac
                    g = 0.2
                    highlight_atoms[idx] = (r, g, b, 0.4)
                    highlight_radii[idx] = 0.35 + 0.15 * frac

        # Draw
        drawer = rdMolDraw2D.MolDraw2DCairo(size[0], size[1])
        draw_opts = drawer.drawOptions()
        draw_opts.annotationFontScale = 0.3
        draw_opts.bondLineWidth = 2.0
        # Increase font size if the attribute exists (varies by RDKit version)
        for attr in ("baseFontSize", "minFontSize"):
            if hasattr(draw_opts, attr):
                try:
                    setattr(draw_opts, attr, 8)
                except Exception:
                    pass

        if highlight_atoms:
            atom_indices = list(highlight_atoms.keys())
            atom_colors = highlight_atoms
            atom_radii = highlight_radii
            drawer.DrawMolecule(
                mol_2d,
                highlightAtoms=atom_indices,
                highlightAtomColors=atom_colors,
                highlightAtomRadii=atom_radii,
                highlightBonds=[],
            )
        else:
            drawer.DrawMolecule(mol_2d)

        drawer.FinishDrawing()
        png_data = drawer.GetDrawingText()

        if output_file:
            with open(output_file, "wb") as f:
                f.write(png_data)
            print(f"  2D ligand structure saved to {output_file}")
        else:
            # Display inline (works in Jupyter notebooks)
            try:
                from IPython.display import display, Image as IPImage
                display(IPImage(data=png_data))
            except ImportError:
                # Not in a notebook — save to temp and inform user
                import tempfile, os
                tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
                tmp.write(png_data)
                tmp.close()
                print(f"  2D ligand structure saved to {tmp.name}")

        return png_data

find_binding_residues(backbone_df, ligand_df, ligand_name, distance_cutoff=5.0)

Identifies residues within a cutoff distance of a specific ligand. Uses CA atoms from backbone for fast residue-level proximity.

Source code in src/sicifus/analysis.py
def find_binding_residues(self, backbone_df: pl.DataFrame, ligand_df: pl.DataFrame, 
                          ligand_name: str, distance_cutoff: float = 5.0) -> pl.DataFrame:
    """
    Identifies residues within a cutoff distance of a specific ligand.
    Uses CA atoms from backbone for fast residue-level proximity.
    """
    target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
    if target_ligand.height == 0:
        return pl.DataFrame()

    from scipy.spatial.distance import cdist
    backbone_coords = backbone_df.select(["x", "y", "z"]).to_numpy()
    ligand_coords = target_ligand.select(["x", "y", "z"]).to_numpy()

    dists = cdist(backbone_coords, ligand_coords)
    min_dists = np.min(dists, axis=1)
    mask = min_dists < distance_cutoff
    binding_residues = backbone_df.filter(mask)
    return binding_residues.unique(subset=["chain", "residue_number", "residue_name"])

plot_binding_histogram(residues_list, title='Ligand Binding Residue Distribution', output_file=None)

Plots a histogram of binding residue types.

Source code in src/sicifus/analysis.py
def plot_binding_histogram(self, residues_list: List[str], title: str = "Ligand Binding Residue Distribution",
                           output_file: Optional[str] = None):
    """Plots a histogram of binding residue types."""
    from collections import Counter
    counts = Counter(residues_list)

    if not counts:
        print("No binding residues to plot.")
        return

    labels, values = zip(*counts.most_common())

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.bar(labels, values, edgecolor="black", alpha=0.8)
    ax.set_title(title, fontsize=13)
    ax.set_xlabel("Residue Type", fontsize=11)
    ax.set_ylabel("Frequency", fontsize=11)
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

detect_pi_stacking(all_atom_df, ligand_df, ligand_name)

Detects pi-stacking interactions between protein aromatic residues and aromatic rings in the specified ligand.

Returns a DataFrame with columns

protein_chain, protein_residue, protein_resname, ligand_ring_atoms, interaction_type, distance, angle

Source code in src/sicifus/analysis.py
def detect_pi_stacking(self, all_atom_df: pl.DataFrame, ligand_df: pl.DataFrame,
                       ligand_name: str) -> pl.DataFrame:
    """
    Detects pi-stacking interactions between protein aromatic residues
    and aromatic rings in the specified ligand.

    Returns a DataFrame with columns:
      protein_chain, protein_residue, protein_resname, 
      ligand_ring_atoms, interaction_type, distance, angle
    """
    target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
    if target_ligand.height == 0:
        return pl.DataFrame()

    protein_rings = self._get_protein_rings(all_atom_df)
    ligand_rings = self._detect_ligand_rings(target_ligand)

    if not protein_rings or not ligand_rings:
        return pl.DataFrame()

    interactions = []
    for pr in protein_rings:
        for lr in ligand_rings:
            interaction = self._classify_pi_interaction(
                pr["centroid"], pr["normal"],
                lr["centroid"], lr["normal"],
            )
            if interaction is not None:
                vec = lr["centroid"] - pr["centroid"]
                d = float(np.linalg.norm(vec))
                cos_a = abs(np.dot(pr["normal"], lr["normal"]))
                angle = float(np.degrees(np.arccos(np.clip(cos_a, 0.0, 1.0))))
                interactions.append({
                    "protein_chain": pr["chain"],
                    "protein_residue": pr["residue_number"],
                    "protein_resname": pr["residue_name"],
                    "ligand_ring_atoms": ",".join(lr["atom_names"]),
                    "interaction_type": interaction,
                    "distance": round(d, 2),
                    "angle": round(angle, 1),
                })

    if not interactions:
        return pl.DataFrame()
    return pl.DataFrame(interactions)

plot_pi_stacking(interactions_list, title='Pi-Stacking Interactions', output_file=None)

Plots a grouped bar chart of pi-stacking interaction types broken down by interaction type and protein residue type.

Source code in src/sicifus/analysis.py
def plot_pi_stacking(self, interactions_list: List[Dict],
                     title: str = "Pi-Stacking Interactions",
                     output_file: Optional[str] = None):
    """
    Plots a grouped bar chart of pi-stacking interaction types
    broken down by interaction type and protein residue type.
    """
    from collections import Counter

    if not interactions_list:
        print("No pi-stacking interactions to plot.")
        return

    type_counts = Counter()
    residue_type_counts: Dict[str, Counter] = {
        "sandwich": Counter(), "parallel_displaced": Counter(), "t_shaped": Counter()
    }

    for ix in interactions_list:
        itype = ix["interaction_type"]
        resname = ix["protein_resname"]
        type_counts[itype] += 1
        if itype in residue_type_counts:
            residue_type_counts[itype][resname] += 1

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: overall counts by interaction type
    type_labels = ["sandwich", "parallel_displaced", "t_shaped"]
    type_colors = {"sandwich": "#2196F3", "parallel_displaced": "#FF9800", "t_shaped": "#4CAF50"}
    counts = [type_counts.get(t, 0) for t in type_labels]
    display_labels = ["Sandwich", "Parallel\nDisplaced", "T-Shaped"]
    bars = axes[0].bar(display_labels, counts, 
                       color=[type_colors[t] for t in type_labels], edgecolor="black", alpha=0.85)
    axes[0].set_ylabel("Count", fontsize=11)
    axes[0].set_title("By Interaction Type", fontsize=12)
    for bar, c in zip(bars, counts):
        if c > 0:
            axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3, 
                       str(c), ha="center", fontsize=10)

    # Right: breakdown by residue type
    all_resnames = sorted(set(r for c in residue_type_counts.values() for r in c))
    x = np.arange(len(all_resnames))
    width = 0.25
    for i, (itype, label) in enumerate(zip(type_labels, display_labels)):
        vals = [residue_type_counts[itype].get(r, 0) for r in all_resnames]
        axes[1].bar(x + i * width, vals, width, label=label.replace("\n", " "),
                   color=type_colors[itype], edgecolor="black", alpha=0.85)
    axes[1].set_xticks(x + width)
    axes[1].set_xticklabels(all_resnames)
    axes[1].set_ylabel("Count", fontsize=11)
    axes[1].set_title("By Residue Type", fontsize=12)
    axes[1].legend(fontsize=9)

    fig.suptitle(title, fontsize=14, fontweight="bold")
    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

find_ligand_atom_contacts(all_atom_df, ligand_df, ligand_name, distance_cutoff=3.3)

Identifies atom-level contacts between protein atoms and individual ligand atoms. Default cutoff of 3.3 Ã… targets hydrogen-bond-like interactions (user can adjust).

Returns a DataFrame with columns

ligand_atom, ligand_element, protein_chain, protein_residue, protein_resname, protein_atom, protein_element, distance

Source code in src/sicifus/analysis.py
def find_ligand_atom_contacts(self, all_atom_df: pl.DataFrame, ligand_df: pl.DataFrame,
                               ligand_name: str, distance_cutoff: float = 3.3) -> pl.DataFrame:
    """
    Identifies atom-level contacts between protein atoms and individual
    ligand atoms. Default cutoff of 3.3 Ã… targets hydrogen-bond-like
    interactions (user can adjust).

    Returns a DataFrame with columns:
      ligand_atom, ligand_element, protein_chain, protein_residue,
      protein_resname, protein_atom, protein_element, distance
    """
    from scipy.spatial.distance import cdist as cdist_fn

    target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
    if target_ligand.height == 0 or all_atom_df.height == 0:
        return pl.DataFrame()

    prot_coords = all_atom_df.select(["x", "y", "z"]).to_numpy()
    lig_coords = target_ligand.select(["x", "y", "z"]).to_numpy()

    dists = cdist_fn(prot_coords, lig_coords)  # (P, L)

    contacts = []
    # For each ligand atom, find protein atoms within cutoff
    for lig_idx in range(lig_coords.shape[0]):
        close_mask = dists[:, lig_idx] < distance_cutoff
        close_indices = np.where(close_mask)[0]

        lig_row = target_ligand.row(lig_idx, named=True)

        for prot_idx in close_indices:
            prot_row = all_atom_df.row(int(prot_idx), named=True)
            contacts.append({
                "ligand_atom": lig_row["atom_name"],
                "ligand_element": lig_row["element"],
                "protein_chain": prot_row["chain"],
                "protein_residue": prot_row["residue_number"],
                "protein_resname": prot_row["residue_name"],
                "protein_atom": prot_row["atom_name"],
                "protein_element": prot_row["element"],
                "distance": round(float(dists[prot_idx, lig_idx]), 2),
            })

    if not contacts:
        return pl.DataFrame()
    return pl.DataFrame(contacts)

plot_ligand_contacts(contacts_df, title='Ligand Atom Contacts', output_file=None)

Plots a bar chart showing which ligand atoms form the most contacts, colored by the protein residue type they interact with.

Uses canonical atom labels (e.g. C1, O2) if available, otherwise falls back to PDB atom names.

Source code in src/sicifus/analysis.py
def plot_ligand_contacts(self, contacts_df: pl.DataFrame,
                         title: str = "Ligand Atom Contacts",
                         output_file: Optional[str] = None):
    """
    Plots a bar chart showing which ligand atoms form the most contacts,
    colored by the protein residue type they interact with.

    Uses canonical atom labels (e.g. C1, O2) if available, otherwise
    falls back to PDB atom names.
    """
    if contacts_df.height == 0:
        print("No contacts to plot.")
        return

    # Use canonical_atom column if available (consistent across predictors)
    label_col = "canonical_atom" if "canonical_atom" in contacts_df.columns else "ligand_atom"

    # Aggregate: for each ligand atom, count contacts by protein residue type
    agg = (
        contacts_df
        .group_by([label_col, "protein_resname"])
        .agg(pl.len().alias("count"))
        .sort("count", descending=True)
    )

    # Get all unique ligand atoms, sorted by total contacts
    atom_totals = (
        agg.group_by(label_col)
        .agg(pl.col("count").sum().alias("total"))
        .sort("total", descending=True)
    )
    lig_atoms = atom_totals.get_column(label_col).to_list()
    all_resnames = sorted(agg.get_column("protein_resname").unique().to_list())

    fig, ax = plt.subplots(figsize=(max(10, len(lig_atoms) * 0.8), 6))

    cmap = plt.cm.get_cmap("tab20", len(all_resnames))
    x = np.arange(len(lig_atoms))
    bottom = np.zeros(len(lig_atoms))

    for i, resname in enumerate(all_resnames):
        vals = []
        for atom in lig_atoms:
            row = agg.filter(
                (pl.col(label_col) == atom) & (pl.col("protein_resname") == resname)
            )
            vals.append(row.get_column("count").sum() if row.height > 0 else 0)
        vals = np.array(vals, dtype=float)
        color = cmap(i / max(len(all_resnames) - 1, 1))
        ax.bar(x, vals, bottom=bottom, label=resname, color=color, edgecolor="black", alpha=0.85)
        bottom += vals

    ax.set_xticks(x)
    ax.set_xticklabels(lig_atoms, rotation=45, ha="right", fontsize=9)
    ax.set_ylabel("Number of Contacts", fontsize=11)
    ax.set_xlabel("Ligand Atom (canonical)", fontsize=11)
    ax.set_title(title, fontsize=13)
    ax.legend(fontsize=7, ncol=3, title="Protein Residue", title_fontsize=8,
              bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

build_ligand_mol(ligand_atoms, charge=None, infer_bond_orders=True)

Builds an RDKit molecule from ligand atom 3D coordinates.

Uses canonical SMILES ordering to assign consistent atom labels that are independent of the input file's atom naming convention. This ensures the same atom always gets the same label regardless of which structure predictor generated the file.

Parameters:

Name Type Description Default
ligand_atoms DataFrame

DataFrame with x, y, z, element, atom_name columns.

required
charge Optional[int]

Total formal charge of the ligand. Helps RDKit infer correct bond orders / protonation state. For example, citrate at pH 7 is typically -3. If None, RDKit guesses.

None
infer_bond_orders bool

If True, attempt to determine double/aromatic bonds from 3D geometry. If False, only connectivity (single bonds) is determined — safer when the protonation state is unknown.

True

Returns:

Name Type Description

(mol, pdb_atom_names, canonical_labels, canonical_smiles)

or (None, None, None, None) if RDKit is not installed.

canonical_labels

list of str like ["C1", "C2", "O1", "O2", ...] numbered per-element in canonical SMILES traversal order.

Source code in src/sicifus/analysis.py
def build_ligand_mol(self, ligand_atoms: pl.DataFrame, charge: Optional[int] = None,
                     infer_bond_orders: bool = True):
    """
    Builds an RDKit molecule from ligand atom 3D coordinates.

    Uses canonical SMILES ordering to assign consistent atom labels
    that are independent of the input file's atom naming convention.
    This ensures the same atom always gets the same label regardless
    of which structure predictor generated the file.

    Args:
        ligand_atoms: DataFrame with x, y, z, element, atom_name columns.
        charge: Total formal charge of the ligand. Helps RDKit infer
                correct bond orders / protonation state.  For example,
                citrate at pH 7 is typically -3. If None, RDKit guesses.
        infer_bond_orders: If True, attempt to determine double/aromatic
                bonds from 3D geometry.  If False, only connectivity
                (single bonds) is determined — safer when the protonation
                state is unknown.

    Returns:
        (mol, pdb_atom_names, canonical_labels, canonical_smiles)
        or (None, None, None, None) if RDKit is not installed.

        canonical_labels: list of str like ["C1", "C2", "O1", "O2", ...]
            numbered per-element in canonical SMILES traversal order.
    """
    if not self._rdkit_available():
        return None, None, None, None

    from rdkit import Chem
    from rdkit.Chem import AllChem
    from rdkit.Geometry import Point3D

    coords = ligand_atoms.select(["x", "y", "z"]).to_numpy()
    elements = ligand_atoms.get_column("element").to_list()
    atom_names = ligand_atoms.get_column("atom_name").to_list()

    if len(coords) == 0:
        return None, None, None, None

    # Build editable molecule with 3D conformer
    mol = Chem.RWMol()
    conf = Chem.Conformer(len(coords))

    for i, (elem, coord) in enumerate(zip(elements, coords)):
        atom = Chem.Atom(elem)
        idx = mol.AddAtom(atom)
        conf.SetAtomPosition(idx, Point3D(float(coord[0]), float(coord[1]), float(coord[2])))

    mol.AddConformer(conf, assignId=True)

    # Step 1: Determine connectivity (which atoms are bonded)
    # Step 2 (optional): Determine bond orders (single / double / aromatic)
    connectivity_ok = False
    try:
        from rdkit.Chem import rdDetermineBonds
        rdDetermineBonds.DetermineConnectivity(mol)
        connectivity_ok = True

        if infer_bond_orders:
            try:
                if charge is not None:
                    rdDetermineBonds.DetermineBondOrders(mol, charge=charge)
                else:
                    rdDetermineBonds.DetermineBondOrders(mol)
            except Exception:
                pass
    except (ImportError, Exception):
        pass

    if not connectivity_ok:
        from scipy.spatial.distance import cdist as cdist_fn
        dists = cdist_fn(coords, coords)
        for i in range(len(coords)):
            for j in range(i + 1, len(coords)):
                if 0.8 < dists[i, j] < 1.85:
                    mol.AddBond(i, j, Chem.BondType.SINGLE)

    try:
        Chem.SanitizeMol(mol)
    except Exception:
        pass

    mol_final = mol.GetMol()

    # ── Canonical atom labelling ─────────────────────────────────
    # CanonicalRankAtoms gives each atom a unique rank based on the
    # canonical SMILES traversal — identical molecular graphs always
    # produce the same ranking, regardless of input atom order/names.
    canonical_smiles = None
    canonical_labels = list(atom_names)  # fallback to PDB names
    try:
        canonical_smiles = Chem.MolToSmiles(mol_final)
        ranks = Chem.CanonicalRankAtoms(mol_final)
        # Build per-element numbering in canonical order:
        #   rank 0 gets assigned first, then rank 1, etc.
        #   e.g. C1, C2, C3, O1, O2, N1, ...
        n = mol_final.GetNumAtoms()
        # Sort atom indices by their canonical rank
        sorted_indices = sorted(range(n), key=lambda i: ranks[i])
        element_counters: Dict[str, int] = {}
        label_by_idx: Dict[int, str] = {}
        for idx in sorted_indices:
            elem = mol_final.GetAtomWithIdx(idx).GetSymbol()
            element_counters[elem] = element_counters.get(elem, 0) + 1
            label_by_idx[idx] = f"{elem}{element_counters[elem]}"
        canonical_labels = [label_by_idx.get(i, atom_names[i]) for i in range(n)]
    except Exception:
        pass

    return mol_final, atom_names, canonical_labels, canonical_smiles

plot_ligand_2d(ligand_atoms, contacts_df=None, title='Ligand 2D Structure', output_file=None, size=(700, 500), charge=None, infer_bond_orders=True, prebuilt_mol=None, prebuilt_canonical_labels=None)

Generates a 2D depiction of the ligand using RDKit. Atoms are labelled with canonical SMILES-derived names (e.g. C1, O2, N1) so they match the contacts bar chart and are consistent across different structure predictors.

If contacts_df is provided, atoms are color-coded by the number of protein contacts (red = many, blue = few, gray = none).

Parameters:

Name Type Description Default
ligand_atoms DataFrame

DataFrame of ligand atoms (from one structure).

required
contacts_df Optional[DataFrame]

Optional contacts DataFrame to color-code atoms. Must contain a "canonical_atom" column.

None
title str

Plot title.

'Ligand 2D Structure'
output_file Optional[str]

Save image to file. If None, displays inline.

None
size Tuple[int, int]

Image dimensions (width, height) in pixels.

(700, 500)
charge Optional[int]

Total formal charge of the ligand (e.g. -3 for citrate).

None
infer_bond_orders bool

If True, attempt to determine double/aromatic bonds. If False, show only connectivity (all single bonds).

True
prebuilt_mol

Optional pre-built RDKit Mol from build_ligand_mol(). Avoids a redundant rebuild and guarantees the same canonical labels used by the bar chart.

None
prebuilt_canonical_labels Optional[List[str]]

Optional canonical labels matching prebuilt_mol atom order. Must be provided together with prebuilt_mol for consistency.

None

Returns:

Type Description

PNG data as bytes (or None if RDKit unavailable).

Source code in src/sicifus/analysis.py
def plot_ligand_2d(self, ligand_atoms: pl.DataFrame,
                   contacts_df: Optional[pl.DataFrame] = None,
                   title: str = "Ligand 2D Structure",
                   output_file: Optional[str] = None,
                   size: Tuple[int, int] = (700, 500),
                   charge: Optional[int] = None,
                   infer_bond_orders: bool = True,
                   prebuilt_mol=None,
                   prebuilt_canonical_labels: Optional[List[str]] = None):
    """
    Generates a 2D depiction of the ligand using RDKit.  Atoms are
    labelled with **canonical SMILES-derived names** (e.g. C1, O2, N1)
    so they match the contacts bar chart and are consistent across
    different structure predictors.

    If contacts_df is provided, atoms are color-coded by the number of
    protein contacts (red = many, blue = few, gray = none).

    Args:
        ligand_atoms: DataFrame of ligand atoms (from one structure).
        contacts_df: Optional contacts DataFrame to color-code atoms.
                     Must contain a "canonical_atom" column.
        title: Plot title.
        output_file: Save image to file. If None, displays inline.
        size: Image dimensions (width, height) in pixels.
        charge: Total formal charge of the ligand (e.g. -3 for citrate).
        infer_bond_orders: If True, attempt to determine double/aromatic
                bonds. If False, show only connectivity (all single bonds).
        prebuilt_mol: Optional pre-built RDKit Mol from build_ligand_mol().
                Avoids a redundant rebuild and guarantees the same
                canonical labels used by the bar chart.
        prebuilt_canonical_labels: Optional canonical labels matching
                prebuilt_mol atom order.  Must be provided together
                with prebuilt_mol for consistency.

    Returns:
        PNG data as bytes (or None if RDKit unavailable).
    """
    if not self._rdkit_available():
        print("RDKit is not installed. Install with: pip install rdkit")
        print("  2D ligand depiction requires RDKit.")
        return None

    from rdkit import Chem
    from rdkit.Chem import AllChem, Draw
    from rdkit.Chem.Draw import rdMolDraw2D

    # Reuse pre-built molecule + labels when available (single source
    # of truth shared with the bar chart) to guarantee consistency.
    if prebuilt_mol is not None and prebuilt_canonical_labels is not None:
        mol = prebuilt_mol
        canonical_labels = prebuilt_canonical_labels
    else:
        mol, _pdb_names, canonical_labels, canonical_smiles = self.build_ligand_mol(
            ligand_atoms, charge=charge, infer_bond_orders=infer_bond_orders
        )
        if mol is None:
            print("Could not build ligand molecule.")
            return None
        if canonical_smiles:
            print(f"  Canonical SMILES: {canonical_smiles}")

    # Compute 2D coordinates for clean layout
    mol_2d = Chem.RWMol(mol)
    AllChem.Compute2DCoords(mol_2d)
    mol_2d = mol_2d.GetMol()

    # Label each atom with its canonical label (e.g. "C1", "O2")
    for idx, label in enumerate(canonical_labels):
        if idx < mol_2d.GetNumAtoms():
            mol_2d.GetAtomWithIdx(idx).SetProp("atomNote", label)

    # Build highlight colors if contact data is available
    highlight_atoms = {}  # idx -> color tuple (r, g, b, a)
    highlight_radii = {}

    if contacts_df is not None and contacts_df.height > 0:
        # Use canonical_atom column if present, otherwise fall back to ligand_atom
        label_col = "canonical_atom" if "canonical_atom" in contacts_df.columns else "ligand_atom"
        contact_counts = (
            contacts_df.group_by(label_col)
            .agg(pl.len().alias("count"))
        )
        count_map = dict(zip(
            contact_counts.get_column(label_col).to_list(),
            contact_counts.get_column("count").to_list(),
        ))

        max_count = max(count_map.values()) if count_map else 1

        for idx, label in enumerate(canonical_labels):
            if idx >= mol_2d.GetNumAtoms():
                break
            c = count_map.get(label, 0)
            if c > 0:
                # Gradient: blue (few) → red (many)
                frac = c / max_count
                r = frac
                b = 1.0 - frac
                g = 0.2
                highlight_atoms[idx] = (r, g, b, 0.4)
                highlight_radii[idx] = 0.35 + 0.15 * frac

    # Draw
    drawer = rdMolDraw2D.MolDraw2DCairo(size[0], size[1])
    draw_opts = drawer.drawOptions()
    draw_opts.annotationFontScale = 0.3
    draw_opts.bondLineWidth = 2.0
    # Increase font size if the attribute exists (varies by RDKit version)
    for attr in ("baseFontSize", "minFontSize"):
        if hasattr(draw_opts, attr):
            try:
                setattr(draw_opts, attr, 8)
            except Exception:
                pass

    if highlight_atoms:
        atom_indices = list(highlight_atoms.keys())
        atom_colors = highlight_atoms
        atom_radii = highlight_radii
        drawer.DrawMolecule(
            mol_2d,
            highlightAtoms=atom_indices,
            highlightAtomColors=atom_colors,
            highlightAtomRadii=atom_radii,
            highlightBonds=[],
        )
    else:
        drawer.DrawMolecule(mol_2d)

    drawer.FinishDrawing()
    png_data = drawer.GetDrawingText()

    if output_file:
        with open(output_file, "wb") as f:
            f.write(png_data)
        print(f"  2D ligand structure saved to {output_file}")
    else:
        # Display inline (works in Jupyter notebooks)
        try:
            from IPython.display import display, Image as IPImage
            display(IPImage(data=png_data))
        except ImportError:
            # Not in a notebook — save to temp and inform user
            import tempfile, os
            tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
            tmp.write(png_data)
            tmp.close()
            print(f"  2D ligand structure saved to {tmp.name}")

    return png_data

get_pocket_residues(all_atom_df, ligand_df, ligand_name, distance_cutoff=8.0)

Identifies all unique residues within the specified distance cutoff of any atom in the ligand. Uses all-atom coordinates for accuracy.

Returns a list of residue names (e.g. ["ALA", "HIS", ...]) found in the pocket.

Source code in src/sicifus/analysis.py
def get_pocket_residues(self, all_atom_df: pl.DataFrame, ligand_df: pl.DataFrame,
                        ligand_name: str, distance_cutoff: float = 8.0) -> List[str]:
    """
    Identifies all unique residues within the specified distance cutoff
    of any atom in the ligand. Uses all-atom coordinates for accuracy.

    Returns a list of residue names (e.g. ["ALA", "HIS", ...]) found in the pocket.
    """
    from scipy.spatial.distance import cdist

    target_ligand = ligand_df.filter(pl.col("residue_name") == ligand_name)
    if target_ligand.height == 0 or all_atom_df.height == 0:
        return []

    prot_coords = all_atom_df.select(["x", "y", "z"]).to_numpy()
    lig_coords = target_ligand.select(["x", "y", "z"]).to_numpy()

    # Calculate distances between all protein atoms and all ligand atoms
    dists = cdist(prot_coords, lig_coords)

    # Find protein atoms that are close to ANY ligand atom
    min_dists = np.min(dists, axis=1)
    mask = min_dists < distance_cutoff

    # Filter protein atoms
    pocket_atoms = all_atom_df.filter(mask)

    if pocket_atoms.height == 0:
        return []

    # Get unique residues (chain + number + name)
    unique_residues = pocket_atoms.unique(subset=["chain", "residue_number", "residue_name"])

    # Return just the residue names
    return unique_residues.get_column("residue_name").to_list()

plot_binding_pocket_composition(residue_counts, title='Binding Pocket Composition', output_file=None)

Plots a histogram of residue types found in the binding pocket. Ensures all 20 standard amino acids are represented on the X-axis.

Source code in src/sicifus/analysis.py
def plot_binding_pocket_composition(self, residue_counts: Dict[str, int],
                                    title: str = "Binding Pocket Composition",
                                    output_file: Optional[str] = None):
    """
    Plots a histogram of residue types found in the binding pocket.
    Ensures all 20 standard amino acids are represented on the X-axis.
    """
    # Standard 20 amino acids
    standard_aa = [
        "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
        "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"
    ]

    # Separate standard vs non-standard counts
    standard_counts = {aa: residue_counts.get(aa, 0) for aa in standard_aa}
    non_standard_counts = {k: v for k, v in residue_counts.items() if k not in standard_aa}

    # Combine for plotting: standard first (alphabetical), then non-standard (sorted by count)
    plot_labels = standard_aa + sorted(non_standard_counts.keys(), key=lambda k: non_standard_counts[k], reverse=True)
    plot_values = [standard_counts.get(l, non_standard_counts.get(l, 0)) for l in plot_labels]

    # Filter out non-standard with 0 counts (shouldn't happen based on logic but good safety)
    # Keep all standard even if 0
    final_labels = []
    final_values = []
    for l, v in zip(plot_labels, plot_values):
        if l in standard_aa or v > 0:
            final_labels.append(l)
            final_values.append(v)

    fig, ax = plt.subplots(figsize=(max(10, len(final_labels) * 0.4), 6))
    bars = ax.bar(final_labels, final_values, edgecolor="black", alpha=0.8, color="#4CAF50")

    ax.set_title(title, fontsize=14)
    ax.set_xlabel("Residue Type", fontsize=12)
    ax.set_ylabel("Frequency (Count)", fontsize=12)
    plt.xticks(rotation=45, ha="right")

    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(height)}',
                    ha='center', va='bottom', fontsize=9)

    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

CIF Loader

sicifus.io.CIFLoader

Handles ingestion of CIF files into Polars DataFrames.

Source code in src/sicifus/io.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
class CIFLoader:
    """
    Handles ingestion of CIF files into Polars DataFrames.
    """

    def __init__(self):
        pass

    def ingest_folder(self, input_folder: str, output_folder: str, batch_size: int = 100, 
                      file_extension: str = "cif", protonate: bool = False):
        """
        Ingests all structure files in a folder and saves them as a partitioned Parquet dataset.

        Args:
            input_folder: Path to the folder containing structure files.
            output_folder: Path to the folder where Parquet files will be saved.
            batch_size: Number of structures to process before writing a partition.
            file_extension: Extension of files to ingest (e.g., "cif" or "pdb").
            protonate: If True, uses PDBFixer (OpenMM) to add hydrogens to the structure 
                       before parsing. This is slower but ensures consistent protonation 
                       for energy calculations.
        """
        input_path = Path(input_folder)
        output_path = Path(output_folder)
        output_path.mkdir(parents=True, exist_ok=True)

        # Create subdirectories for backbone, heavy_atoms, hydrogens, and ligands
        backbone_dir = output_path / "backbone"
        heavy_atom_dir = output_path / "heavy_atoms"
        hydrogens_dir = output_path / "hydrogens"
        ligands_dir = output_path / "ligands"

        backbone_dir.mkdir(exist_ok=True)
        heavy_atom_dir.mkdir(exist_ok=True)
        hydrogens_dir.mkdir(exist_ok=True)
        ligands_dir.mkdir(exist_ok=True)

        # Handle both .ext and .ext.gz
        files = list(input_path.glob(f"*.{file_extension}")) + list(input_path.glob(f"*.{file_extension}.gz"))
        print(f"Found {len(files)} {file_extension} files.")

        backbone_buffer = []
        heavy_atom_buffer = []
        hydrogens_buffer = []
        ligand_buffer = []

        batch_counter = 0

        for i, file_path in enumerate(files):
            try:
                # Parse structure (optionally protonating first)
                backbone_df, heavy_atom_df, hydrogens_df, ligand_df = self._parse_structure(file_path, protonate=protonate)

                if backbone_df is not None:
                    backbone_buffer.append(backbone_df)
                if heavy_atom_df is not None:
                    heavy_atom_buffer.append(heavy_atom_df)
                if hydrogens_df is not None:
                    hydrogens_buffer.append(hydrogens_df)
                if ligand_df is not None:
                    ligand_buffer.append(ligand_df)

                # Write batch if buffer is full or it's the last file
                if (len(backbone_buffer) >= batch_size) or (i == len(files) - 1):
                    self._write_batch(backbone_buffer, backbone_dir, batch_counter)
                    self._write_batch(heavy_atom_buffer, heavy_atom_dir, batch_counter)
                    self._write_batch(hydrogens_buffer, hydrogens_dir, batch_counter)
                    self._write_batch(ligand_buffer, ligands_dir, batch_counter)

                    backbone_buffer = []
                    heavy_atom_buffer = []
                    hydrogens_buffer = []
                    ligand_buffer = []
                    print(f"Processed {i + 1}/{len(files)} files.")
                    batch_counter += 1

            except Exception as e:
                print(f"Error processing {file_path}: {e}")

    def _df_to_pdb(self, df: pl.DataFrame) -> str:
        """
        Converts a DataFrame of atoms to a PDB formatted string.
        Ensures strict column alignment for OpenBabel compatibility.
        """
        lines = []
        for i, row in enumerate(df.iter_rows(named=True)):
            atom_name = str(row.get('atom_name', 'X')).strip()
            res_name = str(row.get('residue_name', 'UNK')).strip()[:3]
            chain_id = str(row.get('chain', 'A')).strip()[:1] # Truncate to 1 char
            res_seq = row.get('residue_number', 1)
            try:
                res_seq = int(res_seq)
            except:
                res_seq = 1

            x, y, z = row['x'], row['y'], row['z']
            elem = str(row.get('element', atom_name[0])).strip().upper()

            # Atom Name Alignment Logic (PDB Standard)
            if len(atom_name) >= 4:
                aname_fmt = f"{atom_name[:4]}"
            elif len(elem) == 2: # 2-letter element starts at 13
                aname_fmt = f"{atom_name:<4}"
            else: # 1-letter element starts at 14
                aname_fmt = f" {atom_name:<3}"

            line = (f"ATOM  {i+1:>5} {aname_fmt:<4} {res_name:<3} {chain_id:>1}{res_seq:>4}    "
                    f"{x:8.3f}{y:8.3f}{z:8.3f}  1.00  0.00           {elem:>2}")
            lines.append(line)
        return "\n".join(lines)

    def _add_hydrogens_to_ligand(self, ligand_df: pl.DataFrame, ligand_name: str) -> pl.DataFrame:
        """
        Uses Meeko (preferred), RDKit, or OpenBabel to add hydrogens to the ligand.
        Returns a new DataFrame with hydrogens added.
        """
        meeko_success = False
        new_df = None

        # Use the first row of original DF as template for metadata
        template = ligand_df.row(0, named=True)

        # --- Helper: build RDKit mol from DataFrame coordinates ---
        def _build_mol_from_df(df):
            from rdkit import Chem
            from rdkit.Chem import AllChem

            elements = df["element"].to_list()
            coords = df.select(["x", "y", "z"]).to_numpy()
            names = df["atom_name"].to_list()

            mol = Chem.RWMol()
            conf = Chem.Conformer(len(elements))

            for i, (elem, coord) in enumerate(zip(elements, coords)):
                atom = Chem.Atom(elem)
                idx = mol.AddAtom(atom)
                conf.SetAtomPosition(idx, (float(coord[0]), float(coord[1]), float(coord[2])))

            mol.AddConformer(conf)

            # Strategy 1: Full bond perception
            bonds_ok = False
            try:
                from rdkit.Chem import rdDetermineBonds
                rdDetermineBonds.DetermineConnectivity(mol)
                rdDetermineBonds.DetermineBondOrders(mol)
                bonds_ok = mol.GetNumBonds() > 0
            except Exception:
                pass

            # Strategy 2: Connectivity only
            if not bonds_ok:
                try:
                    from rdkit.Chem import rdDetermineBonds
                    mol2 = Chem.RWMol()
                    conf2 = Chem.Conformer(len(elements))
                    for i, (elem, coord) in enumerate(zip(elements, coords)):
                        atom = Chem.Atom(elem)
                        idx = mol2.AddAtom(atom)
                        conf2.SetAtomPosition(idx, (float(coord[0]), float(coord[1]), float(coord[2])))
                    mol2.AddConformer(conf2)
                    rdDetermineBonds.DetermineConnectivity(mol2, useHueckel=False)
                    if mol2.GetNumBonds() > 0:
                        mol = mol2
                        bonds_ok = True
                except Exception:
                    pass

            # Strategy 3: Distance-based fallback (0.8-1.85 Ã… = bonded)
            if not bonds_ok:
                from scipy.spatial.distance import cdist as cdist_fn
                dists = cdist_fn(coords, coords)
                for i in range(len(coords)):
                    for j in range(i + 1, len(coords)):
                        if 0.8 < dists[i, j] < 1.85:
                            mol.AddBond(i, j, Chem.BondType.SINGLE)
                bonds_ok = mol.GetNumBonds() > 0
                if bonds_ok:
                    try:
                        Chem.SanitizeMol(mol)
                    except Exception:
                        pass

            return mol.GetMol(), names, bonds_ok

        # Strip existing hydrogens first
        ligand_heavy = ligand_df.filter(pl.col("element") != "H")
        if ligand_heavy.height == 0:
            return ligand_df

        n_heavy = ligand_heavy.height

        # --- Helper to convert mol_h → DataFrame ---
        def _mol_to_df(mol_h, atom_names=None):
            conf_h = mol_h.GetConformer()
            new_rows = []
            heavy_idx = 0
            h_counter = 0
            for i, atom in enumerate(mol_h.GetAtoms()):
                pos = conf_h.GetAtomPosition(i)
                elem = atom.GetSymbol()
                # Preserve original names for heavy atoms, generate H1, H2... for hydrogens
                if atom.GetAtomicNum() > 1 and atom_names and heavy_idx < len(atom_names):
                    name = atom_names[heavy_idx]
                    heavy_idx += 1
                else:
                    h_counter += 1
                    name = f"H{h_counter}"
                new_rows.append({
                    "structure_id": template["structure_id"],
                    "model": template["model"],
                    "chain": template["chain"],
                    "residue_name": template["residue_name"],
                    "residue_number": template["residue_number"],
                    "atom_name": name,
                    "x": pos.x, "y": pos.y, "z": pos.z,
                    "b_factor": template["b_factor"],
                    "element": elem
                })
            return pl.DataFrame(new_rows)

        # 1. Try Meeko first
        try:
            from meeko import MoleculePreparation
            from rdkit import Chem

            mol, atom_names, bonds_ok = _build_mol_from_df(ligand_heavy)

            if not bonds_ok:
                raise ValueError("No bonds detected, cannot use Meeko")

            preparator = MoleculePreparation(merge_these_atom_types=())
            setups = preparator.prepare(mol)

            if setups:
                mol_h = setups[0].mol
                n_h_new = sum(1 for a in mol_h.GetAtoms() if a.GetAtomicNum() == 1)

                if n_h_new == 0:
                    raise ValueError("Meeko added no hydrogens")

                if mol_h.GetNumConformers() > 0:
                    positions = mol_h.GetConformer().GetPositions()
                    if not (np.all(positions == 0.0) or np.any(np.isnan(positions))):
                        meeko_success = True
                        new_df = _mol_to_df(mol_h, atom_names)

        except ImportError:
            pass
        except Exception as e:
            pass

        if meeko_success and new_df is not None:
            return new_df

        # 2. Try RDKit (Fallback)
        rdkit_success = False
        new_df = None

        try:
            from rdkit import Chem
            from rdkit.Chem import AllChem

            try:
                _ = mol.GetNumAtoms()
            except (NameError, AttributeError):
                mol, atom_names, bonds_ok = _build_mol_from_df(ligand_heavy)

            mol_h = Chem.AddHs(mol, addCoords=True)
            n_h_added = mol_h.GetNumAtoms() - mol.GetNumAtoms()

            if n_h_added == 0:
                raise ValueError("RDKit added no hydrogens")

            # Refine Hydrogen Positions
            try:
                if mol_h.GetNumConformers() > 0:
                    conf = mol_h.GetConformer()
                    coord_map = {}
                    for atom in mol_h.GetAtoms():
                        if atom.GetAtomicNum() > 1:
                            idx = atom.GetIdx()
                            pos = conf.GetAtomPosition(idx)
                            coord_map[idx] = pos
                    AllChem.EmbedMolecule(mol_h, coordMap=coord_map, forceTol=0.01, useRandomCoords=True)
            except Exception:
                pass

            conf_h = mol_h.GetConformer()
            positions = conf_h.GetPositions()

            if not (np.all(positions == 0.0) or np.any(np.isnan(positions))):
                rdkit_success = True
                new_df = _mol_to_df(mol_h, atom_names)

        except ImportError:
            pass
        except Exception as e:
            pass

        if rdkit_success and new_df is not None:
            return new_df

        # 3. Fallback to OpenBabel
        if shutil.which("obabel"):
            try:
                pdb_content = self._df_to_pdb(ligand_df)

                run_id = str(uuid.uuid4())[:8]
                temp_in = f"temp_lig_{run_id}.pdb"
                temp_out = f"temp_lig_{run_id}.xyz"

                with open(temp_in, "w") as f:
                    f.write(pdb_content)

                cmd = ["obabel", temp_in, "-O", temp_out, "-h", "-p", "7.4"]
                subprocess.run(cmd, capture_output=True, check=True)

                if os.path.exists(temp_out):
                    with open(temp_out, "r") as f:
                        lines = f.readlines()
                    try:
                        n_atoms = int(lines[0])
                        new_rows = []
                        for line in lines[2:]:
                            parts = line.split()
                            if len(parts) >= 4:
                                elem = parts[0]
                                x, y, z = float(parts[1]), float(parts[2]), float(parts[3])
                                new_rows.append({
                                    "structure_id": template["structure_id"],
                                    "model": template["model"],
                                    "chain": template["chain"],
                                    "residue_name": template["residue_name"],
                                    "residue_number": template["residue_number"],
                                    "atom_name": elem,
                                    "x": x, "y": y, "z": z,
                                    "b_factor": template["b_factor"],
                                    "element": elem
                                })
                        if len(new_rows) > 0:
                            try:
                                os.remove(temp_in)
                                os.remove(temp_out)
                            except: pass
                            return pl.DataFrame(new_rows)
                    except ValueError:
                        pass

                try:
                    if os.path.exists(temp_in): os.remove(temp_in)
                    if os.path.exists(temp_out): os.remove(temp_out)
                except: pass

            except Exception as e:
                try:
                    if os.path.exists(temp_in): os.remove(temp_in)
                    if os.path.exists(temp_out): os.remove(temp_out)
                except: pass

        return ligand_df

    def _parse_structure(self, file_path: Path, protonate: bool = False) -> Tuple[Optional[pl.DataFrame], Optional[pl.DataFrame], Optional[pl.DataFrame], Optional[pl.DataFrame]]:
        """
        Parses a single structure file (CIF or PDB) and extracts:
          - backbone: CA atoms only (fast alignment/RMSD)
          - heavy_atoms: All protein heavy atoms (contacts, pi-stacking)
          - hydrogens: Protein hydrogens only (energy scoring)
          - ligands: non-polymer, non-water atoms
        """
        temp_pdb = None
        try:
            parse_path = str(file_path)

            # --- PROTONATION STEP ---
            if protonate:
                try:
                    from pdbfixer import PDBFixer
                    from openmm.app import PDBFile
                    from openmm import Platform
                except ImportError:
                    print("Warning: PDBFixer/OpenMM not installed. Skipping protonation.")
                    # Fallback to normal parsing
                else:
                    try:
                        # PDBFixer can read PDB and PDBx/mmCIF
                        fixer = PDBFixer(filename=str(file_path))

                        # Apply standard fixes
                        fixer.findMissingResidues()
                        fixer.findMissingAtoms()
                        fixer.addMissingAtoms()
                        fixer.addMissingHydrogens(7.4)

                        # Write to temp PDB
                        import tempfile
                        fd, temp_pdb = tempfile.mkstemp(suffix=".pdb")
                        os.close(fd)

                        with open(temp_pdb, 'w') as f:
                            PDBFile.writeFile(fixer.topology, fixer.positions, f)

                        # Update parse path to the protonated PDB
                        parse_path = temp_pdb

                    except Exception as e:
                        print(f"  Protonation failed for {file_path.name}: {e}")
                        # Fallback to original file
                        if temp_pdb and os.path.exists(temp_pdb):
                            os.remove(temp_pdb)
                        temp_pdb = None

            # --- GEMMI PARSING ---
            # Use gemmi to parse the file (either original or protonated temp PDB)
            structure = gemmi.read_structure(parse_path)

            # Only remove hydrogens if we DIDN'T ask to protonate
            # If protonate=True, we want to keep them!
            if not protonate:
                structure.remove_hydrogens()

            structure_id = file_path.name.split('.')[0]

            backbone_data = []
            heavy_atom_data = []
            hydrogens_data = []
            ligand_data = []

            for model in structure:
                # Defensive check for model.name
                model_name = getattr(model, 'name', '1')
                if not isinstance(model_name, str):
                    model_name = str(model_name)

                for chain in model:
                    for residue in chain:
                        # Check if it's a polymer residue
                        res_info = gemmi.find_tabulated_residue(residue.name)
                        is_amino_acid = res_info.is_amino_acid()
                        is_water = res_info.is_water()

                        for atom in residue:
                            atom_data = {
                                "structure_id": structure_id,
                                "model": model_name,
                                "chain": chain.name,
                                "residue_name": residue.name,
                                "residue_number": str(residue.seqid),
                                "atom_name": atom.name,
                                "x": atom.pos.x,
                                "y": atom.pos.y,
                                "z": atom.pos.z,
                                "b_factor": atom.b_iso,
                                "element": atom.element.name
                            }

                            if is_amino_acid:
                                if atom.element.name == "H":
                                    hydrogens_data.append(atom_data)
                                else:
                                    heavy_atom_data.append(atom_data)
                                    # CA atoms also go to backbone (for fast alignment)
                                    if atom.name == "CA":
                                        backbone_data.append(atom_data)
                            elif not is_water:
                                ligand_data.append(atom_data)

            backbone_df = pl.DataFrame(backbone_data) if backbone_data else None
            heavy_atom_df = pl.DataFrame(heavy_atom_data) if heavy_atom_data else None
            hydrogens_df = pl.DataFrame(hydrogens_data) if hydrogens_data else None
            ligand_df = pl.DataFrame(ligand_data) if ligand_data else None

            # --- LIGAND PROTONATION (RDKit) ---
            if protonate and ligand_df is not None and ligand_df.height > 0:
                try:
                    # Process each ligand separately to add hydrogens
                    protonated_ligands = []
                    # Get unique identifiers for ligands
                    unique_ligands = ligand_df.unique(subset=["chain", "residue_number", "residue_name"])

                    for row in unique_ligands.iter_rows(named=True):
                        # Extract single ligand
                        sub_ligand = ligand_df.filter(
                            (pl.col("chain") == row["chain"]) & 
                            (pl.col("residue_number") == row["residue_number"]) &
                            (pl.col("residue_name") == row["residue_name"])
                        )

                        # Add hydrogens
                        try:
                            # Use helper method
                            sub_ligand_h = self._add_hydrogens_to_ligand(sub_ligand, row["residue_name"])
                            protonated_ligands.append(sub_ligand_h)
                        except Exception:
                            # Fallback to original if RDKit fails
                            protonated_ligands.append(sub_ligand)

                    if protonated_ligands:
                        ligand_df = pl.concat(protonated_ligands)

                except Exception as e:
                    print(f"  Ligand protonation failed for {file_path.name}: {e}")

            # Cast columns to ensure schema consistency
            schema = {
                "structure_id": pl.Utf8,
                "model": pl.Utf8,
                "chain": pl.Utf8,
                "residue_name": pl.Utf8,
                "residue_number": pl.Utf8,
                "atom_name": pl.Utf8,
                "x": pl.Float64,
                "y": pl.Float64,
                "z": pl.Float64,
                "b_factor": pl.Float64,
                "element": pl.Utf8
            }

            for df in [backbone_df, heavy_atom_df, hydrogens_df, ligand_df]:
                if df is not None:
                    # Only cast columns that exist
                    cast_cols = [pl.col(c).cast(t) for c, t in schema.items() if c in df.columns]
                    if cast_cols:
                        df = df.with_columns(cast_cols)

            return backbone_df, heavy_atom_df, hydrogens_df, ligand_df

        except Exception as e:
            print(f"Error parsing {file_path}: {e}")
            return None, None, None, None
        finally:
            # Cleanup temp file
            if temp_pdb and os.path.exists(temp_pdb):
                try:
                    os.remove(temp_pdb)
                except: pass

    def _write_batch(self, df_list: List[pl.DataFrame], output_dir: Path, batch_index: int):
        """
        Concatenates a list of DataFrames and writes to a Parquet file.
        """
        if not df_list:
            return

        batch_df = pl.concat(df_list)
        output_file = output_dir / f"part_{batch_index}.parquet"
        batch_df.write_parquet(output_file)

ingest_folder(input_folder, output_folder, batch_size=100, file_extension='cif', protonate=False)

Ingests all structure files in a folder and saves them as a partitioned Parquet dataset.

Parameters:

Name Type Description Default
input_folder str

Path to the folder containing structure files.

required
output_folder str

Path to the folder where Parquet files will be saved.

required
batch_size int

Number of structures to process before writing a partition.

100
file_extension str

Extension of files to ingest (e.g., "cif" or "pdb").

'cif'
protonate bool

If True, uses PDBFixer (OpenMM) to add hydrogens to the structure before parsing. This is slower but ensures consistent protonation for energy calculations.

False
Source code in src/sicifus/io.py
def ingest_folder(self, input_folder: str, output_folder: str, batch_size: int = 100, 
                  file_extension: str = "cif", protonate: bool = False):
    """
    Ingests all structure files in a folder and saves them as a partitioned Parquet dataset.

    Args:
        input_folder: Path to the folder containing structure files.
        output_folder: Path to the folder where Parquet files will be saved.
        batch_size: Number of structures to process before writing a partition.
        file_extension: Extension of files to ingest (e.g., "cif" or "pdb").
        protonate: If True, uses PDBFixer (OpenMM) to add hydrogens to the structure 
                   before parsing. This is slower but ensures consistent protonation 
                   for energy calculations.
    """
    input_path = Path(input_folder)
    output_path = Path(output_folder)
    output_path.mkdir(parents=True, exist_ok=True)

    # Create subdirectories for backbone, heavy_atoms, hydrogens, and ligands
    backbone_dir = output_path / "backbone"
    heavy_atom_dir = output_path / "heavy_atoms"
    hydrogens_dir = output_path / "hydrogens"
    ligands_dir = output_path / "ligands"

    backbone_dir.mkdir(exist_ok=True)
    heavy_atom_dir.mkdir(exist_ok=True)
    hydrogens_dir.mkdir(exist_ok=True)
    ligands_dir.mkdir(exist_ok=True)

    # Handle both .ext and .ext.gz
    files = list(input_path.glob(f"*.{file_extension}")) + list(input_path.glob(f"*.{file_extension}.gz"))
    print(f"Found {len(files)} {file_extension} files.")

    backbone_buffer = []
    heavy_atom_buffer = []
    hydrogens_buffer = []
    ligand_buffer = []

    batch_counter = 0

    for i, file_path in enumerate(files):
        try:
            # Parse structure (optionally protonating first)
            backbone_df, heavy_atom_df, hydrogens_df, ligand_df = self._parse_structure(file_path, protonate=protonate)

            if backbone_df is not None:
                backbone_buffer.append(backbone_df)
            if heavy_atom_df is not None:
                heavy_atom_buffer.append(heavy_atom_df)
            if hydrogens_df is not None:
                hydrogens_buffer.append(hydrogens_df)
            if ligand_df is not None:
                ligand_buffer.append(ligand_df)

            # Write batch if buffer is full or it's the last file
            if (len(backbone_buffer) >= batch_size) or (i == len(files) - 1):
                self._write_batch(backbone_buffer, backbone_dir, batch_counter)
                self._write_batch(heavy_atom_buffer, heavy_atom_dir, batch_counter)
                self._write_batch(hydrogens_buffer, hydrogens_dir, batch_counter)
                self._write_batch(ligand_buffer, ligands_dir, batch_counter)

                backbone_buffer = []
                heavy_atom_buffer = []
                hydrogens_buffer = []
                ligand_buffer = []
                print(f"Processed {i + 1}/{len(files)} files.")
                batch_counter += 1

        except Exception as e:
            print(f"Error processing {file_path}: {e}")