diff --git a/main.go b/main.go index a4e28ef..5810932 100644 --- a/main.go +++ b/main.go @@ -835,7 +835,18 @@ func handleStaleLinks(stales []string) error { reader := bufio.NewReader(os.Stdin) for _, path := range stales { fmt.Printf("stale: %s\n", path) - fmt.Print("action [p=prune, u=unlink, i=ignore]: ") + + canUnlink, err := staleHasRepoFile(path, repo) + if err != nil { + return err + } + + prompt := "action [p=prune, u=unlink, i=ignore]: " + if !canUnlink { + prompt = "action [p=prune, i=ignore]: " + } + + fmt.Print(prompt) choice, err := reader.ReadString('\n') if err != nil { return err @@ -851,7 +862,7 @@ func handleStaleLinks(stales []string) error { fmt.Printf("removed %s\n", path) continue } - if choice == "u" { + if choice == "u" && canUnlink { if err := unlinkStale(path, repo); err != nil { return err } @@ -863,56 +874,30 @@ func handleStaleLinks(stales []string) error { return nil } -func unlinkStale(targetPath, repo string) error { - info, err := os.Lstat(targetPath) +func staleHasRepoFile(targetPath, repo string) (bool, error) { + repoPath, err := repoPathForTarget(targetPath, repo) if err != nil { - return err + return false, err } - if info.Mode()&os.ModeSymlink == 0 { - return nil + if repoPath == "" { + return false, nil } - - src, err := os.Readlink(targetPath) - if err != nil { - return err - } - if !filepath.IsAbs(src) { - src = filepath.Join(filepath.Dir(targetPath), src) - } - src = filepath.Clean(src) - - repoAbs, err := filepath.Abs(repo) - if err != nil { - return err - } - - rel, err := filepath.Rel(repoAbs, src) - if err != nil { - return err - } - if strings.HasPrefix(rel, "..") { - return nil - } - - parts := strings.Split(rel, string(os.PathSeparator)) - if len(parts) < 3 { - return nil - } - - if parts[1] != filesDirName { - return nil - } - - repoPath := filepath.Join(repoAbs, rel) if _, err := os.Stat(repoPath); errors.Is(err, os.ErrNotExist) { - if err := os.Remove(targetPath); err != nil { - return err - } - fmt.Printf("removed %s (missing repo file)\n", targetPath) - return nil + return false, nil } else if err != nil { + return false, err + } + return true, nil +} + +func unlinkStale(targetPath, repo string) error { + repoPath, err := repoPathForTarget(targetPath, repo) + if err != nil { return err } + if repoPath == "" { + return nil + } if err := os.Remove(targetPath); err != nil { return err @@ -930,6 +915,48 @@ func unlinkStale(targetPath, repo string) error { return nil } +func repoPathForTarget(targetPath, repo string) (string, error) { + info, err := os.Lstat(targetPath) + if err != nil { + return "", err + } + if info.Mode()&os.ModeSymlink == 0 { + return "", nil + } + + src, err := os.Readlink(targetPath) + if err != nil { + return "", err + } + if !filepath.IsAbs(src) { + src = filepath.Join(filepath.Dir(targetPath), src) + } + src = filepath.Clean(src) + + repoAbs, err := filepath.Abs(repo) + if err != nil { + return "", err + } + + rel, err := filepath.Rel(repoAbs, src) + if err != nil { + return "", err + } + if strings.HasPrefix(rel, "..") { + return "", nil + } + + parts := strings.Split(rel, string(os.PathSeparator)) + if len(parts) < 3 { + return "", nil + } + if parts[1] != filesDirName { + return "", nil + } + + return filepath.Join(repoAbs, rel), nil +} + func restorePackage(filesDir, targetRoot string, dryRun bool) error { filesAbs, err := filepath.Abs(filesDir) if err != nil {