Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import copy 

2import re 

3import unittest 

4from datetime import datetime 

5from io import BytesIO, StringIO 

6from pathlib import Path 

7from typing import Dict, List, Optional, Union 

8from warnings import warn 

9 

10from Bio import Phylo as Phylo 

11from Bio.Align import MultipleSeqAlignment 

12from Bio.Phylo.BaseTree import Clade 

13from Bio.Phylo.BaseTree import Tree as BioTree 

14from dateutil.parser import parse 

15from pytest_html import extras 

16from treetime import GTR, TreeTime 

17from treetime.utils import DateConversion, datetime_from_numeric, numeric_date 

18 

19from ..utils import ( 

20 PhytestAssertion, 

21 PhytestObject, 

22 PhytestWarning, 

23 assert_or_warn, 

24 default_date_patterns, 

25) 

26 

27 

28class Tree(PhytestObject, BioTree): 

29 @classmethod 

30 def read(cls, tree_path, tree_format) -> 'Tree': 

31 tree = Phylo.read(tree_path, tree_format) 

32 return cls(root=tree.root, rooted=tree.rooted, id=tree.id, name=tree.name) 

33 

34 @classmethod 

35 def parse(cls, tree_path, tree_format) -> 'Tree': 

36 trees = Phylo.parse(tree_path, tree_format) 

37 return (cls(root=tree.root, rooted=tree.rooted, id=tree.id, name=tree.name) for tree in trees) 

38 

39 @classmethod 

40 def read_str(cls, tree_str: str, tree_format: str = "newick") -> 'Tree': 

41 data = StringIO(tree_str) 

42 return cls.read(data, tree_format) 

43 

44 @property 

45 def tips(self): 

46 return self.get_terminals() 

47 

48 def parse_tip_dates( 

49 self, 

50 *, 

51 patterns=None, 

52 date_format: Optional[str] = None, 

53 decimal_year: bool = False, 

54 ): 

55 patterns = patterns or default_date_patterns() 

56 if isinstance(patterns, str): 

57 patterns = [patterns] 

58 

59 dates = {} 

60 

61 compiled_patterns = [re.compile(pattern_string) for pattern_string in patterns] 

62 for tip in self.find_elements(terminal=True): 

63 for pattern in compiled_patterns: 

64 m = pattern.search(tip.name) 

65 if m: 

66 matched_str = m.group(0) 

67 if re.match(r"^\d+\.?\d*$", matched_str): 

68 date = datetime_from_numeric(float(matched_str)) 

69 else: 

70 date = parse(matched_str, date_format) 

71 

72 dates[tip.name] = date 

73 break 

74 

75 if decimal_year: 

76 dates = {key: numeric_date(value) for key, value in dates.items()} 

77 

78 return dates 

79 

80 def assert_number_of_tips( 

81 self, 

82 tips: Optional[int] = None, 

83 *, 

84 min: Optional[int] = None, 

85 max: Optional[int] = None, 

86 warning: bool = False, 

87 ): 

88 """ 

89 Asserts that the number of tips meets the specified criteria. 

90 

91 Args: 

92 tips (int, optional): If set, then number of tips must be equal to this value. Defaults to None. 

93 min (int, optional): If set, then number of tips must be equal to or greater than this value. Defaults to None. 

94 max (int, optional): If set, then number of tips must be equal to or less than this value. Defaults to None. 

95 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

96 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

97 """ 

98 number_of_tips = len(self.get_terminals()) 

99 if tips is not None: 

100 assert_or_warn( 

101 number_of_tips == tips, 

102 warning, 

103 f"The number of tips ({number_of_tips}) which is different from the required number of tips ({tips}).", 

104 ) 

105 if min is not None: 

106 assert_or_warn( 

107 number_of_tips >= min, 

108 warning, 

109 f"The number of tips ({number_of_tips}) is less than the minimum ({min}).", 

110 ) 

111 if max is not None: 

112 assert_or_warn( 

113 number_of_tips <= max, 

114 warning, 

115 f"The number of tips ({number_of_tips}) is greater than the maximum ({max}).", 

116 ) 

117 

118 def assert_unique_tips(self, *, warning: bool = False): 

119 """ 

120 Asserts that all the tip names are unique. 

121 

122 Args: 

123 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

124 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

125 

126 """ 

127 tip_names = [t.name for t in self.get_terminals()] 

128 assert_or_warn( 

129 len(tip_names) == len(set(tip_names)), 

130 warning, 

131 f"The tree contains {len(tip_names)} tips, however, {len(set(tip_names))} are unique.", 

132 ) 

133 

134 def assert_is_rooted(self, *, warning: bool = False): 

135 """ 

136 Asserts that the tree is rooted. 

137 

138 Args: 

139 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

140 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

141 """ 

142 assert_or_warn( 

143 self.rooted, 

144 warning, 

145 "The tree is not rooted.", 

146 ) 

147 

148 def assert_is_bifurcating(self, *, warning: bool = False): 

149 """ 

150 Asserts that the tree is bifurcating. 

151 

152 The root may have 3 descendents and still be considered part of a bifurcating tree, because it has no ancestor. 

153 

154 Args: 

155 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

156 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

157 """ 

158 assert_or_warn( 

159 self.is_bifurcating(), 

160 warning, 

161 "The tree is not bifurcating.", 

162 ) 

163 

164 def assert_is_monophyletic(self, tips: List[Clade], *, warning: bool = False): 

165 """ 

166 Asserts that the specified tips form a monophyletic group. 

167 

168 Args: 

169 tips (List[Clade]): List of terminal nodes (tips). 

170 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

171 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

172 """ 

173 assert_or_warn( 

174 self.is_monophyletic(tips), 

175 warning, 

176 f"The group \'{', '.join([tip.name for tip in tips])}\' is paraphyletic!", 

177 ) 

178 

179 def assert_branch_lengths( 

180 self, 

181 *, 

182 min: Optional[float] = None, 

183 max: Optional[float] = None, 

184 terminal: Optional[bool] = None, 

185 warning: bool = False, 

186 ): 

187 """ 

188 Asserts that the all brach lengths meet the specified criteria. 

189 

190 Args: 

191 min (float, optional): If set, then each brach length must be equal to or greater than this value. Defaults to None. 

192 max (float, optional): If set, then each brach length must be equal to or less than this value. Defaults to None. 

193 terminal (bool, optional): True searches for only terminal nodes, False excludes terminal nodes, and the default, None, 

194 searches both terminal and non-terminal nodes, as well as any tree elements lacking the is_terminal method. 

195 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

196 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

197 """ 

198 root, *nodes = self.find_clades(terminal=terminal) 

199 for node in nodes: 

200 print(node, node.branch_length) 

201 if min is not None: 

202 assert_or_warn( 

203 node.branch_length >= min, 

204 warning, 

205 f"An internal branch in the tree is less than the minimum ({min}).", 

206 ) 

207 if max is not None: 

208 assert_or_warn( 

209 node.branch_length <= max, 

210 warning, 

211 f"An internal branch in the tree is greater than the maximum ({max}).", 

212 ) 

213 

214 def assert_terminal_branch_lengths( 

215 self, 

216 *, 

217 min: Optional[float] = None, 

218 max: Optional[float] = None, 

219 warning: bool = False, 

220 ): 

221 """ 

222 Asserts that the terminal brach lengths meet the specified criteria. 

223 

224 Args: 

225 min (float, optional): If set, then each terminal brach length must be equal to or greater than this value. Defaults to None. 

226 max (float, optional): If set, then each terminal brach length must be equal to or less than this value. Defaults to None. 

227 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

228 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

229 """ 

230 self.assert_branch_lengths(min=min, max=max, terminal=True, warning=warning) 

231 

232 def assert_internal_branch_lengths( 

233 self, 

234 *, 

235 min: Optional[float] = None, 

236 max: Optional[float] = None, 

237 warning: bool = False, 

238 ): 

239 """ 

240 Asserts that the internal brach lengths meet the specified criteria. 

241 

242 Args: 

243 min (float, optional): If set, then each internal brach length must be equal to or greater than this value. Defaults to None. 

244 max (float, optional): If set, then each internal brach length must be equal to or less than this value. Defaults to None. 

245 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

246 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

247 """ 

248 self.assert_branch_lengths(min=min, max=max, terminal=False, warning=warning) 

249 

250 def assert_no_negatives( 

251 self, 

252 *, 

253 terminal: Optional[bool] = None, 

254 warning: bool = False, 

255 ): 

256 """ 

257 Asserts that there are no negative branches. 

258 

259 Args: 

260 terminal (bool, optional): True searches for only terminal nodes, False excludes terminal nodes, and the default, None, 

261 searches both terminal and non-terminal nodes, as well as any tree elements lacking the is_terminal method. 

262 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

263 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

264 """ 

265 self.assert_branch_lengths(min=0, terminal=terminal, warning=warning) 

266 

267 def assert_total_branch_length( 

268 self, 

269 length: Optional[float] = None, 

270 *, 

271 min: Optional[float] = None, 

272 max: Optional[float] = None, 

273 warning: bool = False, 

274 ): 

275 """ 

276 Asserts that the total brach length meets the specified criteria. 

277 

278 Args: 

279 length (float, optional): If set, then total brach length must be equal to this value. Defaults to None. 

280 min (float, optional): If set, then total brach length must be equal to or greater than this value. Defaults to None. 

281 max (float, optional): If set, then total brach length must be equal to or less than this value. Defaults to None. 

282 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

283 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

284 """ 

285 total_branch_length = self.total_branch_length() 

286 if length is not None: 

287 assert_or_warn( 

288 total_branch_length == length, 

289 warning, 

290 f"The total branch length ({total_branch_length}) is not equal to the required length ({length}).", 

291 ) 

292 if min is not None: 

293 assert_or_warn( 

294 total_branch_length >= min, 

295 warning, 

296 f"The total branch length ({total_branch_length}) is less than the minimum ({min}).", 

297 ) 

298 if max is not None: 

299 assert_or_warn( 

300 total_branch_length <= max, 

301 warning, 

302 f"The total branch length ({total_branch_length}) is greater than the maximum ({max}).", 

303 ) 

304 

305 def assert_tip_regex( 

306 self, 

307 patterns: Union[List[str], str], 

308 *, 

309 warning: bool = False, 

310 ): 

311 """ 

312 Asserts that all the tips match at least one of a list of regular expression patterns. 

313 

314 Args: 

315 patterns (Union[List[str], str]): The regex pattern(s) to match to. 

316 If a string, then every tip must match that pattern. 

317 If a list then each tip must match at least one pattern in the list. 

318 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

319 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

320 """ 

321 if isinstance(patterns, str): 

322 patterns = [patterns] 

323 

324 compiled_patterns = [re.compile(pattern_string) for pattern_string in patterns] 

325 

326 for tip in self.find_elements(terminal=True): 

327 matches = False 

328 for pattern in compiled_patterns: 

329 if pattern.search(tip.name): 

330 matches = True 

331 break 

332 assert_or_warn( 

333 matches, 

334 warning, 

335 f"Tip {tip.name} does not match any of the regex patterns in: '{patterns}'.", 

336 ) 

337 

338 def assert_tip_names(self, names: List[str], warning=False): 

339 """ 

340 Asserts that the tree tip names match the supplied names. 

341 

342 Args: 

343 names (List[str]): The names to match. 

344 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

345 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

346 """ 

347 tip_names = [t.name for t in self.get_terminals()] 

348 assert_or_warn( 

349 len(tip_names) == len(names), 

350 warning, 

351 f"The tree contains {len(tip_names)} tips, however, {len(names)} names were supplied.", 

352 ) 

353 diff = set(tip_names).difference(names) 

354 assert_or_warn( 

355 diff == set(), 

356 warning, 

357 f"There was a difference ({', '.join(diff)}) between the supplied names and tree tip names.", 

358 ) 

359 

360 def copy(self): 

361 """Makes a deep copy of this tree.""" 

362 new_copy = copy.deepcopy(self) 

363 return new_copy 

364 

365 def root_to_tip_regression( 

366 self, 

367 *, 

368 dates: Optional[Dict] = None, 

369 alignment: Optional[MultipleSeqAlignment] = None, 

370 sequence_length: Optional[int] = None, 

371 clock_filter: float = 3.0, 

372 gtr: Union[GTR, str] = 'JC69', 

373 root_method: str = 'least-squares', 

374 allow_negative_rate: bool = False, 

375 keep_root: bool = False, 

376 covariation: bool = False, 

377 ): 

378 """ 

379 Performs a root-to-tip regression to determine how clock-like a tree is. 

380 

381 Args: 

382 dates (Dict, optional): The tip dates as a dictionary with the tip name as the key and the date as the value. 

383 If not set, then it parses the tip dates to generate this dictionary using the `parse_tip_dates` method. 

384 alignment (MultipleSeqAlignment, optional): The alignment associated with this tree. Defaults to None. 

385 sequence_length (int, optional): The sequence length of the alignment. Defaults to None. 

386 clock_filter (float, optional): The number of interquartile ranges from regression beyond which to ignore. 

387 This provides a way to ignore tips that don't follow a loose clock. 

388 Defaults to 3.0. 

389 gtr (GTR, str, optional): The molecular evolution model. Defaults to 'JC69'. 

390 allow_negative_rate (bool, optional): Whether or not a negative clock rate is allowed. 

391 For trees with little temporal signal, it can be set to True to achieve essentially mid-point rooting. 

392 Defaults to False. 

393 keep_root (bool, optional): Keeps the current root of the tree. 

394 If False, then a new optimal root is sought. Defaults to False. 

395 root_method (str, optional): The method used to reroot the tree if `keep_root` is False. 

396 Valid choices are: 'min_dev', 'least-squares', and 'oldest'. 

397 Defaults to 'least-squares'. 

398 covariation (bool, optional): Accounts for covariation when estimating rates or rerooting. Defaults to False. 

399 """ 

400 

401 if covariation and (alignment is None and sequence_length is None): 

402 raise PhytestAssertion( 

403 "Cannot perform root-to-tip regression with `covariation` as True if no alignment of sequence length is provided." 

404 ) 

405 

406 dates = dates or self.parse_tip_dates() 

407 

408 # Convert datetimes to floats with decimal years if necessary 

409 dates = {name: numeric_date(date) if isinstance(date, datetime) else date for name, date in dates.items()} 

410 

411 regression = TreeTime( 

412 dates=dates, 

413 tree=self.copy(), 

414 aln=alignment, 

415 gtr=gtr, 

416 seq_len=sequence_length, 

417 ) 

418 

419 if clock_filter: 

420 bad_nodes = [node.name for node in regression.tree.get_terminals() if node.bad_branch] 

421 regression.clock_filter(n_iqd=clock_filter, reroot=root_method or 'least-squares') 

422 bad_nodes_after = [node.name for node in regression.tree.get_terminals() if node.bad_branch] 

423 if len(bad_nodes_after) > len(bad_nodes): 

424 warn( 

425 "The following leaves don't follow a loose clock and " 

426 "will be ignored in rate estimation:\n\t" + "\n\t".join(set(bad_nodes_after).difference(bad_nodes)), 

427 PhytestWarning, 

428 ) 

429 

430 if not keep_root: 

431 if covariation: # this requires branch length estimates 

432 regression.run(root="least-squares", max_iter=0, use_covariation=covariation) 

433 

434 assert root_method in ['min_dev', 'least-squares', 'oldest'] 

435 regression.reroot(root_method, force_positive=not allow_negative_rate) 

436 

437 regression.get_clock_model(covariation=covariation) 

438 return regression 

439 

440 def plot_root_to_tip( 

441 self, 

442 filename: Union[str, Path], 

443 *, 

444 format: Optional[str] = None, 

445 regression: Optional[TreeTime] = None, 

446 add_internal: bool = False, 

447 label: bool = True, 

448 ax=None, 

449 **kwargs, 

450 ): 

451 """ 

452 Plots a root-to-tip regression. 

453 

454 Args: 

455 filename (str, Path): The path to save the plot as an image. 

456 regression (TreeTime, optional): The root-to-tip regression for this tree. 

457 If None, then this regression is calculated using the `root_to_tip_regression` method. 

458 add_internal (bool): Whether or not to plot the internal node positions. Default: False. 

459 label (bool): Whether or not to label the points. Default: True. 

460 ax (matplotlib axes): Uses matplotlib axes if provided. Default: None. 

461 **kwargs: Keyword arguments for the `root_to_tip_regression` method. 

462 """ 

463 regression = regression or self.root_to_tip_regression(**kwargs) 

464 from matplotlib import pyplot as plt 

465 

466 regression.plot_root_to_tip(add_internal=add_internal, label=label, ax=ax) 

467 if isinstance(filename, Path): 

468 filename = str(filename) 

469 

470 plt.savefig(filename, format=format) 

471 

472 def assert_root_to_tip( 

473 self, 

474 *, 

475 regression: Optional[TreeTime] = None, 

476 min_r_squared: Optional[float] = None, 

477 min_rate: Optional[float] = None, 

478 max_rate: Optional[float] = None, 

479 min_root_date: Optional[float] = None, 

480 max_root_date: Optional[float] = None, 

481 valid_confidence: Optional[bool] = None, 

482 extra: Optional[List] = None, 

483 warning: bool = False, 

484 **kwargs, 

485 ): 

486 """ 

487 Checks inferred values from a root-to-tip regression. 

488 

489 Args: 

490 regression (TreeTime, optional): The root-to-tip regression for this tree. 

491 If None, then this regression is calculated using the `root_to_tip_regression` method. 

492 min_r_squared (float, optional): If set, then R^2 must be equal or greater than this value. Defaults to None. 

493 min_rate (float, optional): If set, then the clock rate must be equal or greater than this value. Defaults to None. 

494 max_rate (float, optional): If set, then the clock rate must be equal or less than this value. Defaults to None. 

495 min_root_date (float, optional): If set, then the interpolated root date must be equal or greater than this value. Defaults to None. 

496 max_root_date (float, optional): If set, then the interpolated root date must be equal or less than this value. Defaults to None. 

497 valid_confidence (bool, optional): Checks that the `valid_confidence` value in the regression is equal to this boolean value. 

498 Defaults to None which does not perform a check. 

499 warning (bool): If True, raise a warning instead of an exception. Defaults to False. 

500 This flag can be set by running this method with the prefix `warn_` instead of `assert_`. 

501 extra (List): The pytest-html extra fixture for adding in root-to-tip regression plot. 

502 **kwargs: Keyword arguments for the `root_to_tip_regression` method. 

503 """ 

504 regression = regression or self.root_to_tip_regression(**kwargs) 

505 clock_model = DateConversion.from_regression(regression.clock_model) 

506 root_date = clock_model.numdate_from_dist2root(0.0) 

507 

508 if extra is not None: 

509 f = StringIO() 

510 self.plot_root_to_tip(filename=f, format="svg", regression=regression) 

511 svg = f.getvalue() 

512 extra.append(extras.html(svg)) 

513 

514 if min_r_squared is not None: 

515 assert_or_warn( 

516 clock_model.r_val**2 >= min_r_squared, 

517 warning, 

518 f"The R-squared value from the root-to-tip regression '{clock_model.r_val**2}' " 

519 "is less than the minimum allowed R-squarred '{min_r_squared}'.", 

520 ) 

521 

522 if min_rate is not None: 

523 assert_or_warn( 

524 clock_model.clock_rate >= min_rate, 

525 warning, 

526 f"Inferred clock rate '{clock_model.clock_rate}' is less than the minimum allowed clock rate '{min_rate}'.", 

527 ) 

528 

529 if max_rate is not None: 

530 assert_or_warn( 

531 clock_model.clock_rate <= max_rate, 

532 warning, 

533 f"Inferred clock rate '{clock_model.clock_rate}' is greater than the maximum allowed clock rate '{max_rate}'.", 

534 ) 

535 

536 if min_root_date is not None: 

537 assert_or_warn( 

538 root_date >= min_root_date, 

539 warning, 

540 f"Inferred root date '{root_date}' is less than the minimum allowed root date '{min_root_date}'.", 

541 ) 

542 

543 if max_root_date is not None: 

544 assert_or_warn( 

545 root_date <= max_root_date, 

546 warning, 

547 f"Inferred root date '{root_date}' is greater than the maximum allowed root date: '{max_root_date}'.", 

548 ) 

549 

550 if valid_confidence is not None: 

551 assert_or_warn( 

552 clock_model.valid_confidence == valid_confidence, 

553 warning, 

554 f"The `clock_model.valid_confidence` variable is not {valid_confidence}.", 

555 )