mirror of
https://github.com/patriceckhart/zot.git
synced 2026-06-26 21:36:31 +02:00
Merge pull request #30 from jameswei/fix/session-fork-after-compaction
fix(core): fork sessions from effective compacted transcript
This commit is contained in:
commit
7f954ceaa3
2 changed files with 129 additions and 18 deletions
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/patriceckhart/zot/packages/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PortableExt is the filesystem extension used for exported sessions.
|
// PortableExt is the filesystem extension used for exported sessions.
|
||||||
|
|
@ -321,43 +322,76 @@ func BranchSession(parentPath, root, cwd, version string, upToMessageIdx int) (s
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy message rows up to the cut point, plus all usage rows
|
// Reconstruct the effective transcript the same way OpenSession
|
||||||
// that land before the cut (they describe the cost of those
|
// does: message rows append, and compaction rows replace everything
|
||||||
// messages). Rewind and use the large-row-safe JSONL reader.
|
// before them. The fork index is defined over that effective stream,
|
||||||
|
// not over the raw audit rows kept on disk before a compaction.
|
||||||
if _, err := src.Seek(0, io.SeekStart); err != nil {
|
if _, err := src.Seek(0, io.SeekStart); err != nil {
|
||||||
return "", fmt.Errorf("branch: rewind parent: %w", err)
|
return "", fmt.Errorf("branch: rewind parent: %w", err)
|
||||||
}
|
}
|
||||||
msgCount := 0
|
var effective []provider.Message
|
||||||
|
var nonCompactedRows [][]byte
|
||||||
|
effectiveCount := 0
|
||||||
|
sawCompaction := false
|
||||||
if err := forEachJSONLLine(src, func(line []byte) error {
|
if err := forEachJSONLLine(src, func(line []byte) error {
|
||||||
if msgCount >= upToMessageIdx {
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
var h sessionLineHead
|
var h sessionLineHead
|
||||||
if err := json.Unmarshal(line, &h); err != nil {
|
if err := json.Unmarshal(line, &h); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
switch h.Type {
|
switch h.Type {
|
||||||
case "message":
|
case "message":
|
||||||
if _, err := bw.Write(line); err != nil {
|
if msg, err := hydrateMessage(line); err == nil && len(msg.Content) > 0 {
|
||||||
return err
|
effective = append(effective, msg)
|
||||||
|
if !sawCompaction && effectiveCount < upToMessageIdx {
|
||||||
|
raw := append([]byte(nil), line...)
|
||||||
|
nonCompactedRows = append(nonCompactedRows, raw)
|
||||||
|
}
|
||||||
|
effectiveCount++
|
||||||
}
|
}
|
||||||
if err := bw.WriteByte('\n'); err != nil {
|
case "compaction":
|
||||||
return err
|
if compacted, err := hydrateCompaction(line); err == nil {
|
||||||
|
effective = compacted
|
||||||
|
effectiveCount = len(effective)
|
||||||
|
sawCompaction = true
|
||||||
}
|
}
|
||||||
msgCount++
|
|
||||||
case "usage":
|
case "usage":
|
||||||
if _, err := bw.Write(line); err != nil {
|
if !sawCompaction && effectiveCount < upToMessageIdx {
|
||||||
return err
|
raw := append([]byte(nil), line...)
|
||||||
|
nonCompactedRows = append(nonCompactedRows, raw)
|
||||||
}
|
}
|
||||||
if err := bw.WriteByte('\n'); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// don't increment msgCount for usage rows
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}); err != nil && err != io.EOF {
|
}); err != nil && err != io.EOF {
|
||||||
return "", fmt.Errorf("branch: read parent: %w", err)
|
return "", fmt.Errorf("branch: read parent: %w", err)
|
||||||
}
|
}
|
||||||
|
if sawCompaction {
|
||||||
|
limit := upToMessageIdx
|
||||||
|
if limit > len(effective) {
|
||||||
|
limit = len(effective)
|
||||||
|
}
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
msg := effective[i]
|
||||||
|
line, err := json.Marshal(sessionLine{Type: "message", Message: &msg})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("branch: marshal message: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := bw.Write(line); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := bw.WriteByte('\n'); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, row := range nonCompactedRows {
|
||||||
|
if _, err := bw.Write(row); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := bw.WriteByte('\n'); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if err := bw.Flush(); err != nil {
|
if err := bw.Flush(); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,71 @@ func TestBranchSessionCopiesPrefix(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBranchSessionUsesEffectiveTranscriptAfterCompaction(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
cwd := "/project"
|
||||||
|
parent, err := NewSession(root, cwd, "anthropic", "claude-opus-4-7", "0.0.0-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, text := range []string{"raw-a", "raw-b", "raw-c", "raw-d"} {
|
||||||
|
_ = parent.AppendMessage(provider.Message{
|
||||||
|
Role: provider.RoleUser,
|
||||||
|
Content: []provider.Content{provider.TextBlock{Text: text}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ = parent.AppendCompaction([]provider.Message{
|
||||||
|
{Role: provider.RoleAssistant, Content: []provider.Content{provider.TextBlock{Text: "summary"}}},
|
||||||
|
{Role: provider.RoleUser, Content: []provider.Content{provider.TextBlock{Text: "tail-c"}}},
|
||||||
|
{Role: provider.RoleAssistant, Content: []provider.Content{provider.TextBlock{Text: "tail-d"}}},
|
||||||
|
})
|
||||||
|
_ = parent.AppendMessage(provider.Message{
|
||||||
|
Role: provider.RoleUser,
|
||||||
|
Content: []provider.Content{provider.TextBlock{Text: "after-compact"}},
|
||||||
|
})
|
||||||
|
_ = parent.Close()
|
||||||
|
|
||||||
|
opened, msgs, err := OpenSession(parent.Path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenSession parent: %v", err)
|
||||||
|
}
|
||||||
|
_ = opened.Close()
|
||||||
|
assertMessageTexts(t, msgs, []string{"summary", "tail-c", "tail-d", "after-compact"})
|
||||||
|
|
||||||
|
branchPath, err := BranchSession(parent.Path, root, cwd, "0.0.0-test", 4)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BranchSession: %v", err)
|
||||||
|
}
|
||||||
|
branch, branchMsgs, err := OpenSession(branchPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenSession branch: %v", err)
|
||||||
|
}
|
||||||
|
defer branch.Close()
|
||||||
|
|
||||||
|
assertMessageTexts(t, branchMsgs, []string{"summary", "tail-c", "tail-d", "after-compact"})
|
||||||
|
if branch.Meta.Parent != parent.Meta.ID {
|
||||||
|
t.Errorf("parent id: want %q, got %q", parent.Meta.ID, branch.Meta.Parent)
|
||||||
|
}
|
||||||
|
if branch.Meta.ForkPoint != 4 {
|
||||||
|
t.Errorf("fork_point: want 4, got %d", branch.Meta.ForkPoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
shortBranchPath, err := BranchSession(parent.Path, root, cwd, "0.0.0-test", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BranchSession short fork: %v", err)
|
||||||
|
}
|
||||||
|
shortBranch, shortBranchMsgs, err := OpenSession(shortBranchPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenSession short branch: %v", err)
|
||||||
|
}
|
||||||
|
defer shortBranch.Close()
|
||||||
|
|
||||||
|
assertMessageTexts(t, shortBranchMsgs, []string{"summary", "tail-c"})
|
||||||
|
if shortBranch.Meta.ForkPoint != 2 {
|
||||||
|
t.Errorf("short branch fork_point: want 2, got %d", shortBranch.Meta.ForkPoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestBuildSessionTree verifies parent/child edges are rebuilt
|
// TestBuildSessionTree verifies parent/child edges are rebuilt
|
||||||
// from meta + sibling-scan.
|
// from meta + sibling-scan.
|
||||||
func TestBuildSessionTree(t *testing.T) {
|
func TestBuildSessionTree(t *testing.T) {
|
||||||
|
|
@ -266,3 +331,15 @@ func TestBuildSessionTree(t *testing.T) {
|
||||||
t.Errorf("want 2 children, got %d", len(rootNode.Children))
|
t.Errorf("want 2 children, got %d", len(rootNode.Children))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func assertMessageTexts(t *testing.T, msgs []provider.Message, want []string) {
|
||||||
|
t.Helper()
|
||||||
|
if len(msgs) != len(want) {
|
||||||
|
t.Fatalf("message count: want %d, got %d", len(want), len(msgs))
|
||||||
|
}
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if got := extractText(msg); got != want[i] {
|
||||||
|
t.Errorf("message %d: want %q, got %q", i, want[i], got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue