Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #cat <<'EOF' > mini_run.py
- #!/usr/bin/env python3
- import os
- import re
- import json
- import zipfile
- import hashlib
- import tempfile
- import subprocess
- import argparse
- import time
- from tqdm import tqdm
- from collections import Counter
- from llama_cpp import Llama
- import difflib
- import psutil
- import chardet
- from func_timeout import func_timeout, FunctionTimedOut
- # ===== Configuration =====
- MODEL_PATH = "/home/davetmire85/gguf_models/mistral-7b-instruct-v0.2.Q4_K_M.gguf"
- INPUT_PATH = "/home/davetmire85/sigil_inputs/rust_examples_utf8.jsonl"
- OUTPUT_DIR = "enriched_outputs"
- ZIP_PATH = "rich_results.zip"
- BUCKET_URI = "gs://sigil-transfer-bucket/rich_results.zip"
- HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN", "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")
- # Reduced for 26GB RAM environment
- BATCH_SIZE = 2
- MAX_TOKENS = 512
- CTX_WINDOW = 4096
- MAX_CODE_SIZE = 10000
- BUG_TYPES = {
- "OperatorSwap": r"(\+|\-|\*|\/|==|!=)",
- "UnwrapToQuestionMark": r"\.unwrap\(\)",
- "BoundaryError": r"for\s.*\bin\b\s(0\.\.|\.\.=)?\d+",
- "LifetimeMissing": r"&\s*\w+\s*(?![:+])",
- "TypeMismatch": r"as\s+[A-Za-z0-9_]+",
- "IndexError": r"\[[A-Za-z0-9_]+\]",
- }
- # Resource monitoring
- def memory_safe():
- mem = psutil.virtual_memory()
- return mem.available > (512 * 1024 * 1024) # 512MB buffer
- def disk_safe():
- usage = psutil.disk_usage('/tmp')
- return usage.free > (100 * 1024 * 1024) # 100MB min free
- def compute_uid(code: str) -> str:
- if len(code) > MAX_CODE_SIZE:
- raise ValueError(f"Code exceeds maximum size ({MAX_CODE_SIZE} chars)")
- return hashlib.sha256(code.encode("utf-8")).hexdigest()[:32]
- def analyze_edits(original: str, fixed: str) -> dict:
- d = difflib.SequenceMatcher(None, original, fixed)
- metrics = {
- "similarity_score": d.ratio(),
- "equal_cnt": 0,
- "replace_cnt": 0,
- "delete_cnt": 0,
- "insert_cnt": 0,
- "fix_ops_cnt": 0,
- "changed_lines": []
- }
- for op, i1, i2, j1, j2 in d.get_opcodes():
- if op == "equal":
- metrics["equal_cnt"] += (i2 - i1)
- elif op == "replace":
- metrics["replace_cnt"] += (i2 - i1)
- metrics["fix_ops_cnt"] += 1
- metrics["changed_lines"].extend(range(i1, i2))
- elif op == "delete":
- metrics["delete_cnt"] += (i2 - i1)
- metrics["fix_ops_cnt"] += 1
- metrics["changed_lines"].extend(range(i1, i2))
- elif op == "insert":
- metrics["insert_cnt"] += (j2 - j1)
- metrics["fix_ops_cnt"] += 1
- metrics["changed_lines"] = sorted(set(metrics["changed_lines"]))
- return metrics
- def classify_bug(buggy: str, fixed: str) -> str:
- for bug_type, pattern in BUG_TYPES.items():
- if re.search(pattern, buggy) and not re.search(pattern, fixed):
- return bug_type
- return "ComplexChange"
- def get_execution_status(code: str) -> dict:
- # Check resources before starting build
- if not memory_safe() or not disk_safe():
- return {"status": "RESOURCE", "error": "Insufficient system resources"}
- with tempfile.TemporaryDirectory(prefix="sigil_") as tmpdir:
- crate_name = f"sigil_{compute_uid(code)[:8]}"
- src_dir = os.path.join(tmpdir, "src")
- os.makedirs(src_dir, exist_ok=True)
- with open(os.path.join(tmpdir, "Cargo.toml"), "w") as f:
- f.write(f"""
- [package]
- name = "{crate_name}"
- version = "0.1.0"
- edition = "2021"
- [profile.release]
- opt-level = 'z' # Optimize for size
- lto = false
- """)
- with open(os.path.join(src_dir, "main.rs"), "w") as f:
- f.write(code)
- # Constrained build environment
- env = os.environ.copy()
- env["CARGO_BUILD_JOBS"] = "1" # Limit parallelism
- try:
- build = subprocess.run(
- ["cargo", "build", "--release"],
- cwd=tmpdir,
- capture_output=True,
- timeout=30,
- text=True,
- env=env
- )
- if build.returncode != 0:
- return {"status": "CE", "error": build.stderr[:500]}
- run = subprocess.run(
- [os.path.join(tmpdir, "target/release", crate_name)],
- capture_output=True,
- timeout=5,
- text=True
- )
- return {
- "status": "AC" if run.returncode == 0 else "RE",
- "signal": run.returncode,
- "output": run.stdout[:500],
- "exec_time": run.returncode
- }
- except subprocess.TimeoutExpired:
- return {"status": "TLE", "error": "timeout"}
- except Exception as e:
- return {"status": "ERROR", "error": str(e)[:200]}
- def load_model():
- return Llama(
- model_path=MODEL_PATH,
- n_ctx=CTX_WINDOW,
- n_threads=2, # Fixed for 4 vCPU environment
- n_gpu_layers=35,
- n_batch=min(256, BATCH_SIZE * 2),
- f16_kv=True,
- use_mlock=True,
- verbose=False
- )
- def detect_and_decode(filepath):
- with open(filepath, "rb") as f:
- raw = f.read()
- detection = chardet.detect(raw)
- encoding = detection["encoding"] or "utf-8"
- try:
- text = raw.decode(encoding)
- except UnicodeDecodeError:
- text = raw.decode("utf-8", errors="replace")
- return text.replace("\r\n", "\n")
- def load_entries(path):
- entries = []
- content = detect_and_decode(path)
- for line in content.splitlines():
- try:
- data = json.loads(line)
- code = data.get("before") or data.get("source_code")
- if not code:
- continue
- entries.append({
- "bug_code_uid": data.get("code_uid", compute_uid(code)),
- "bug_source_code": code,
- "metadata": {k: v for k, v in data.items() if k not in ["before", "source_code"]}
- })
- except Exception as e:
- print(f"[!] Failed to parse: {e}")
- return entries
- def process_batch(llm: Llama, batch: list) -> list:
- results = []
- for entry in batch:
- try:
- # Check resources before each entry
- if not memory_safe() or not disk_safe():
- raise ResourceWarning("Insufficient system resources")
- prompt = f"""<s>[INST] You are a senior Rust developer. Fix this code:
- {entry['bug_source_code'][:CTX_WINDOW//2]}
- Provide ONLY the fixed Rust code without explanations. [/INST] Fixed Code:\n"""
- try:
- response = func_timeout(
- 120,
- llm,
- args=(prompt,),
- kwargs={
- "max_tokens": MAX_TOKENS,
- "temperature": 0.1,
- "top_p": 0.9,
- "stop": ["</s>", "```", "\n\n\n", "[INST]"],
- "echo": False
- }
- )
- fix = response["choices"][0]["text"].strip().split("Fixed Code:")[-1].strip()
- except FunctionTimedOut:
- fix = entry['bug_source_code'] # Fallback to original
- print("Model timeout, using original code")
- if not fix:
- raise ValueError("Empty fix generated")
- diff_metrics = analyze_edits(entry['bug_source_code'], fix)
- execution = {
- "bug": get_execution_status(entry['bug_source_code']),
- "fix": get_execution_status(fix)
- }
- results.append({
- **entry,
- "fix_source_code": fix,
- "fix_code_uid": compute_uid(fix),
- "apr_id": compute_uid(entry['bug_code_uid'] + compute_uid(fix)),
- **diff_metrics,
- "bug_type": classify_bug(entry['bug_source_code'], fix),
- "execution": execution,
- "potential_dominant_fix_op": "replace",
- "resource_usage": {}
- })
- except Exception as e:
- results.append({
- "error": str(e),
- "bug_code_uid": entry.get('bug_code_uid', 'unknown'),
- "partial_data": entry.get('metadata', {})
- })
- return results
- def package_results(output_dir):
- with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zipf:
- for fname in os.listdir(output_dir):
- zipf.write(os.path.join(output_dir, fname), fname)
- def upload_results():
- try:
- subprocess.run(["gsutil", "cp", ZIP_PATH, BUCKET_URI], check=True)
- except Exception as e:
- print(f"Upload failed: {str(e)}")
- def main():
- parser = argparse.ArgumentParser(description="Rust Code Fix Pipeline")
- parser.add_argument("--input", default=INPUT_PATH, help="Input JSONL path")
- parser.add_argument("--output_dir", default=OUTPUT_DIR, help="Output directory")
- parser.add_argument("--skip_existing", action="store_true", help="Skip processed shards")
- args = parser.parse_args()
- os.environ["HUGGINGFACE_HUB_TOKEN"] = HUGGINGFACE_TOKEN
- os.makedirs(args.output_dir, exist_ok=True)
- print("Loading model...")
- llm = load_model()
- print("Model loaded")
- print(f"Reading input from: {args.input}")
- entries = load_entries(args.input)
- print(f"Loaded {len(entries)} valid entries")
- if not entries:
- print("[!] No entries loaded. Exiting.")
- return
- total_batches = (len(entries) + BATCH_SIZE - 1) // BATCH_SIZE
- with tqdm(total=total_batches, desc="Processing") as pbar:
- for batch_idx in range(0, len(entries), BATCH_SIZE):
- shard_path = os.path.join(args.output_dir, f"shard_{batch_idx//BATCH_SIZE:04d}.jsonl")
- if args.skip_existing and os.path.exists(shard_path):
- pbar.update(1)
- continue
- # System resource check
- if not memory_safe():
- print("Low memory, pausing for 30s...")
- time.sleep(30)
- batch = entries[batch_idx:batch_idx + BATCH_SIZE]
- results = process_batch(llm, batch)
- with open(shard_path, "w") as f:
- for result in results:
- f.write(json.dumps(result, ensure_ascii=False) + "\n")
- pbar.update(1)
- print("Packaging results...")
- package_results(args.output_dir)
- print("Uploading results...")
- upload_results()
- print("Pipeline completed!")
- if __name__ == "__main__":
- main()
- #EOF
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement