1
0
Fork 0
mirror of https://github.com/chrislusf/seaweedfs synced 2025-09-10 13:22:47 +02:00
seaweedfs/weed/mount/ml/plugins/pytorch_plugin.go
chrislu 814e0bb233 Phase 4: Revolutionary Recipe-Based ML Optimization Engine
🚀 Transform SeaweedFS ML optimizations from hard-coded framework-specific code
to a flexible, configuration-driven system using YAML/JSON rules and templates.

## Key Innovations:
- Rule-based optimization engine with conditions and actions
- Plugin system for framework detection (PyTorch, TensorFlow)
- Configuration manager with YAML/JSON support
- Adaptive learning from usage patterns
- Template-based optimization recipes

## New Components:
- optimization_engine.go: Core rule evaluation and application
- config_manager.go: Configuration loading and validation
- plugins/pytorch_plugin.go: PyTorch-specific optimizations
- plugins/tensorflow_plugin.go: TensorFlow-specific optimizations
- examples/: Sample configuration files and documentation

## Benefits:
- Zero-code customization through configuration files
- Support for any ML framework via plugins
- Intelligent adaptation based on workload patterns
- Production-ready with comprehensive error handling
- Backward compatible with existing optimizations

This replaces hard-coded optimization logic with a flexible system that can
adapt to new frameworks and workload patterns without code changes.
2025-08-30 16:49:12 -07:00

362 lines
9.5 KiB
Go

package plugins
import (
"path/filepath"
"strings"
"github.com/seaweedfs/seaweedfs/weed/mount/ml"
)
// PyTorchPlugin provides PyTorch-specific optimizations
type PyTorchPlugin struct {
name string
version string
}
// NewPyTorchPlugin creates a new PyTorch optimization plugin
func NewPyTorchPlugin() *PyTorchPlugin {
return &PyTorchPlugin{
name: "pytorch",
version: "1.0.0",
}
}
// GetFrameworkName returns the framework name
func (p *PyTorchPlugin) GetFrameworkName() string {
return p.name
}
// DetectFramework detects if a file belongs to PyTorch framework
func (p *PyTorchPlugin) DetectFramework(filePath string, content []byte) float64 {
confidence := 0.0
// File extension-based detection
ext := strings.ToLower(filepath.Ext(filePath))
switch ext {
case ".pth", ".pt":
confidence = 0.95
case ".pkl":
if strings.Contains(strings.ToLower(filePath), "pytorch") ||
strings.Contains(strings.ToLower(filePath), "torch") {
confidence = 0.7
} else {
confidence = 0.3
}
}
// Content-based detection (if content is provided)
if len(content) > 0 {
contentStr := string(content[:minInt(len(content), 1024)]) // First 1KB
if strings.Contains(contentStr, "torch") ||
strings.Contains(contentStr, "pytorch") ||
strings.Contains(contentStr, "PytorchStreamReader") {
confidence = maxFloat64(confidence, 0.8)
}
}
// Path-based detection
if strings.Contains(strings.ToLower(filePath), "torch") ||
strings.Contains(strings.ToLower(filePath), "pytorch") {
confidence = maxFloat64(confidence, 0.6)
}
return confidence
}
// GetOptimizationHints provides PyTorch-specific optimization hints
func (p *PyTorchPlugin) GetOptimizationHints(context *ml.OptimizationContext) []ml.OptimizationHint {
hints := make([]ml.OptimizationHint, 0)
// Model file optimizations
if context.FileType == "model" && p.isPyTorchModel(context.FilePath) {
hints = append(hints, ml.OptimizationHint{
Type: "cache_strategy",
Description: "PyTorch models benefit from persistent memory caching",
Priority: 90,
Parameters: map[string]interface{}{
"cache_type": "memory",
"persistence": true,
"compression": false,
"prefetch_size": "25%", // 25% of model size
},
})
if context.FileSize > 500*1024*1024 { // > 500MB
hints = append(hints, ml.OptimizationHint{
Type: "loading_strategy",
Description: "Large PyTorch model - consider lazy loading",
Priority: 85,
Parameters: map[string]interface{}{
"lazy_loading": true,
"chunk_size": 64 * 1024 * 1024, // 64MB chunks
"parallel_load": true,
},
})
}
}
// Dataset optimizations
if p.isPyTorchDataset(context.FilePath) {
hints = append(hints, ml.OptimizationHint{
Type: "dataloader_optimization",
Description: "PyTorch DataLoader optimization for training efficiency",
Priority: 80,
Parameters: map[string]interface{}{
"num_workers": 4,
"pin_memory": true,
"prefetch_factor": 2,
"persistent_workers": true,
},
})
}
// Training-specific optimizations
if context.WorkloadType == "training" {
hints = append(hints, ml.OptimizationHint{
Type: "training_optimization",
Description: "PyTorch training optimizations",
Priority: 75,
Parameters: map[string]interface{}{
"gradient_checkpointing": context.FileSize > 1024*1024*1024, // > 1GB
"mixed_precision": true,
"batch_accumulation": context.BatchSize > 32,
},
})
}
return hints
}
// GetDefaultRules returns PyTorch-specific optimization rules
func (p *PyTorchPlugin) GetDefaultRules() []*ml.OptimizationRule {
return []*ml.OptimizationRule{
{
ID: "pytorch_model_caching",
Name: "PyTorch Model Caching",
Description: "Optimized caching for PyTorch model files",
Priority: 95,
Conditions: []ml.RuleCondition{
{
Type: "file_pattern",
Property: "extension",
Operator: "in",
Value: []string{".pth", ".pt"},
Weight: 1.0,
},
{
Type: "file_context",
Property: "size",
Operator: "greater_than",
Value: 1024 * 1024, // > 1MB
Weight: 0.8,
},
},
Actions: []ml.RuleAction{
{
Type: "cache",
Target: "file",
Parameters: map[string]interface{}{
"strategy": "pytorch_model",
"cache_type": "memory",
"eviction_policy": "lfu",
"compression": false,
"preload": true,
},
},
},
Metadata: map[string]interface{}{
"framework": "pytorch",
"category": "model_caching",
},
},
{
ID: "pytorch_checkpoint_handling",
Name: "PyTorch Checkpoint Optimization",
Description: "Optimized handling for PyTorch training checkpoints",
Priority: 85,
Conditions: []ml.RuleCondition{
{
Type: "file_pattern",
Property: "name_pattern",
Operator: "matches",
Value: ".*checkpoint.*\\.(pth|pt)$",
Weight: 1.0,
},
{
Type: "workload_context",
Property: "workload_type",
Operator: "equals",
Value: "training",
Weight: 0.9,
},
},
Actions: []ml.RuleAction{
{
Type: "checkpoint_optimization",
Target: "file",
Parameters: map[string]interface{}{
"incremental_save": true,
"compression": true,
"backup_strategy": "rolling",
"sync_frequency": "epoch",
},
},
},
Metadata: map[string]interface{}{
"framework": "pytorch",
"category": "checkpoint",
},
},
{
ID: "pytorch_tensor_prefetch",
Name: "PyTorch Tensor Prefetching",
Description: "Intelligent prefetching for PyTorch tensor operations",
Priority: 80,
Conditions: []ml.RuleCondition{
{
Type: "access_pattern",
Property: "pattern_type",
Operator: "in",
Value: []string{"sequential", "strided"},
Weight: 1.0,
},
{
Type: "workload_context",
Property: "framework",
Operator: "equals",
Value: "pytorch",
Weight: 0.9,
},
{
Type: "workload_context",
Property: "batch_size",
Operator: "greater_than",
Value: 8,
Weight: 0.7,
},
},
Actions: []ml.RuleAction{
{
Type: "prefetch",
Target: "tensor",
Parameters: map[string]interface{}{
"strategy": "pytorch_tensor",
"prefetch_size": "batch_aligned",
"parallel_workers": 2,
"cuda_streams": true,
},
},
},
Metadata: map[string]interface{}{
"framework": "pytorch",
"category": "tensor_ops",
},
},
}
}
// GetDefaultTemplates returns PyTorch-specific optimization templates
func (p *PyTorchPlugin) GetDefaultTemplates() []*ml.OptimizationTemplate {
return []*ml.OptimizationTemplate{
{
ID: "pytorch_training_template",
Name: "PyTorch Training Optimization",
Description: "Complete optimization template for PyTorch training workloads",
Category: "training",
Rules: []string{
"pytorch_model_caching",
"pytorch_checkpoint_handling",
"pytorch_tensor_prefetch",
"sequential_prefetch", // From base rules
"dataset_batch_optimize", // From base rules
},
Parameters: map[string]interface{}{
"framework": "pytorch",
"training_phase": "active",
"memory_optimization": true,
"gpu_optimization": true,
"dataloader_config": map[string]interface{}{
"num_workers": 4,
"pin_memory": true,
"persistent_workers": true,
"prefetch_factor": 2,
},
"model_config": map[string]interface{}{
"gradient_checkpointing": false,
"mixed_precision": true,
"compile_model": true,
},
},
},
{
ID: "pytorch_inference_template",
Name: "PyTorch Inference Optimization",
Description: "Optimized template for PyTorch inference workloads",
Category: "inference",
Rules: []string{
"pytorch_model_caching",
"pytorch_tensor_prefetch",
},
Parameters: map[string]interface{}{
"framework": "pytorch",
"inference_mode": true,
"batch_inference": true,
"model_config": map[string]interface{}{
"torch_compile": true,
"optimization_level": "O2",
"precision": "fp16",
},
},
},
{
ID: "pytorch_research_template",
Name: "PyTorch Research & Experimentation",
Description: "Flexible template for PyTorch research and experimentation",
Category: "research",
Rules: []string{
"pytorch_model_caching",
"pytorch_checkpoint_handling",
},
Parameters: map[string]interface{}{
"framework": "pytorch",
"experiment_tracking": true,
"flexible_caching": true,
"checkpoint_config": map[string]interface{}{
"save_frequency": "auto",
"version_control": true,
"metadata_tracking": true,
},
},
},
}
}
// Helper methods
func (p *PyTorchPlugin) isPyTorchModel(filePath string) bool {
ext := strings.ToLower(filepath.Ext(filePath))
return ext == ".pth" || ext == ".pt"
}
func (p *PyTorchPlugin) isPyTorchDataset(filePath string) bool {
// Common PyTorch dataset patterns
baseName := strings.ToLower(filepath.Base(filePath))
return strings.Contains(baseName, "dataset") ||
strings.Contains(baseName, "train") ||
strings.Contains(baseName, "val") ||
strings.Contains(baseName, "test")
}
// Utility functions
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func maxFloat64(a, b float64) float64 {
if a > b {
return a
}
return b
}