diff --git a/packages/core/session_portable.go b/packages/core/session_portable.go index afbcaee..3332140 100644 --- a/packages/core/session_portable.go +++ b/packages/core/session_portable.go @@ -13,6 +13,7 @@ import ( "time" "github.com/google/uuid" + "github.com/patriceckhart/zot/packages/provider" ) // PortableExt is the filesystem extension used for exported sessions. @@ -321,43 +322,76 @@ func BranchSession(parentPath, root, cwd, version string, upToMessageIdx int) (s return "", err } - // Copy message rows up to the cut point, plus all usage rows - // that land before the cut (they describe the cost of those - // messages). Rewind and use the large-row-safe JSONL reader. + // Reconstruct the effective transcript the same way OpenSession + // does: message rows append, and compaction rows replace everything + // before them. The fork index is defined over that effective stream, + // not over the raw audit rows kept on disk before a compaction. if _, err := src.Seek(0, io.SeekStart); err != nil { return "", fmt.Errorf("branch: rewind parent: %w", err) } - msgCount := 0 + var effective []provider.Message + var nonCompactedRows [][]byte + effectiveCount := 0 + sawCompaction := false if err := forEachJSONLLine(src, func(line []byte) error { - if msgCount >= upToMessageIdx { - return io.EOF - } var h sessionLineHead if err := json.Unmarshal(line, &h); err != nil { return nil } switch h.Type { case "message": - if _, err := bw.Write(line); err != nil { - return err + if msg, err := hydrateMessage(line); err == nil && len(msg.Content) > 0 { + effective = append(effective, msg) + if !sawCompaction && effectiveCount < upToMessageIdx { + raw := append([]byte(nil), line...) + nonCompactedRows = append(nonCompactedRows, raw) + } + effectiveCount++ } - if err := bw.WriteByte('\n'); err != nil { - return err + case "compaction": + if compacted, err := hydrateCompaction(line); err == nil { + effective = compacted + effectiveCount = len(effective) + sawCompaction = true } - msgCount++ case "usage": - if _, err := bw.Write(line); err != nil { - return err + if !sawCompaction && effectiveCount < upToMessageIdx { + raw := append([]byte(nil), line...) + nonCompactedRows = append(nonCompactedRows, raw) } - if err := bw.WriteByte('\n'); err != nil { - return err - } - // don't increment msgCount for usage rows } return nil }); err != nil && err != io.EOF { return "", fmt.Errorf("branch: read parent: %w", err) } + if sawCompaction { + limit := upToMessageIdx + if limit > len(effective) { + limit = len(effective) + } + for i := 0; i < limit; i++ { + msg := effective[i] + line, err := json.Marshal(sessionLine{Type: "message", Message: &msg}) + if err != nil { + return "", fmt.Errorf("branch: marshal message: %w", err) + } + if _, err := bw.Write(line); err != nil { + return "", err + } + if err := bw.WriteByte('\n'); err != nil { + return "", err + } + } + } else { + for _, row := range nonCompactedRows { + if _, err := bw.Write(row); err != nil { + return "", err + } + if err := bw.WriteByte('\n'); err != nil { + return "", err + } + } + } if err := bw.Flush(); err != nil { return "", err } diff --git a/packages/core/session_portable_test.go b/packages/core/session_portable_test.go index cfb36ae..ebc89dc 100644 --- a/packages/core/session_portable_test.go +++ b/packages/core/session_portable_test.go @@ -231,6 +231,71 @@ func TestBranchSessionCopiesPrefix(t *testing.T) { } } +func TestBranchSessionUsesEffectiveTranscriptAfterCompaction(t *testing.T) { + root := t.TempDir() + cwd := "/project" + parent, err := NewSession(root, cwd, "anthropic", "claude-opus-4-7", "0.0.0-test") + if err != nil { + t.Fatal(err) + } + for _, text := range []string{"raw-a", "raw-b", "raw-c", "raw-d"} { + _ = parent.AppendMessage(provider.Message{ + Role: provider.RoleUser, + Content: []provider.Content{provider.TextBlock{Text: text}}, + }) + } + _ = parent.AppendCompaction([]provider.Message{ + {Role: provider.RoleAssistant, Content: []provider.Content{provider.TextBlock{Text: "summary"}}}, + {Role: provider.RoleUser, Content: []provider.Content{provider.TextBlock{Text: "tail-c"}}}, + {Role: provider.RoleAssistant, Content: []provider.Content{provider.TextBlock{Text: "tail-d"}}}, + }) + _ = parent.AppendMessage(provider.Message{ + Role: provider.RoleUser, + Content: []provider.Content{provider.TextBlock{Text: "after-compact"}}, + }) + _ = parent.Close() + + opened, msgs, err := OpenSession(parent.Path) + if err != nil { + t.Fatalf("OpenSession parent: %v", err) + } + _ = opened.Close() + assertMessageTexts(t, msgs, []string{"summary", "tail-c", "tail-d", "after-compact"}) + + branchPath, err := BranchSession(parent.Path, root, cwd, "0.0.0-test", 4) + if err != nil { + t.Fatalf("BranchSession: %v", err) + } + branch, branchMsgs, err := OpenSession(branchPath) + if err != nil { + t.Fatalf("OpenSession branch: %v", err) + } + defer branch.Close() + + assertMessageTexts(t, branchMsgs, []string{"summary", "tail-c", "tail-d", "after-compact"}) + if branch.Meta.Parent != parent.Meta.ID { + t.Errorf("parent id: want %q, got %q", parent.Meta.ID, branch.Meta.Parent) + } + if branch.Meta.ForkPoint != 4 { + t.Errorf("fork_point: want 4, got %d", branch.Meta.ForkPoint) + } + + shortBranchPath, err := BranchSession(parent.Path, root, cwd, "0.0.0-test", 2) + if err != nil { + t.Fatalf("BranchSession short fork: %v", err) + } + shortBranch, shortBranchMsgs, err := OpenSession(shortBranchPath) + if err != nil { + t.Fatalf("OpenSession short branch: %v", err) + } + defer shortBranch.Close() + + assertMessageTexts(t, shortBranchMsgs, []string{"summary", "tail-c"}) + if shortBranch.Meta.ForkPoint != 2 { + t.Errorf("short branch fork_point: want 2, got %d", shortBranch.Meta.ForkPoint) + } +} + // TestBuildSessionTree verifies parent/child edges are rebuilt // from meta + sibling-scan. func TestBuildSessionTree(t *testing.T) { @@ -266,3 +331,15 @@ func TestBuildSessionTree(t *testing.T) { t.Errorf("want 2 children, got %d", len(rootNode.Children)) } } + +func assertMessageTexts(t *testing.T, msgs []provider.Message, want []string) { + t.Helper() + if len(msgs) != len(want) { + t.Fatalf("message count: want %d, got %d", len(want), len(msgs)) + } + for i, msg := range msgs { + if got := extractText(msg); got != want[i] { + t.Errorf("message %d: want %q, got %q", i, want[i], got) + } + } +}