Coverage for lmcat\lmcat.py: 71%
192 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-05 02:11 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-05 02:11 -0700
1import argparse
2import io
3import json
5# from dataclasses import dataclass, field
6from pathlib import Path
7import sys
9from lmcat.processing_pipeline import ProcessingPipeline
12# Handle Python 3.11+ vs older Python for TOML parsing
13try:
14 import tomllib
15except ImportError:
16 try:
17 import tomli as tomllib # type: ignore
18 except ImportError:
19 tomllib = None # type: ignore[assignment]
21import igittigitt # noqa: E402
23from muutils.json_serialize import (
24 SerializableDataclass,
25 serializable_dataclass,
26 serializable_field,
27)
28from muutils.misc import shorten_numerical_to_str # noqa: E402
31from lmcat.file_stats import FileStats, TokenizerWrapper, TreeEntry, TOKENIZERS_PRESENT
32from lmcat.processing_pipeline import OnMultipleProcessors
35@serializable_dataclass(kw_only=True)
36class LMCatConfig(SerializableDataclass):
37 """Configuration dataclass for lmcat"""
39 content_divider: str = serializable_field(default="``````")
40 tree_only: bool = serializable_field(default=False)
42 # ignoring
43 ignore_patterns: list[str] = serializable_field(default_factory=list)
44 ignore_patterns_files: list[Path] = serializable_field(
45 default_factory=lambda: [Path(".gitignore"), Path(".lmignore")],
46 serialization_fn=lambda x: [p.as_posix() for p in x],
47 deserialize_fn=lambda x: [Path(p) for p in x],
48 )
50 # this file will be imported, and if the functions in it are decorated
51 # with one of the `register_*` decorators, they will be added to the functions
52 # which can be used in the processing pipeline
53 # --allow-plugins is a command line only option and must be set to true for this to work
54 plugins_file: Path | None = serializable_field(
55 default=None,
56 serialization_fn=lambda x: x.as_posix() if x else None,
57 deserialize_fn=lambda x: Path(x) if x else None,
58 )
59 allow_plugins: bool = serializable_field(
60 default=False,
61 deserialize_fn=lambda x: False, # this can only be overriden through the command line
62 )
64 # processing pipeline
65 glob_process: dict[str, str] = serializable_field(default_factory=dict)
66 decider_process: dict[str, str] = serializable_field(default_factory=dict)
67 on_multiple_processors: OnMultipleProcessors = serializable_field(
68 default="except",
69 assert_type=False,
70 )
72 # tokenization
73 tokenizer: str = serializable_field(
74 default="gpt2" if TOKENIZERS_PRESENT else "whitespace-split"
75 )
76 "Tokenizer to use for tokenizing the output. `gpt2` by default. passed to `tokenizers.Tokenizer.from_pretrained()`. If specified and `tokenizers` not installed, will throw exception. fallback `whitespace-split` used to avoid exception when `tokenizers` not installed."
78 # tree formatting
79 tree_divider: str = serializable_field(default="│ ")
80 tree_file_divider: str = serializable_field(default="├── ")
81 tree_indent: str = serializable_field(default=" ")
83 # output location
84 output: str | None = serializable_field(default=None)
86 def get_tokenizer_obj(self) -> TokenizerWrapper:
87 """Get the tokenizer object"""
88 return TokenizerWrapper(self.tokenizer)
90 def get_processing_pipeline(self) -> ProcessingPipeline:
91 """Get the processing pipeline object"""
92 plugins_file: Path | None = self.plugins_file if self.allow_plugins else None
93 return ProcessingPipeline(
94 plugins_file=plugins_file,
95 decider_process_keys=self.decider_process,
96 glob_process_keys=self.glob_process,
97 on_multiple_processors=self.on_multiple_processors,
98 )
100 @classmethod
101 def read(cls, root_dir: Path) -> "LMCatConfig":
102 """Attempt to read config from pyproject.toml, lmcat.toml, or lmcat.json."""
103 pyproject_path: Path = root_dir / "pyproject.toml"
104 lmcat_toml_path: Path = root_dir / "lmcat.toml"
105 lmcat_json_path: Path = root_dir / "lmcat.json"
107 if (
108 sum(
109 int(p.is_file())
110 for p in (pyproject_path, lmcat_toml_path, lmcat_json_path)
111 )
112 > 1
113 ):
114 raise ValueError(
115 "Multiple configuration files found. Please only use one of pyproject.toml, lmcat.toml, or lmcat.json."
116 )
118 # Try pyproject.toml first
119 if tomllib is not None and pyproject_path.is_file():
120 with pyproject_path.open("rb") as f:
121 pyproject_data = tomllib.load(f)
122 if "tool" in pyproject_data and "lmcat" in pyproject_data["tool"]:
123 return cls.load(pyproject_data["tool"]["lmcat"])
125 # Then try lmcat.toml
126 if tomllib is not None and lmcat_toml_path.is_file():
127 with lmcat_toml_path.open("rb") as f:
128 toml_data = tomllib.load(f)
129 return cls.load(toml_data)
131 # Finally try lmcat.json
132 if lmcat_json_path.is_file():
133 with lmcat_json_path.open("r", encoding="utf-8") as f:
134 json_data = json.load(f)
135 return cls.load(json_data)
137 # Fallback to defaults
138 return cls()
141class IgnoreHandler:
142 """Handles all ignore pattern matching using igittigitt"""
144 def __init__(self, root_dir: Path, config: LMCatConfig):
145 self.root_dir: Path = root_dir
146 self.config: LMCatConfig = config
148 # set up parser
149 self.parser: igittigitt.IgnoreParser = igittigitt.IgnoreParser()
151 # first from the files
152 for ignore_file in self.config.ignore_patterns_files:
153 self.parser.parse_rule_files(self.root_dir, filename=ignore_file.name)
155 # then from the config itself
156 for pattern in self.config.ignore_patterns:
157 self.parser.add_rule(pattern=pattern, base_path=self.root_dir)
159 def is_ignored(self, path: Path) -> bool:
160 """Check if a path should be ignored"""
161 # Never ignore the gitignore/lmignore files themselves
162 if path.name in {".gitignore", ".lmignore"}:
163 return True
165 # Use igittigitt's matching
166 return self.parser.match(path)
169def sorted_entries(directory: Path) -> list[Path]:
170 """Return directory contents sorted: directories first, then files"""
171 subdirs: list[Path] = sorted(
172 [p for p in directory.iterdir() if p.is_dir()], key=lambda x: x.name
173 )
174 files: list[Path] = sorted(
175 [p for p in directory.iterdir() if p.is_file()], key=lambda x: x.name
176 )
177 return subdirs + files
180def walk_dir(
181 directory: Path,
182 ignore_handler: IgnoreHandler,
183 config: LMCatConfig,
184 tokenizer: TokenizerWrapper,
185 prefix: str = "",
186) -> tuple[list[TreeEntry], list[Path]]:
187 """Recursively walk a directory, building tree lines and collecting file paths"""
188 tree_output: list[TreeEntry] = []
189 collected_files: list[Path] = []
191 entries: list[Path] = sorted_entries(directory)
192 for i, entry in enumerate(entries):
193 if ignore_handler.is_ignored(entry):
194 continue
196 is_last: bool = i == len(entries) - 1
197 connector: str = (
198 config.tree_file_divider
199 if not is_last
200 else config.tree_file_divider.replace("├", "└")
201 )
203 if entry.is_dir():
204 tree_output.append(TreeEntry(f"{prefix}{connector}{entry.name}", None))
205 extension: str = config.tree_divider if not is_last else config.tree_indent
206 sub_output: list[TreeEntry]
207 sub_files: list[Path]
208 sub_output, sub_files = walk_dir(
209 directory=entry,
210 ignore_handler=ignore_handler,
211 config=config,
212 tokenizer=tokenizer,
213 prefix=prefix + extension,
214 )
215 tree_output.extend(sub_output)
216 collected_files.extend(sub_files)
217 else:
218 stats: FileStats = FileStats.from_file(entry, tokenizer)
219 tree_output.append(TreeEntry(f"{prefix}{connector}{entry.name}", stats))
220 collected_files.append(entry)
222 return tree_output, collected_files
225def format_tree_with_stats(
226 entries: list[TreeEntry], show_tokens: bool = False
227) -> list[str]:
228 """Format tree entries with aligned statistics
230 # Parameters:
231 - `entries : list[TreeEntry]`
232 List of tree entries with optional stats
233 - `show_tokens : bool`
234 Whether to show token counts
236 # Returns:
237 - `list[str]`
238 Formatted tree lines with aligned stats
239 """
240 # Find max widths for alignment
241 max_line_len: int = max(len(entry.line) for entry in entries)
242 max_lines: int = max(
243 (len(f"{entry.stats.lines:,}") if entry.stats else 0) for entry in entries
244 )
245 max_chars: int = max(
246 (len(f"{entry.stats.chars:,}") if entry.stats else 0) for entry in entries
247 )
248 max_tokens: int = (
249 max(
250 (
251 len(f"{entry.stats.tokens:,}")
252 if entry.stats and entry.stats.tokens
253 else 0
254 )
255 for entry in entries
256 )
257 if show_tokens
258 else 0
259 )
261 formatted: list[str] = []
262 for entry in entries:
263 line: str = entry.line.ljust(max_line_len + 2)
264 if entry.stats:
265 lines_str: str = f"{entry.stats.lines:,}L".rjust(max_lines + 1)
266 chars_str: str = f"{entry.stats.chars:,}C".rjust(max_chars + 1)
267 stats_str: str = f"[{lines_str} {chars_str}"
268 if show_tokens and entry.stats.tokens is not None:
269 tokens_str: str = f"{entry.stats.tokens:,}T".rjust(max_tokens + 1)
270 stats_str += f" {tokens_str}"
271 stats_str += "]"
272 formatted.append(f"{line}{stats_str}")
273 else:
274 formatted.append(line)
276 return formatted
279def walk_and_collect(
280 root_dir: Path,
281 config: LMCatConfig,
282) -> tuple[list[str], list[Path]]:
283 """Walk filesystem from root_dir and gather tree listing plus file paths"""
284 if config is None:
285 config = LMCatConfig()
287 tokenizer: TokenizerWrapper = config.get_tokenizer_obj()
289 ignore_handler = IgnoreHandler(root_dir, config)
290 base_name = root_dir.resolve().name
292 # Start with root directory name
293 tree_output = [TreeEntry(base_name)]
295 # Walk the directory tree
296 sub_output, sub_files = walk_dir(
297 directory=root_dir,
298 ignore_handler=ignore_handler,
299 config=config,
300 tokenizer=tokenizer,
301 prefix="",
302 )
303 tree_output.extend(sub_output)
305 # Format tree with stats
306 formatted_tree = format_tree_with_stats(
307 tree_output, show_tokens=tokenizer is not None
308 )
310 return formatted_tree, sub_files
313def assemble_summary(
314 root_dir: Path,
315 config: LMCatConfig,
316) -> str:
317 """Assemble the summary output and return"""
319 processing_pipeline: ProcessingPipeline = config.get_processing_pipeline()
321 tree_output: list[str]
322 collected_files: list[Path]
323 tree_output, collected_files = walk_and_collect(
324 root_dir=root_dir,
325 config=config,
326 )
328 output: list[str] = []
329 output.append("# File Tree")
330 output.append("\n```")
331 output.extend(tree_output)
332 output.append("```\n")
334 # Add file contents if not suppressed
335 if not config.tree_only:
336 output.append("# File Contents")
338 for fpath in collected_files:
339 # get the path
340 relpath_posix: str = fpath.relative_to(root_dir).as_posix()
342 # process the contents
343 f_contents: str
344 p_name: str | None
345 f_contents, p_name = processing_pipeline.process_file(fpath)
346 processed_with: str = f'processed_with="{p_name}"' if p_name else ""
348 # start of file marker
349 pathspec_start: str = f'{ path="{relpath_posix}" {processed_with} } '
350 pathspec_end: str = f'{ end_of_file="{relpath_posix}" } '
351 output.append("")
352 output.append(config.content_divider + pathspec_start)
354 # process the actual contents of the file with the pipeline, and append
355 output.append(f_contents)
357 # add the end of file marker
358 output.append(config.content_divider + pathspec_end)
360 output_joined: str = "\n".join(output)
362 stats_dict_ints: dict[str, int] = {
363 "files": len(collected_files),
364 "lines": len(output_joined.splitlines()),
365 "chars": len(output_joined),
366 }
368 tokenizer: TokenizerWrapper = config.get_tokenizer_obj()
370 n_tokens: int = tokenizer.n_tokens(output_joined)
371 stats_dict_ints[f"`{tokenizer.name}` tokens"] = n_tokens
373 stats_header: list[str] = ["# Stats"]
374 for key, val in stats_dict_ints.items():
375 val_str: str = str(val)
376 val_short: str = shorten_numerical_to_str(val)
377 if val_str != val_short:
378 stats_header.append(f"- {val} ({val_short}) {key}")
379 else:
380 stats_header.append(f"- {val} {key}")
382 output_complete: str = "\n".join(stats_header) + "\n\n" + output_joined
384 return output_complete
387def main() -> None:
388 """Main entry point for the script"""
389 arg_parser = argparse.ArgumentParser(
390 description="lmcat - list tree and content, combining .gitignore + .lmignore",
391 add_help=False,
392 )
393 arg_parser.add_argument(
394 "-t",
395 "--tree-only",
396 action="store_true",
397 default=False,
398 help="Only print the tree, not the file contents.",
399 )
400 arg_parser.add_argument(
401 "-o",
402 "--output",
403 action="store",
404 default=None,
405 help="Output file to write the tree and contents to. set to 'STDOUT' to print to console if you want to override the config.",
406 )
407 arg_parser.add_argument(
408 "-h", "--help", action="help", help="Show this help message and exit."
409 )
410 arg_parser.add_argument(
411 "--print-cfg",
412 action="store_true",
413 default=False,
414 help="Print the configuration as json and exit.",
415 )
416 arg_parser.add_argument(
417 "--allow-plugins",
418 action="store_true",
419 default=False,
420 help="Allow plugins to be loaded from the plugins file. WARNING: this will execute arbitrary code found in the file pointed to by `config.plugins_file`, and **is a security risk**.",
421 )
423 args: argparse.Namespace = arg_parser.parse_known_args()[0]
424 root_dir: Path = Path(".").resolve()
425 config: LMCatConfig = LMCatConfig.read(root_dir)
427 # CLI overrides
428 if args.output == "STDOUT":
429 config.output = None
430 elif args.output is not None:
431 config.output = args.output
432 else:
433 assert args.output is None
435 config.tree_only = args.tree_only
436 config.allow_plugins = args.allow_plugins
438 # print cfg and exit if requested
439 if args.print_cfg:
440 print(json.dumps(config.serialize(), indent="\t"))
441 return
443 # assemble summary
444 summary: str = assemble_summary(root_dir=root_dir, config=config)
446 # Write output
447 if config.output:
448 output_path: Path = Path(config.output)
449 output_path.parent.mkdir(parents=True, exist_ok=True)
450 output_path.write_text(summary, encoding="utf-8")
451 else:
452 if sys.platform == "win32":
453 sys.stdout = io.TextIOWrapper(
454 sys.stdout.buffer, encoding="utf-8", errors="replace"
455 )
456 sys.stderr = io.TextIOWrapper(
457 sys.stderr.buffer, encoding="utf-8", errors="replace"
458 )
460 print(summary)
463if __name__ == "__main__":
464 main()