File size: 6,054 Bytes
84cfaba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# infra/message_store.py
from typing import Any, Dict, List, Tuple, Optional
import json
from kallam.infra.db import sqlite_conn
from datetime import datetime
import uuid

class MessageStore:
    def __init__(self, db_path: str): 
        self.db_path = db_path.replace("sqlite:///", "")

    def get_original_history(self, session_id: str, limit: int = 10) -> List[Dict[str, str]]:
        with sqlite_conn(self.db_path) as c:
            rows = c.execute("""

                select role, content as content

                from messages where session_id=? and role in ('user','assistant')

                order by id desc limit ?""", (session_id, limit)).fetchall()
        return [{"role": r["role"], "content": r["content"]} for r in reversed(rows)]

    def get_translated_history(self, session_id: str, limit: int = 10) -> List[Dict[str, str]]:
        with sqlite_conn(self.db_path) as c:
            rows = c.execute("""

                select role, coalesce(translated_content, content) as content

                from messages where session_id=? and role in ('user','assistant')

                order by id desc limit ?""", (session_id, limit)).fetchall()
        return [{"role": r["role"], "content": r["content"]} for r in reversed(rows)]

    def get_reasoning_traces(self, session_id: str, limit: int = 10) -> List[Dict[str, Any]]:
        with sqlite_conn(self.db_path) as c:
            rows = c.execute("""

                select message_id, chain_of_thoughts from messages

                where session_id=? and chain_of_thoughts is not null

                order by id desc limit ?""", (session_id, limit)).fetchall()
        out = []
        for r in rows:
            try:
                out.append({"message_id": r["message_id"], "contents": json.loads(r["chain_of_thoughts"])})
            except json.JSONDecodeError:
                continue
        return out

    def append_user(self, session_id: str, content: str, translated: str | None,

                    flags: Dict[str, Any] | None, tokens_in: int) -> None:
        self._append(session_id, "user", content, translated, None, None, flags, tokens_in, 0)

    def append_assistant(self, session_id: str, content: str, translated: str | None,

                         reasoning: Dict[str, Any] | None, tokens_out: int) -> None:
        self._append(session_id, "assistant", content, translated, reasoning, None, None, 0, tokens_out)

    def _append(self, session_id, role, content, translated, reasoning, latency_ms, flags, tok_in, tok_out):
        message_id = f"MSG-{uuid.uuid4().hex[:8].upper()}"
        now = datetime.now().isoformat()
        with sqlite_conn(self.db_path) as c:
            c.execute("""insert into messages (session_id,message_id,timestamp,role,content,

                         translated_content,chain_of_thoughts,tokens_input,tokens_output,latency_ms,flags)

                         values (?,?,?,?,?,?,?,?,?,?,?)""",
                      (session_id, message_id, now, role, content,
                       translated, json.dumps(reasoning, ensure_ascii=False) if reasoning else None,
                       tok_in, tok_out, latency_ms, json.dumps(flags, ensure_ascii=False) if flags else None))
            if role == "user":
                c.execute("""update sessions set total_messages=total_messages+1,

                             total_user_messages=coalesce(total_user_messages,0)+1,

                             last_activity=? where session_id=?""", (now, session_id))
            elif role == "assistant":
                c.execute("""update sessions set total_messages=total_messages+1,

                             total_assistant_messages=coalesce(total_assistant_messages,0)+1,

                             last_activity=? where session_id=?""", (now, session_id))

    def aggregate_stats(self, session_id: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """

        Returns:

          stats: {

            "message_count": int,

            "total_tokens_in": int,

            "total_tokens_out": int,

            "avg_latency": float | None,

            "first_message": str | None,  # ISO timestamp

            "last_message": str | None,   # ISO timestamp

          }

          session: dict  # full row from sessions table (as a mapping)

        """
        with sqlite_conn(self.db_path) as c:
            # Roll up message stats
            row = c.execute(
                """

                SELECT

                  COUNT(*) AS message_count,

                  COALESCE(SUM(tokens_input), 0)  AS total_tokens_in,

                  COALESCE(SUM(tokens_output), 0) AS total_tokens_out,

                  AVG(CASE WHEN role='assistant' THEN latency_ms END) AS avg_latency,

                  MIN(timestamp)                  AS first_message,

                  MAX(timestamp)                  AS last_message

                FROM messages

                WHERE session_id = ?

                  AND role IN ('user','assistant')

                """,
                (session_id,),
            ).fetchone()

            stats = {
                "message_count": row["message_count"] or 0,
                "total_tokens_in": row["total_tokens_in"] or 0,
                "total_tokens_out": row["total_tokens_out"] or 0,
                # Normalize avg_latency to float if not None
                "avg_latency": float(row["avg_latency"]) if row["avg_latency"] is not None else None,
                "first_message": row["first_message"],
                "last_message": row["last_message"],
            }

            # Fetch session info (entire row)
            session = c.execute(
                "SELECT * FROM sessions WHERE session_id = ?",
                (session_id,),
            ).fetchone() or {}

        # Convert sqlite Row to plain dict if needed
        if hasattr(session, "keys"):
            session = {k: session[k] for k in session.keys()}

        return stats, session